diff options
-rw-r--r-- | gi/pf/align-lexonly.cc | 14 | ||||
-rw-r--r-- | gi/pf/base_measures.cc | 2 | ||||
-rw-r--r-- | gi/pf/base_measures.h | 27 | ||||
-rw-r--r-- | training/model1.cc | 64 |
4 files changed, 98 insertions, 9 deletions
diff --git a/gi/pf/align-lexonly.cc b/gi/pf/align-lexonly.cc index e9f1e7b6..76e2e009 100644 --- a/gi/pf/align-lexonly.cc +++ b/gi/pf/align-lexonly.cc @@ -122,10 +122,11 @@ struct BasicLexicalAlignment { vector<AlignedSentencePair>* corp) : letters(lets), corpus(*corp), + up0("fr-en.10k.translit-base.txt.gz"), //up0(words_e), //up0("en.chars.1gram", letters_e), //up0("en.words.1gram"), - up0(letters_e), + //up0(letters_e), //up0("en.chars.2gram"), tmodel(up0) { } @@ -180,14 +181,18 @@ struct BasicLexicalAlignment { //PhraseConditionalUninformativeUnigramBase up0; //UnigramWordBase up0; //HierarchicalUnigramBase up0; - HierarchicalWordBase up0; + TableLookupBase up0; + //HierarchicalWordBase up0; + //PoissonUniformUninformativeBase up0; //CompletelyUniformBase up0; //FixedNgramBase up0; //ConditionalTranslationModel<PhraseConditionalUninformativeBase> tmodel; //ConditionalTranslationModel<PhraseConditionalUninformativeUnigramBase> tmodel; //ConditionalTranslationModel<UnigramWordBase> tmodel; //ConditionalTranslationModel<HierarchicalUnigramBase> tmodel; - ConditionalTranslationModel<HierarchicalWordBase> tmodel; + //ConditionalTranslationModel<HierarchicalWordBase> tmodel; + //ConditionalTranslationModel<PoissonUniformUninformativeBase> tmodel; + ConditionalTranslationModel<TableLookupBase> tmodel; //ConditionalTranslationModel<FixedNgramBase> tmodel; //ConditionalTranslationModel<CompletelyUniformBase> tmodel; }; @@ -222,6 +227,7 @@ void BasicLexicalAlignment::ResampleCorpus() { void ExtractLetters(const set<WordID>& v, vector<vector<WordID> >* l, set<WordID>* letset = NULL) { for (set<WordID>::const_iterator it = v.begin(); it != v.end(); ++it) { + if (*it >= l->size()) { l->resize(*it + 1); } vector<WordID>& letters = (*l)[*it]; if (letters.size()) continue; // if e and f have the same word @@ -308,7 +314,7 @@ int main(int argc, char** argv) { x.InitializeRandom(); const unsigned samples = conf["samples"].as<unsigned>(); for (int i = 0; i < samples; ++i) { - for (int j = 4995; j < 4997; ++j) Debug(corpus[j]); + for (int j = 395; j < 397; ++j) Debug(corpus[j]); cerr << i << "\t" << x.tmodel.r.size() << "\t"; if (i % 10 == 0) x.ResampleHyperparemeters(); x.ResampleCorpus(); diff --git a/gi/pf/base_measures.cc b/gi/pf/base_measures.cc index 7894d3e7..4b1863fa 100644 --- a/gi/pf/base_measures.cc +++ b/gi/pf/base_measures.cc @@ -37,7 +37,7 @@ TableLookupBase::TableLookupBase(const string& fname) { } else if (cc == 1) { x.e_.push_back(cur); } else if (cc == 2) { - table[x] = atof(TD::Convert(cur)); + table[x].logeq(atof(TD::Convert(cur))); ++cc; } else { if (flag) cerr << endl; diff --git a/gi/pf/base_measures.h b/gi/pf/base_measures.h index 7214aa22..b0495bfd 100644 --- a/gi/pf/base_measures.h +++ b/gi/pf/base_measures.h @@ -51,6 +51,22 @@ struct Model1 { std::vector<std::map<WordID, prob_t> > ttable; }; +struct PoissonUniformUninformativeBase { + explicit PoissonUniformUninformativeBase(const unsigned ves) : kUNIFORM(1.0 / ves) {} + prob_t operator()(const TRule& r) const { + prob_t p; p.logeq(log_poisson(r.e_.size(), 1.0)); + prob_t q = kUNIFORM; q.poweq(r.e_.size()); + p *= q; + return p; + } + void Summary() const {} + void ResampleHyperparameters(MT19937*) {} + void Increment(const TRule&) {} + void Decrement(const TRule&) {} + prob_t Likelihood() const { return prob_t::One(); } + const prob_t kUNIFORM; +}; + struct CompletelyUniformBase { explicit CompletelyUniformBase(const unsigned ves) : kUNIFORM(1.0 / ves) {} prob_t operator()(const TRule&) const { @@ -83,10 +99,19 @@ struct TableLookupBase { prob_t operator()(const TRule& rule) const { const std::tr1::unordered_map<TRule,prob_t>::const_iterator it = table.find(rule); - assert(it != table.end()); + if (it == table.end()) { + std::cerr << rule << " not found\n"; + abort(); + } return it->second; } + void ResampleHyperparameters(MT19937*) {} + void Increment(const TRule&) {} + void Decrement(const TRule&) {} + prob_t Likelihood() const { return prob_t::One(); } + void Summary() const {} + std::tr1::unordered_map<TRule,prob_t,RuleHasher> table; }; diff --git a/training/model1.cc b/training/model1.cc index 346c0033..40249aa3 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -14,6 +14,11 @@ namespace po = boost::program_options; using namespace std; +inline double log_poisson(unsigned x, const double& lambda) { + assert(lambda > 0.0); + return log(lambda) * x - lgamma(x + 1) - lambda; +} + bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() @@ -25,6 +30,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("diagonal_tension,T", po::value<double>()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (<1 = flat >1 = sharp)") ("prob_align_null", po::value<double>()->default_value(0.08), "When --favor_diagonal is set, what's the probability of a null alignment?") ("variational_bayes,v","Add a symmetric Dirichlet prior and infer VB estimate of weights") + ("testset,x", po::value<string>(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model") ("alpha,a", po::value<double>()->default_value(0.01), "Hyperparameter for optional Dirichlet prior") ("no_add_viterbi,V","Do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)"); po::options_description clo("Command line options"); @@ -63,6 +69,8 @@ int main(int argc, char** argv) { const bool write_alignments = (conf.count("write_alignments") > 0); const double diagonal_tension = conf["diagonal_tension"].as<double>(); const double prob_align_null = conf["prob_align_null"].as<double>(); + string testset; + if (conf.count("testset")) testset = conf["testset"].as<string>(); const double prob_align_not_null = 1.0 - prob_align_null; const double alpha = conf["alpha"].as<double>(); const bool favor_diagonal = conf.count("favor_diagonal"); @@ -73,6 +81,8 @@ int main(int argc, char** argv) { TTable tt; TTable::Word2Word2Double was_viterbi; + double tot_len_ratio = 0; + double mean_srclen_multiplier = 0; for (int iter = 0; iter < ITERATIONS; ++iter) { const bool final_iteration = (iter == (ITERATIONS - 1)); cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; @@ -83,13 +93,13 @@ int main(int argc, char** argv) { int lc = 0; bool flag = false; string line; + string ssrc, strg; while(true) { getline(in, line); if (!in) break; ++lc; if (lc % 1000 == 0) { cerr << '.'; flag = true; } if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; } - string ssrc, strg; ParseTranslatorInput(line, &ssrc, &strg); Lattice src, trg; LatticeTools::ConvertTextToLattice(ssrc, &src); @@ -99,9 +109,10 @@ int main(int argc, char** argv) { assert(src.size() > 0); assert(trg.size() > 0); } + if (iter == 0) + tot_len_ratio += static_cast<double>(trg.size()) / static_cast<double>(src.size()); denom += trg.size(); vector<double> probs(src.size() + 1); - const double src_logprob = -log(src.size() + 1); bool first_al = true; // used for write_alignments for (int j = 0; j < trg.size(); ++j) { const WordID& f_j = trg[j][0].label; @@ -156,7 +167,7 @@ int main(int argc, char** argv) { for (int i = 1; i <= src.size(); ++i) tt.Increment(src[i-1][0].label, f_j, probs[i] / sum); } - likelihood += log(sum) + src_logprob; + likelihood += log(sum); } if (write_alignments && final_iteration) cout << endl; } @@ -165,6 +176,10 @@ int main(int argc, char** argv) { double base2_likelihood = likelihood / log(2); if (flag) { cerr << endl; } + if (iter == 0) { + mean_srclen_multiplier = tot_len_ratio / lc; + cerr << "expected target length = source length * " << mean_srclen_multiplier << endl; + } cerr << " log_e likelihood: " << likelihood << endl; cerr << " log_2 likelihood: " << base2_likelihood << endl; cerr << " cross entropy: " << (-base2_likelihood / denom) << endl; @@ -176,6 +191,49 @@ int main(int argc, char** argv) { tt.Normalize(); } } + if (testset.size()) { + ReadFile rf(testset); + istream& in = *rf.stream(); + int lc = 0; + double tlp = 0; + string ssrc, strg, line; + while (getline(in, line)) { + ++lc; + ParseTranslatorInput(line, &ssrc, &strg); + Lattice src, trg; + LatticeTools::ConvertTextToLattice(ssrc, &src); + LatticeTools::ConvertTextToLattice(strg, &trg); + double log_prob = log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier); + + // compute likelihood + for (int j = 0; j < trg.size(); ++j) { + const WordID& f_j = trg[j][0].label; + double sum = 0; + const double j_over_ts = double(j) / trg.size(); + double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1) + if (use_null) { + if (favor_diagonal) prob_a_i = prob_align_null; + sum += tt.prob(kNULL, f_j) * prob_a_i; + } + double az = 0; + if (favor_diagonal) { + for (int ta = 0; ta < src.size(); ++ta) + az += exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); + az /= prob_align_not_null; + } + for (int i = 1; i <= src.size(); ++i) { + if (favor_diagonal) + prob_a_i = exp(-fabs(double(i) / src.size() - j_over_ts) * diagonal_tension) / az; + sum += tt.prob(src[i-1][0].label, f_j) * prob_a_i; + } + log_prob += log(sum); + } + tlp += log_prob; + cerr << ssrc << " ||| " << strg << " ||| " << log_prob << endl; + } + cerr << "TOTAL LOG PROB " << tlp << endl; + } + if (write_alignments) return 0; for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) { |