From b355e5e8a2a17ae6d8f241d8f6e19e4b4c528397 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 7 Mar 2012 20:25:53 -0500 Subject: lattice builder --- gi/pf/Makefile.am | 7 +- gi/pf/align-tl.cc | 334 ++++++++++++++++++++++++++++++++++++++++++++++ gi/pf/conditional_pseg.h | 11 +- gi/pf/nuisance_test.cc | 161 ++++++++++++++++++++++ gi/pf/transliterations.cc | 193 +++++++++++++++++++++++++++ gi/pf/transliterations.h | 20 +++ 6 files changed, 723 insertions(+), 3 deletions(-) create mode 100644 gi/pf/align-tl.cc create mode 100644 gi/pf/nuisance_test.cc create mode 100644 gi/pf/transliterations.cc create mode 100644 gi/pf/transliterations.h (limited to 'gi') diff --git a/gi/pf/Makefile.am b/gi/pf/Makefile.am index 7cf9c14d..5e89f02a 100644 --- a/gi/pf/Makefile.am +++ b/gi/pf/Makefile.am @@ -1,12 +1,17 @@ -bin_PROGRAMS = cbgi brat dpnaive pfbrat pfdist itg pfnaive condnaive align-lexonly align-lexonly-pyp learn_cfg pyp_lm +bin_PROGRAMS = cbgi brat dpnaive pfbrat pfdist itg pfnaive condnaive align-lexonly align-lexonly-pyp learn_cfg pyp_lm nuisance_test align-tl noinst_LIBRARIES = libpf.a + libpf_a_SOURCES = base_distributions.cc reachability.cc cfg_wfst_composer.cc corpus.cc unigrams.cc ngram_base.cc +nuisance_test_SOURCES = nuisance_test.cc transliterations.cc + align_lexonly_SOURCES = align-lexonly.cc align_lexonly_pyp_SOURCES = align-lexonly-pyp.cc +align_tl_SOURCES = align-tl.cc transliterations.cc + itg_SOURCES = itg.cc pyp_lm_SOURCES = pyp_lm.cc diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc new file mode 100644 index 00000000..0e0454e5 --- /dev/null +++ b/gi/pf/align-tl.cc @@ -0,0 +1,334 @@ +#include +#include +#include + +#include +#include +#include + +#include "array2d.h" +#include "base_distributions.h" +#include "monotonic_pseg.h" +#include "conditional_pseg.h" +#include "trule.h" +#include "tdict.h" +#include "stringlib.h" +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "mfcr.h" +#include "corpus.h" +#include "ngram_base.h" +#include "transliterations.h" + +using namespace std; +using namespace tr1; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("input,i",po::value(),"Read parallel data from") + ("random_seed,S",po::value(), "Random seed"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +shared_ptr prng; + +struct LexicalAlignment { + unsigned char src_index; + bool is_transliteration; + vector > derivation; +}; + +struct AlignedSentencePair { + vector src; + vector trg; + vector a; + Array2D posterior; +}; + +struct HierarchicalWordBase { + explicit HierarchicalWordBase(const unsigned vocab_e_size) : + base(prob_t::One()), r(1,1,1,1,0.66,50.0), u0(-log(vocab_e_size)), l(1,prob_t::One()), v(1, prob_t::Zero()) {} + + void ResampleHyperparameters(MT19937* rng) { + r.resample_hyperparameters(rng); + } + + inline double logp0(const vector& s) const { + return Md::log_poisson(s.size(), 7.5) + s.size() * u0; + } + + // return p0 of rule.e_ + prob_t operator()(const TRule& rule) const { + v[0].logeq(logp0(rule.e_)); + return r.prob(rule.e_, v.begin(), l.begin()); + } + + void Increment(const TRule& rule) { + v[0].logeq(logp0(rule.e_)); + if (r.increment(rule.e_, v.begin(), l.begin(), &*prng).count) { + base *= v[0] * l[0]; + } + } + + void Decrement(const TRule& rule) { + if (r.decrement(rule.e_, &*prng).count) { + base /= prob_t(exp(logp0(rule.e_))); + } + } + + prob_t Likelihood() const { + prob_t p; p.logeq(r.log_crp_prob()); + p *= base; + return p; + } + + void Summary() const { + cerr << "NUMBER OF CUSTOMERS: " << r.num_customers() << " (d=" << r.discount() << ",s=" << r.strength() << ')' << endl; + for (MFCR<1,vector >::const_iterator it = r.begin(); it != r.end(); ++it) + cerr << " " << it->second.total_dish_count_ << " (on " << it->second.table_counts_.size() << " tables) " << TD::GetString(it->first) << endl; + } + + prob_t base; + MFCR<1,vector > r; + const double u0; + const vector l; + mutable vector v; +}; + +struct BasicLexicalAlignment { + explicit BasicLexicalAlignment(const vector >& lets, + const unsigned words_e, + const unsigned letters_e, + vector* corp) : + letters(lets), + corpus(*corp), + //up0(words_e), + //up0("en.chars.1gram", letters_e), + //up0("en.words.1gram"), + up0(letters_e), + //up0("en.chars.2gram"), + tmodel(up0) { + } + + void InstantiateRule(const WordID src, + const WordID trg, + TRule* rule) const { + static const WordID kX = TD::Convert("X") * -1; + rule->lhs_ = kX; + rule->e_ = letters[trg]; + rule->f_ = letters[src]; + } + + void InitializeRandom() { + const WordID kNULL = TD::Convert("NULL"); + cerr << "Initializing with random alignments ...\n"; + for (unsigned i = 0; i < corpus.size(); ++i) { + AlignedSentencePair& asp = corpus[i]; + asp.a.resize(asp.trg.size()); + for (unsigned j = 0; j < asp.trg.size(); ++j) { + const unsigned char a_j = prng->next() * (1 + asp.src.size()); + const WordID f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); + TRule r; + InstantiateRule(f_a_j, asp.trg[j], &r); + asp.a[j].is_transliteration = false; + asp.a[j].src_index = a_j; + if (tmodel.IncrementRule(r, &*prng)) + up0.Increment(r); + } + } + cerr << " LLH = " << Likelihood() << endl; + } + + prob_t Likelihood() const { + prob_t p = tmodel.Likelihood(); + p *= up0.Likelihood(); + return p; + } + + void ResampleHyperparemeters() { + tmodel.ResampleHyperparameters(&*prng); + up0.ResampleHyperparameters(&*prng); + cerr << " (base d=" << up0.r.discount() << ",s=" << up0.r.strength() << ")\n"; + } + + void ResampleCorpus(); + + const vector >& letters; // spelling dictionary + vector& corpus; + //PhraseConditionalUninformativeBase up0; + //PhraseConditionalUninformativeUnigramBase up0; + //UnigramWordBase up0; + //HierarchicalUnigramBase up0; + HierarchicalWordBase up0; + //CompletelyUniformBase up0; + //FixedNgramBase up0; + //ConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; + MConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; +}; + +void BasicLexicalAlignment::ResampleCorpus() { + static const WordID kNULL = TD::Convert("NULL"); + for (unsigned i = 0; i < corpus.size(); ++i) { + AlignedSentencePair& asp = corpus[i]; + SampleSet ss; ss.resize(asp.src.size() + 1); + for (unsigned j = 0; j < asp.trg.size(); ++j) { + TRule r; + unsigned char& a_j = asp.a[j].src_index; + WordID f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); + InstantiateRule(f_a_j, asp.trg[j], &r); + if (tmodel.DecrementRule(r, &*prng)) + up0.Decrement(r); + + for (unsigned prop_a_j = 0; prop_a_j <= asp.src.size(); ++prop_a_j) { + const WordID prop_f = (prop_a_j ? asp.src[prop_a_j - 1] : kNULL); + InstantiateRule(prop_f, asp.trg[j], &r); + ss[prop_a_j] = tmodel.RuleProbability(r); + } + a_j = prng->SelectSample(ss); + f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); + InstantiateRule(f_a_j, asp.trg[j], &r); + if (tmodel.IncrementRule(r, &*prng)) + up0.Increment(r); + } + } + cerr << " LLH = " << Likelihood() << endl; +} + +void ExtractLetters(const set& v, vector >* l, set* letset = NULL) { + for (set::const_iterator it = v.begin(); it != v.end(); ++it) { + vector& letters = (*l)[*it]; + if (letters.size()) continue; // if e and f have the same word + + const string& w = TD::Convert(*it); + + size_t cur = 0; + while (cur < w.size()) { + const size_t len = UTF8Len(w[cur]); + letters.push_back(TD::Convert(w.substr(cur, len))); + if (letset) letset->insert(letters.back()); + cur += len; + } + } +} + +void Debug(const AlignedSentencePair& asp) { + cerr << TD::GetString(asp.src) << endl << TD::GetString(asp.trg) << endl; + Array2D a(asp.src.size(), asp.trg.size()); + for (unsigned j = 0; j < asp.trg.size(); ++j) + if (asp.a[j].src_index) a(asp.a[j].src_index - 1, j) = true; + cerr << a << endl; +} + +void AddSample(AlignedSentencePair* asp) { + for (unsigned j = 0; j < asp->trg.size(); ++j) + asp->posterior(asp->a[j].src_index, j)++; +} + +void WriteAlignments(const AlignedSentencePair& asp) { + bool first = true; + for (unsigned j = 0; j < asp.trg.size(); ++j) { + int src_index = -1; + int mc = -1; + for (unsigned i = 0; i <= asp.src.size(); ++i) { + if (asp.posterior(i, j) > mc) { + mc = asp.posterior(i, j); + src_index = i; + } + } + + if (src_index) { + if (first) first = false; else cout << ' '; + cout << (src_index - 1) << '-' << j; + } + } + cout << endl; +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); +// MT19937& rng = *prng; + + vector > corpuse, corpusf; + set vocabe, vocabf; + corpus::ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + cerr << "f-Corpus size: " << corpusf.size() << " sentences\n"; + cerr << "f-Vocabulary size: " << vocabf.size() << " types\n"; + cerr << "f-Corpus size: " << corpuse.size() << " sentences\n"; + cerr << "f-Vocabulary size: " << vocabe.size() << " types\n"; + assert(corpusf.size() == corpuse.size()); + + vector corpus(corpuse.size()); + for (unsigned i = 0; i < corpuse.size(); ++i) { + corpus[i].src.swap(corpusf[i]); + corpus[i].trg.swap(corpuse[i]); + corpus[i].posterior.resize(corpus[i].src.size() + 1, corpus[i].trg.size()); + } + corpusf.clear(); corpuse.clear(); + + vocabf.insert(TD::Convert("NULL")); + vector > letters(TD::NumWords()); + set letset; + ExtractLetters(vocabe, &letters, &letset); + ExtractLetters(vocabf, &letters, NULL); + letters[TD::Convert("NULL")].clear(); + + Transliterations tl; + + // TODO CONFIGURE THIS + int min_trans_src = 4; + + cerr << "Initializing transliteration DPs ...\n"; + for (int i = 0; i < corpus.size(); ++i) { + const vector& src = corpus[i].src; + const vector& trg = corpus[i].trg; + cerr << '.' << flush; + if (i % 80 == 79) cerr << endl; + for (int j = 0; j < src.size(); ++j) { + const vector& src_let = letters[src[j]]; + for (int k = 0; k < trg.size(); ++k) { + const vector& trg_let = letters[trg[k]]; + if (src_let.size() < min_trans_src) + tl.Forbid(src[j], trg[k]); + else + tl.Initialize(src[j], src_let, trg[k], trg_let); + } + } + } + cerr << endl; + tl.GraphSummary(); + + return 0; +} diff --git a/gi/pf/conditional_pseg.h b/gi/pf/conditional_pseg.h index 8202778b..81ddb206 100644 --- a/gi/pf/conditional_pseg.h +++ b/gi/pf/conditional_pseg.h @@ -56,6 +56,12 @@ struct MConditionalTranslationModel { }; void ResampleHyperparameters(MT19937* rng) { + typename std::tr1::unordered_map, MFCR<1,TRule>, boost::hash > >::iterator it; +#if 1 + for (it = r.begin(); it != r.end(); ++it) { + it->second.resample_hyperparameters(rng); + } +#else const unsigned nloop = 5; const unsigned niterations = 10; DiscountResampler dr(*this); @@ -70,12 +76,12 @@ struct MConditionalTranslationModel { } strength = slice_sampler1d(ar, strength, *rng, -d, std::numeric_limits::infinity(), 0.0, niterations, 100*niterations); - typename std::tr1::unordered_map, MFCR<1,TRule>, boost::hash > >::iterator it; std::cerr << "MConditionalTranslationModel(d=" << d << ",s=" << strength << ") = " << log_likelihood(d, strength) << std::endl; for (it = r.begin(); it != r.end(); ++it) { it->second.set_discount(d); it->second.set_strength(strength); } +#endif } int DecrementRule(const TRule& rule, MT19937* rng) { @@ -91,7 +97,8 @@ struct MConditionalTranslationModel { int IncrementRule(const TRule& rule, MT19937* rng) { RuleModelHash::iterator it = r.find(rule.f_); if (it == r.end()) { - it = r.insert(make_pair(rule.f_, MFCR<1,TRule>(d, strength))).first; + //it = r.insert(make_pair(rule.f_, MFCR<1,TRule>(d, strength))).first; + it = r.insert(make_pair(rule.f_, MFCR<1,TRule>(1,1,1,1,0.6, -0.12))).first; } p0s[0] = rp0(rule); TableCount delta = it->second.increment(rule, p0s.begin(), lambdas.begin(), rng); diff --git a/gi/pf/nuisance_test.cc b/gi/pf/nuisance_test.cc new file mode 100644 index 00000000..0f44fe95 --- /dev/null +++ b/gi/pf/nuisance_test.cc @@ -0,0 +1,161 @@ +#include "ccrp.h" + +#include +#include + +#include "tdict.h" +#include "transliterations.h" + +using namespace std; + +MT19937 rng; + +ostream& operator<<(ostream&os, const vector& v) { + os << '[' << v[0]; + if (v.size() == 2) os << ' ' << v[1]; + return os << ']'; +} + +struct Base { + Base() : llh(), v(2), v1(1), v2(1), crp(0.25, 0.5) {} + inline double p0(const vector& x) const { + double p = 0.75; + if (x.size() == 2) p = 0.25; + p *= 1.0 / 3.0; + if (x.size() == 2) p *= 1.0 / 3.0; + return p; + } + double est_deriv_prob(int a, int b, int seg) const { + assert(a > 0 && a < 4); // a \in {1,2,3} + assert(b > 0 && b < 4); // b \in {1,2,3} + assert(seg == 0 || seg == 1); // seg \in {0,1} + if (seg == 0) { + v[0] = a; + v[1] = b; + return crp.prob(v, p0(v)); + } else { + v1[0] = a; + v2[0] = b; + return crp.prob(v1, p0(v1)) * crp.prob(v2, p0(v2)); + } + } + double est_marginal_prob(int a, int b) const { + return est_deriv_prob(a,b,0) + est_deriv_prob(a,b,1); + } + int increment(int a, int b, double* pw = NULL) { + double p1 = est_deriv_prob(a, b, 0); + double p2 = est_deriv_prob(a, b, 1); + //p1 = 0.5; p2 = 0.5; + int seg = rng.SelectSample(p1,p2); + double tmp = 0; + if (!pw) pw = &tmp; + double& w = *pw; + if (seg == 0) { + v[0] = a; + v[1] = b; + w = crp.prob(v, p0(v)) / p1; + if (crp.increment(v, p0(v), &rng)) { + llh += log(p0(v)); + } + } else { + v1[0] = a; + w = crp.prob(v1, p0(v1)) / p2; + if (crp.increment(v1, p0(v1), &rng)) { + llh += log(p0(v1)); + } + v2[0] = b; + w *= crp.prob(v2, p0(v2)); + if (crp.increment(v2, p0(v2), &rng)) { + llh += log(p0(v2)); + } + } + return seg; + } + void increment(int a, int b, int seg) { + if (seg == 0) { + v[0] = a; + v[1] = b; + if (crp.increment(v, p0(v), &rng)) { + llh += log(p0(v)); + } + } else { + v1[0] = a; + if (crp.increment(v1, p0(v1), &rng)) { + llh += log(p0(v1)); + } + v2[0] = b; + if (crp.increment(v2, p0(v2), &rng)) { + llh += log(p0(v2)); + } + } + } + void decrement(int a, int b, int seg) { + if (seg == 0) { + v[0] = a; + v[1] = b; + if (crp.decrement(v, &rng)) { + llh -= log(p0(v)); + } + } else { + v1[0] = a; + if (crp.decrement(v1, &rng)) { + llh -= log(p0(v1)); + } + v2[0] = b; + if (crp.decrement(v2, &rng)) { + llh -= log(p0(v2)); + } + } + } + double log_likelihood() const { + return llh + crp.log_crp_prob(); + } + double llh; + mutable vector v, v1, v2; + CCRP > crp; +}; + +int main(int argc, char** argv) { + double tl = 0; + const int ITERS = 1000; + const int PARTICLES = 20; + const int DATAPOINTS = 50; + WordID x = TD::Convert("souvenons"); + WordID y = TD::Convert("remember"); + vector src; TD::ConvertSentence("s o u v e n o n s", &src); + vector trg; TD::ConvertSentence("r e m e m b e r", &trg); + Transliterations xx; + xx.Initialize(x, src, y, trg); + return 1; + + for (int j = 0; j < ITERS; ++j) { + Base b; + vector segs(DATAPOINTS); + SampleSet ss; + vector sss; + for (int i = 0; i < DATAPOINTS; i++) { + ss.clear(); + sss.clear(); + int x = ((i / 10) % 3) + 1; + int y = (i % 3) + 1; + //double ep = b.est_marginal_prob(x,y); + //cerr << "est p(" << x << "," << y << ") = " << ep << endl; + for (int n = 0; n < PARTICLES; ++n) { + double w; + int seg = b.increment(x,y,&w); + //cerr << seg << " w=" << w << endl; + ss.add(w); + sss.push_back(seg); + b.decrement(x,y,seg); + } + int seg = sss[rng.SelectSample(ss)]; + b.increment(x, y, seg); + //cerr << "Selected: " << seg << endl; + //return 1; + segs[i] = seg; + } + tl += b.log_likelihood(); + } + cerr << "LLH=" << tl / ITERS << endl; +} + diff --git a/gi/pf/transliterations.cc b/gi/pf/transliterations.cc new file mode 100644 index 00000000..6e0c2e93 --- /dev/null +++ b/gi/pf/transliterations.cc @@ -0,0 +1,193 @@ +#include "transliterations.h" + +#include +#include +#include + +#include "grammar.h" +#include "bottom_up_parser.h" +#include "hg.h" +#include "hg_intersect.h" +#include "filelib.h" +#include "ccrp.h" +#include "m.h" +#include "lattice.h" +#include "verbose.h" + +using namespace std; +using namespace std::tr1; + +static WordID kX; +static int kMAX_SRC_SIZE = 0; +static vector > cur_trg_chunks; + +vector tlttofreelist; + +static void InitTargetChunks(int max_size, const vector& trg) { + cur_trg_chunks.clear(); + vector tmp; + unordered_set, boost::hash > > u; + for (int len = 1; len <= max_size; ++len) { + int end = trg.size() + 1; + end -= len; + for (int i = 0; i < end; ++i) { + tmp.clear(); + for (int j = 0; j < len; ++j) + tmp.push_back(trg[i + j]); + if (u.insert(tmp).second) cur_trg_chunks.push_back(tmp); + } + } +} + +struct TransliterationGrammarIter : public GrammarIter, public RuleBin { + TransliterationGrammarIter() { tlttofreelist.push_back(this); } + TransliterationGrammarIter(const TRulePtr& inr, int symbol) { + if (inr) { + r.reset(new TRule(*inr)); + } else { + r.reset(new TRule); + } + TRule& rr = *r; + rr.lhs_ = kX; + rr.f_.push_back(symbol); + tlttofreelist.push_back(this); + } + virtual int GetNumRules() const { + if (!r) return 0; + return cur_trg_chunks.size(); + } + virtual TRulePtr GetIthRule(int i) const { + TRulePtr nr(new TRule(*r)); + nr->e_ = cur_trg_chunks[i]; + //cerr << nr->AsString() << endl; + return nr; + } + virtual int Arity() const { + return 0; + } + virtual const RuleBin* GetRules() const { + if (!r) return NULL; else return this; + } + virtual const GrammarIter* Extend(int symbol) const { + if (symbol <= 0) return NULL; + if (!r || !kMAX_SRC_SIZE || r->f_.size() < kMAX_SRC_SIZE) + return new TransliterationGrammarIter(r, symbol); + else + return NULL; + } + TRulePtr r; +}; + +struct TransliterationGrammar : public Grammar { + virtual const GrammarIter* GetRoot() const { + return new TransliterationGrammarIter; + } + virtual bool HasRuleForSpan(int, int, int distance) const { + return (distance < kMAX_SRC_SIZE); + } +}; + +struct TInfo { + TInfo() : initialized(false) {} + bool initialized; + Hypergraph lattice; // may be empty if transliteration is not possible + prob_t est_prob; // will be zero if not possible +}; + +struct TransliterationsImpl { + TransliterationsImpl() { + kX = TD::Convert("X")*-1; + kMAX_SRC_SIZE = 4; + grammars.push_back(GrammarPtr(new TransliterationGrammar)); + grammars.push_back(GrammarPtr(new GlueGrammar("S", "X"))); + SetSilent(true); + } + + void Initialize(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { + if (src >= graphs.size()) graphs.resize(src + 1); + if (graphs[src][trg].initialized) return; + int kMAX_TRG_SIZE = 4; + InitTargetChunks(kMAX_TRG_SIZE, trg_lets); + ExhaustiveBottomUpParser parser("S", grammars); + Lattice lat(src_lets.size()), tlat(trg_lets.size()); + for (unsigned i = 0; i < src_lets.size(); ++i) + lat[i].push_back(LatticeArc(src_lets[i], 0.0, 1)); + for (unsigned i = 0; i < trg_lets.size(); ++i) + tlat[i].push_back(LatticeArc(trg_lets[i], 0.0, 1)); + //cerr << "Creating lattice for: " << TD::Convert(src) << " --> " << TD::Convert(trg) << endl; + //cerr << "'" << TD::GetString(src_lets) << "' --> " << TD::GetString(trg_lets) << endl; + if (!parser.Parse(lat, &graphs[src][trg].lattice)) { + //cerr << "Failed to parse " << TD::GetString(src_lets) << endl; + abort(); + } + if (HG::Intersect(tlat, &graphs[src][trg].lattice)) { + graphs[src][trg].est_prob = prob_t(1e-4); + } else { + graphs[src][trg].lattice.clear(); + //cerr << "Failed to intersect " << TD::GetString(src_lets) << " ||| " << TD::GetString(trg_lets) << endl; + graphs[src][trg].est_prob = prob_t::Zero(); + } + for (unsigned i = 0; i < tlttofreelist.size(); ++i) + delete tlttofreelist[i]; + tlttofreelist.clear(); + //cerr << "Number of paths: " << graphs[src][trg].lattice.NumberOfPaths() << endl; + graphs[src][trg].initialized = true; + } + + const prob_t& EstimateProbability(WordID src, WordID trg) const { + assert(src < graphs.size()); + const unordered_map& um = graphs[src]; + const unordered_map::const_iterator it = um.find(trg); + assert(it != um.end()); + assert(it->second.initialized); + return it->second.est_prob; + } + + void Forbid(WordID src, WordID trg) { + if (src >= graphs.size()) graphs.resize(src + 1); + graphs[src][trg].est_prob = prob_t::Zero(); + graphs[src][trg].initialized = true; + } + + void GraphSummary() const { + double tlp = 0; + int tt = 0; + for (int i = 0; i < graphs.size(); ++i) { + const unordered_map& um = graphs[i]; + unordered_map::const_iterator it; + for (it = um.begin(); it != um.end(); ++it) { + if (it->second.lattice.empty()) continue; + //cerr << TD::Convert(i) << " --> " << TD::Convert(it->first) << ": " << it->second.lattice.NumberOfPaths() << endl; + tlp += log(it->second.lattice.NumberOfPaths()); + tt++; + } + } + tlp /= tt; + cerr << "E[log paths] = " << tlp << endl; + cerr << "exp(E[log paths]) = " << exp(tlp) << endl; + } + + vector > graphs; + vector grammars; +}; + +Transliterations::Transliterations() : pimpl_(new TransliterationsImpl) {} +Transliterations::~Transliterations() { delete pimpl_; } + +void Transliterations::Initialize(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { + pimpl_->Initialize(src, src_lets, trg, trg_lets); +} + +prob_t Transliterations::EstimateProbability(WordID src, WordID trg) const { + return pimpl_->EstimateProbability(src,trg); +} + +void Transliterations::Forbid(WordID src, WordID trg) { + pimpl_->Forbid(src, trg); +} + +void Transliterations::GraphSummary() const { + pimpl_->GraphSummary(); +} + + diff --git a/gi/pf/transliterations.h b/gi/pf/transliterations.h new file mode 100644 index 00000000..a548aacf --- /dev/null +++ b/gi/pf/transliterations.h @@ -0,0 +1,20 @@ +#ifndef _TRANSLITERATIONS_H_ +#define _TRANSLITERATIONS_H_ + +#include +#include "wordid.h" +#include "prob.h" + +struct TransliterationsImpl; +struct Transliterations { + explicit Transliterations(); + ~Transliterations(); + void Initialize(WordID src, const std::vector& src_lets, WordID trg, const std::vector& trg_lets); + void Forbid(WordID src, WordID trg); + void GraphSummary() const; + prob_t EstimateProbability(WordID src, WordID trg) const; + TransliterationsImpl* pimpl_; +}; + +#endif + -- cgit v1.2.3