From f0bdd4de6455855d705d9056deb2e90c999dc740 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 20 Jan 2012 15:35:47 -0500 Subject: 'pseudo model 2' that strictly favors a diagonal, with tunable parameters for p(null) and how sharp/flat the alignment distribution is around the diagonal --- training/model1.cc | 39 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) (limited to 'training/model1.cc') diff --git a/training/model1.cc b/training/model1.cc index b9590ece..346c0033 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -20,6 +20,10 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("iterations,i",po::value()->default_value(5),"Number of iterations of EM training") ("beam_threshold,t",po::value()->default_value(-4),"log_10 of beam threshold (-10000 to include everything, 0 max)") ("no_null_word,N","Do not generate from the null token") + ("write_alignments,A", "Write alignments instead of parameters") + ("favor_diagonal,d", "Use a static alignment distribution that assigns higher probabilities to alignments near the diagonal") + ("diagonal_tension,T", po::value()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (<1 = flat >1 = sharp)") + ("prob_align_null", po::value()->default_value(0.08), "When --favor_diagonal is set, what's the probability of a null alignment?") ("variational_bayes,v","Add a symmetric Dirichlet prior and infer VB estimate of weights") ("alpha,a", po::value()->default_value(0.01), "Hyperparameter for optional Dirichlet prior") ("no_add_viterbi,V","Do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)"); @@ -56,7 +60,12 @@ int main(int argc, char** argv) { const WordID kNULL = TD::Convert(""); const bool add_viterbi = (conf.count("no_add_viterbi") == 0); const bool variational_bayes = (conf.count("variational_bayes") > 0); + const bool write_alignments = (conf.count("write_alignments") > 0); + const double diagonal_tension = conf["diagonal_tension"].as(); + const double prob_align_null = conf["prob_align_null"].as(); + const double prob_align_not_null = 1.0 - prob_align_null; const double alpha = conf["alpha"].as(); + const bool favor_diagonal = conf.count("favor_diagonal"); if (variational_bayes && alpha <= 0.0) { cerr << "--alpha must be > 0\n"; return 1; @@ -93,31 +102,52 @@ int main(int argc, char** argv) { denom += trg.size(); vector probs(src.size() + 1); const double src_logprob = -log(src.size() + 1); + bool first_al = true; // used for write_alignments for (int j = 0; j < trg.size(); ++j) { const WordID& f_j = trg[j][0].label; double sum = 0; + const double j_over_ts = double(j) / trg.size(); + double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1) if (use_null) { - probs[0] = tt.prob(kNULL, f_j); + if (favor_diagonal) prob_a_i = prob_align_null; + probs[0] = tt.prob(kNULL, f_j) * prob_a_i; sum += probs[0]; } + double az = 0; + if (favor_diagonal) { + for (int ta = 0; ta < src.size(); ++ta) + az += exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); + az /= prob_align_not_null; + } for (int i = 1; i <= src.size(); ++i) { - probs[i] = tt.prob(src[i-1][0].label, f_j); + if (favor_diagonal) + prob_a_i = exp(-fabs(double(i) / src.size() - j_over_ts) * diagonal_tension) / az; + probs[i] = tt.prob(src[i-1][0].label, f_j) * prob_a_i; sum += probs[i]; } if (final_iteration) { - if (add_viterbi) { + if (add_viterbi || write_alignments) { WordID max_i = 0; double max_p = -1; + int max_index = -1; if (use_null) { max_i = kNULL; + max_index = 0; max_p = probs[0]; } for (int i = 1; i <= src.size(); ++i) { if (probs[i] > max_p) { + max_index = i; max_p = probs[i]; max_i = src[i-1][0].label; } } + if (write_alignments) { + if (max_index > 0) { + if (first_al) first_al = false; else cout << ' '; + cout << (max_index - 1) << "-" << j; + } + } was_viterbi[max_i][f_j] = 1.0; } } else { @@ -128,6 +158,7 @@ int main(int argc, char** argv) { } likelihood += log(sum) + src_logprob; } + if (write_alignments && final_iteration) cout << endl; } // log(e) = 1.0 @@ -145,6 +176,8 @@ int main(int argc, char** argv) { tt.Normalize(); } } + if (write_alignments) return 0; + for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) { const TTable::Word2Double& cpd = ei->second; const TTable::Word2Double& vit = was_viterbi[ei->first]; -- cgit v1.2.3 From 4c2360119def2fb624d2691b355b1908c511f004 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 24 Jan 2012 22:26:44 -0500 Subject: more models --- gi/pf/align-lexonly.cc | 14 +++++++---- gi/pf/base_measures.cc | 2 +- gi/pf/base_measures.h | 27 ++++++++++++++++++++- training/model1.cc | 64 +++++++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 98 insertions(+), 9 deletions(-) (limited to 'training/model1.cc') diff --git a/gi/pf/align-lexonly.cc b/gi/pf/align-lexonly.cc index e9f1e7b6..76e2e009 100644 --- a/gi/pf/align-lexonly.cc +++ b/gi/pf/align-lexonly.cc @@ -122,10 +122,11 @@ struct BasicLexicalAlignment { vector* corp) : letters(lets), corpus(*corp), + up0("fr-en.10k.translit-base.txt.gz"), //up0(words_e), //up0("en.chars.1gram", letters_e), //up0("en.words.1gram"), - up0(letters_e), + //up0(letters_e), //up0("en.chars.2gram"), tmodel(up0) { } @@ -180,14 +181,18 @@ struct BasicLexicalAlignment { //PhraseConditionalUninformativeUnigramBase up0; //UnigramWordBase up0; //HierarchicalUnigramBase up0; - HierarchicalWordBase up0; + TableLookupBase up0; + //HierarchicalWordBase up0; + //PoissonUniformUninformativeBase up0; //CompletelyUniformBase up0; //FixedNgramBase up0; //ConditionalTranslationModel tmodel; //ConditionalTranslationModel tmodel; //ConditionalTranslationModel tmodel; //ConditionalTranslationModel tmodel; - ConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; + ConditionalTranslationModel tmodel; //ConditionalTranslationModel tmodel; //ConditionalTranslationModel tmodel; }; @@ -222,6 +227,7 @@ void BasicLexicalAlignment::ResampleCorpus() { void ExtractLetters(const set& v, vector >* l, set* letset = NULL) { for (set::const_iterator it = v.begin(); it != v.end(); ++it) { + if (*it >= l->size()) { l->resize(*it + 1); } vector& letters = (*l)[*it]; if (letters.size()) continue; // if e and f have the same word @@ -308,7 +314,7 @@ int main(int argc, char** argv) { x.InitializeRandom(); const unsigned samples = conf["samples"].as(); for (int i = 0; i < samples; ++i) { - for (int j = 4995; j < 4997; ++j) Debug(corpus[j]); + for (int j = 395; j < 397; ++j) Debug(corpus[j]); cerr << i << "\t" << x.tmodel.r.size() << "\t"; if (i % 10 == 0) x.ResampleHyperparemeters(); x.ResampleCorpus(); diff --git a/gi/pf/base_measures.cc b/gi/pf/base_measures.cc index 7894d3e7..4b1863fa 100644 --- a/gi/pf/base_measures.cc +++ b/gi/pf/base_measures.cc @@ -37,7 +37,7 @@ TableLookupBase::TableLookupBase(const string& fname) { } else if (cc == 1) { x.e_.push_back(cur); } else if (cc == 2) { - table[x] = atof(TD::Convert(cur)); + table[x].logeq(atof(TD::Convert(cur))); ++cc; } else { if (flag) cerr << endl; diff --git a/gi/pf/base_measures.h b/gi/pf/base_measures.h index 7214aa22..b0495bfd 100644 --- a/gi/pf/base_measures.h +++ b/gi/pf/base_measures.h @@ -51,6 +51,22 @@ struct Model1 { std::vector > ttable; }; +struct PoissonUniformUninformativeBase { + explicit PoissonUniformUninformativeBase(const unsigned ves) : kUNIFORM(1.0 / ves) {} + prob_t operator()(const TRule& r) const { + prob_t p; p.logeq(log_poisson(r.e_.size(), 1.0)); + prob_t q = kUNIFORM; q.poweq(r.e_.size()); + p *= q; + return p; + } + void Summary() const {} + void ResampleHyperparameters(MT19937*) {} + void Increment(const TRule&) {} + void Decrement(const TRule&) {} + prob_t Likelihood() const { return prob_t::One(); } + const prob_t kUNIFORM; +}; + struct CompletelyUniformBase { explicit CompletelyUniformBase(const unsigned ves) : kUNIFORM(1.0 / ves) {} prob_t operator()(const TRule&) const { @@ -83,10 +99,19 @@ struct TableLookupBase { prob_t operator()(const TRule& rule) const { const std::tr1::unordered_map::const_iterator it = table.find(rule); - assert(it != table.end()); + if (it == table.end()) { + std::cerr << rule << " not found\n"; + abort(); + } return it->second; } + void ResampleHyperparameters(MT19937*) {} + void Increment(const TRule&) {} + void Decrement(const TRule&) {} + prob_t Likelihood() const { return prob_t::One(); } + void Summary() const {} + std::tr1::unordered_map table; }; diff --git a/training/model1.cc b/training/model1.cc index 346c0033..40249aa3 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -14,6 +14,11 @@ namespace po = boost::program_options; using namespace std; +inline double log_poisson(unsigned x, const double& lambda) { + assert(lambda > 0.0); + return log(lambda) * x - lgamma(x + 1) - lambda; +} + bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() @@ -25,6 +30,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("diagonal_tension,T", po::value()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (<1 = flat >1 = sharp)") ("prob_align_null", po::value()->default_value(0.08), "When --favor_diagonal is set, what's the probability of a null alignment?") ("variational_bayes,v","Add a symmetric Dirichlet prior and infer VB estimate of weights") + ("testset,x", po::value(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model") ("alpha,a", po::value()->default_value(0.01), "Hyperparameter for optional Dirichlet prior") ("no_add_viterbi,V","Do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)"); po::options_description clo("Command line options"); @@ -63,6 +69,8 @@ int main(int argc, char** argv) { const bool write_alignments = (conf.count("write_alignments") > 0); const double diagonal_tension = conf["diagonal_tension"].as(); const double prob_align_null = conf["prob_align_null"].as(); + string testset; + if (conf.count("testset")) testset = conf["testset"].as(); const double prob_align_not_null = 1.0 - prob_align_null; const double alpha = conf["alpha"].as(); const bool favor_diagonal = conf.count("favor_diagonal"); @@ -73,6 +81,8 @@ int main(int argc, char** argv) { TTable tt; TTable::Word2Word2Double was_viterbi; + double tot_len_ratio = 0; + double mean_srclen_multiplier = 0; for (int iter = 0; iter < ITERATIONS; ++iter) { const bool final_iteration = (iter == (ITERATIONS - 1)); cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; @@ -83,13 +93,13 @@ int main(int argc, char** argv) { int lc = 0; bool flag = false; string line; + string ssrc, strg; while(true) { getline(in, line); if (!in) break; ++lc; if (lc % 1000 == 0) { cerr << '.'; flag = true; } if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; } - string ssrc, strg; ParseTranslatorInput(line, &ssrc, &strg); Lattice src, trg; LatticeTools::ConvertTextToLattice(ssrc, &src); @@ -99,9 +109,10 @@ int main(int argc, char** argv) { assert(src.size() > 0); assert(trg.size() > 0); } + if (iter == 0) + tot_len_ratio += static_cast(trg.size()) / static_cast(src.size()); denom += trg.size(); vector probs(src.size() + 1); - const double src_logprob = -log(src.size() + 1); bool first_al = true; // used for write_alignments for (int j = 0; j < trg.size(); ++j) { const WordID& f_j = trg[j][0].label; @@ -156,7 +167,7 @@ int main(int argc, char** argv) { for (int i = 1; i <= src.size(); ++i) tt.Increment(src[i-1][0].label, f_j, probs[i] / sum); } - likelihood += log(sum) + src_logprob; + likelihood += log(sum); } if (write_alignments && final_iteration) cout << endl; } @@ -165,6 +176,10 @@ int main(int argc, char** argv) { double base2_likelihood = likelihood / log(2); if (flag) { cerr << endl; } + if (iter == 0) { + mean_srclen_multiplier = tot_len_ratio / lc; + cerr << "expected target length = source length * " << mean_srclen_multiplier << endl; + } cerr << " log_e likelihood: " << likelihood << endl; cerr << " log_2 likelihood: " << base2_likelihood << endl; cerr << " cross entropy: " << (-base2_likelihood / denom) << endl; @@ -176,6 +191,49 @@ int main(int argc, char** argv) { tt.Normalize(); } } + if (testset.size()) { + ReadFile rf(testset); + istream& in = *rf.stream(); + int lc = 0; + double tlp = 0; + string ssrc, strg, line; + while (getline(in, line)) { + ++lc; + ParseTranslatorInput(line, &ssrc, &strg); + Lattice src, trg; + LatticeTools::ConvertTextToLattice(ssrc, &src); + LatticeTools::ConvertTextToLattice(strg, &trg); + double log_prob = log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier); + + // compute likelihood + for (int j = 0; j < trg.size(); ++j) { + const WordID& f_j = trg[j][0].label; + double sum = 0; + const double j_over_ts = double(j) / trg.size(); + double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1) + if (use_null) { + if (favor_diagonal) prob_a_i = prob_align_null; + sum += tt.prob(kNULL, f_j) * prob_a_i; + } + double az = 0; + if (favor_diagonal) { + for (int ta = 0; ta < src.size(); ++ta) + az += exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); + az /= prob_align_not_null; + } + for (int i = 1; i <= src.size(); ++i) { + if (favor_diagonal) + prob_a_i = exp(-fabs(double(i) / src.size() - j_over_ts) * diagonal_tension) / az; + sum += tt.prob(src[i-1][0].label, f_j) * prob_a_i; + } + log_prob += log(sum); + } + tlp += log_prob; + cerr << ssrc << " ||| " << strg << " ||| " << log_prob << endl; + } + cerr << "TOTAL LOG PROB " << tlp << endl; + } + if (write_alignments) return 0; for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) { -- cgit v1.2.3 From a38b3fa383412e56eb958db998662c026bc08f4b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 17 Feb 2012 13:01:54 -0500 Subject: boost version checking, check for Eigen, get rid of old digamma stuff --- configure.ac | 21 +++++++++++++++------ training/em_utils.h | 24 ------------------------ training/model1.cc | 1 - training/mr_em_adapted_reduce.cc | 6 +++--- training/ttables.h | 4 ++-- utils/m.h | 6 ++++++ 6 files changed, 26 insertions(+), 36 deletions(-) delete mode 100644 training/em_utils.h (limited to 'training/model1.cc') diff --git a/configure.ac b/configure.ac index cd78ee72..aa79027f 100644 --- a/configure.ac +++ b/configure.ac @@ -9,7 +9,7 @@ esac AC_PROG_CC AC_PROG_CXX AC_LANG_CPLUSPLUS -BOOST_REQUIRE +BOOST_REQUIRE([1.44]) BOOST_PROGRAM_OPTIONS AC_ARG_ENABLE(mpi, [ --enable-mpi Build MPI binaries, assumes mpi.h is present ], @@ -38,7 +38,7 @@ then CPPFLAGS="$CPPFLAGS -I${with_cmph}/include" AC_CHECK_HEADER(cmph.h, - [AC_DEFINE([HAVE_CMPH], [], [flag for cmph perfect hashing library])], + [AC_DEFINE([HAVE_CMPH], [1], [flag for cmph perfect hashing library])], [AC_MSG_ERROR([Cannot find cmph library!])]) LDFLAGS="$LDFLAGS -L${with_cmph}/lib" @@ -46,6 +46,18 @@ then AM_CONDITIONAL([HAVE_CMPH], true) fi +if test "x$with_eigen" != 'xno' +then + SAVE_CPPFLAGS="$CPPFLAGS" + CPPFLAGS="$CPPFLAGS -I${with_eigen}" + + AC_CHECK_HEADER(Eigen, + [AC_DEFINE([HAVE_EIGEN], [1], [flag for Eigen linear algebra library])], + [AC_MSG_ERROR([Cannot find Eigen!])]) + + AM_CONDITIONAL([HAVE_EIGEN], true) +fi + #BOOST_THREADS CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS" @@ -53,11 +65,8 @@ LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS" LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS" # $BOOST_THREAD_LIBS" -AC_CHECK_HEADER(boost/math/special_functions/digamma.hpp, - [AC_DEFINE([HAVE_BOOST_DIGAMMA], [], [flag for boost::math::digamma])]) - AC_CHECK_HEADER(google/dense_hash_map, - [AC_DEFINE([HAVE_SPARSEHASH], [], [flag for google::dense_hash_map])]) + [AC_DEFINE([HAVE_SPARSEHASH], [1], [flag for google::dense_hash_map])]) AC_PROG_INSTALL GTEST_LIB_CHECK(1.0) diff --git a/training/em_utils.h b/training/em_utils.h deleted file mode 100644 index 37762978..00000000 --- a/training/em_utils.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef _EM_UTILS_H_ -#define _EM_UTILS_H_ - -#include "config.h" -#ifdef HAVE_BOOST_DIGAMMA -#include -using boost::math::digamma; -#else -#warning Using Mark Johnsons digamma() -#include -inline double digamma(double x) { - double result = 0, xx, xx2, xx4; - assert(x > 0); - for ( ; x < 7; ++x) - result -= 1/x; - x -= 1.0/2.0; - xx = 1.0/x; - xx2 = xx*xx; - xx4 = xx2*xx2; - result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4; - return result; -} -#endif -#endif diff --git a/training/model1.cc b/training/model1.cc index 40249aa3..a87d388f 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -9,7 +9,6 @@ #include "filelib.h" #include "ttables.h" #include "tdict.h" -#include "em_utils.h" namespace po = boost::program_options; using namespace std; diff --git a/training/mr_em_adapted_reduce.cc b/training/mr_em_adapted_reduce.cc index d4c16a2f..f65b5440 100644 --- a/training/mr_em_adapted_reduce.cc +++ b/training/mr_em_adapted_reduce.cc @@ -10,7 +10,7 @@ #include "fdict.h" #include "weights.h" #include "sparse_vector.h" -#include "em_utils.h" +#include "m.h" using namespace std; namespace po = boost::program_options; @@ -63,11 +63,11 @@ void Maximize(const bool use_vb, assert(tot > 0.0); double ltot = log(tot); if (use_vb) - ltot = digamma(tot + total_event_types * alpha); + ltot = Md::digamma(tot + total_event_types * alpha); for (SparseVector::const_iterator it = counts.begin(); it != counts.end(); ++it) { if (use_vb) { - pc->set_value(it->first, NoZero(digamma(it->second + alpha) - ltot)); + pc->set_value(it->first, NoZero(Md::digamma(it->second + alpha) - ltot)); } else { pc->set_value(it->first, NoZero(log(it->second) - ltot)); } diff --git a/training/ttables.h b/training/ttables.h index 50d85a68..bf3351d2 100644 --- a/training/ttables.h +++ b/training/ttables.h @@ -4,9 +4,9 @@ #include #include +#include "m.h" #include "wordid.h" #include "tdict.h" -#include "em_utils.h" class TTable { public: @@ -39,7 +39,7 @@ class TTable { for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) tot += it->second + alpha; for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) - it->second = exp(digamma(it->second + alpha) - digamma(tot)); + it->second = exp(Md::digamma(it->second + alpha) - Md::digamma(tot)); } counts.clear(); } diff --git a/utils/m.h b/utils/m.h index b25248c2..5e45efee 100644 --- a/utils/m.h +++ b/utils/m.h @@ -3,6 +3,7 @@ #include #include +#include template struct M { @@ -81,6 +82,11 @@ struct M { } } + // digamma is the first derivative of the log-gamma function + static inline F digamma(const F& x) { + return boost::math::digamma(x); + } + }; typedef M Md; -- cgit v1.2.3 From d3ccf26cf501cb15ed300bc0ad17596a4e59fbeb Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 18 Feb 2012 15:16:17 -0500 Subject: fix diagonal model --- configure.ac | 2 +- training/model1.cc | 29 +++++++++++++++++------------ 2 files changed, 18 insertions(+), 13 deletions(-) (limited to 'training/model1.cc') diff --git a/configure.ac b/configure.ac index aa79027f..026dad01 100644 --- a/configure.ac +++ b/configure.ac @@ -51,7 +51,7 @@ then SAVE_CPPFLAGS="$CPPFLAGS" CPPFLAGS="$CPPFLAGS -I${with_eigen}" - AC_CHECK_HEADER(Eigen, + AC_CHECK_HEADER(Eigen/Dense, [AC_DEFINE([HAVE_EIGEN], [1], [flag for Eigen linear algebra library])], [AC_MSG_ERROR([Cannot find Eigen!])]) diff --git a/training/model1.cc b/training/model1.cc index a87d388f..73104304 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -4,6 +4,7 @@ #include #include +#include "m.h" #include "lattice.h" #include "stringlib.h" #include "filelib.h" @@ -13,11 +14,6 @@ namespace po = boost::program_options; using namespace std; -inline double log_poisson(unsigned x, const double& lambda) { - assert(lambda > 0.0); - return log(lambda) * x - lgamma(x + 1) - lambda; -} - bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() @@ -82,6 +78,7 @@ int main(int argc, char** argv) { TTable::Word2Word2Double was_viterbi; double tot_len_ratio = 0; double mean_srclen_multiplier = 0; + vector unnormed_a_i; for (int iter = 0; iter < ITERATIONS; ++iter) { const bool final_iteration = (iter == (ITERATIONS - 1)); cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; @@ -108,6 +105,8 @@ int main(int argc, char** argv) { assert(src.size() > 0); assert(trg.size() > 0); } + if (src.size() > unnormed_a_i.size()) + unnormed_a_i.resize(src.size()); if (iter == 0) tot_len_ratio += static_cast(trg.size()) / static_cast(src.size()); denom += trg.size(); @@ -125,13 +124,15 @@ int main(int argc, char** argv) { } double az = 0; if (favor_diagonal) { - for (int ta = 0; ta < src.size(); ++ta) - az += exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); + for (int ta = 0; ta < src.size(); ++ta) { + unnormed_a_i[ta] = exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); + az += unnormed_a_i[ta]; + } az /= prob_align_not_null; } for (int i = 1; i <= src.size(); ++i) { if (favor_diagonal) - prob_a_i = exp(-fabs(double(i) / src.size() - j_over_ts) * diagonal_tension) / az; + prob_a_i = unnormed_a_i[i-1] / az; probs[i] = tt.prob(src[i-1][0].label, f_j) * prob_a_i; sum += probs[i]; } @@ -202,7 +203,9 @@ int main(int argc, char** argv) { Lattice src, trg; LatticeTools::ConvertTextToLattice(ssrc, &src); LatticeTools::ConvertTextToLattice(strg, &trg); - double log_prob = log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier); + double log_prob = Md::log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier); + if (src.size() > unnormed_a_i.size()) + unnormed_a_i.resize(src.size()); // compute likelihood for (int j = 0; j < trg.size(); ++j) { @@ -216,13 +219,15 @@ int main(int argc, char** argv) { } double az = 0; if (favor_diagonal) { - for (int ta = 0; ta < src.size(); ++ta) - az += exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); + for (int ta = 0; ta < src.size(); ++ta) { + unnormed_a_i[ta] = exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); + az += unnormed_a_i[ta]; + } az /= prob_align_not_null; } for (int i = 1; i <= src.size(); ++i) { if (favor_diagonal) - prob_a_i = exp(-fabs(double(i) / src.size() - j_over_ts) * diagonal_tension) / az; + prob_a_i = unnormed_a_i[i-1] / az; sum += tt.prob(src[i-1][0].label, f_j) * prob_a_i; } log_prob += log(sum); -- cgit v1.2.3