From 4c2360119def2fb624d2691b355b1908c511f004 Mon Sep 17 00:00:00 2001
From: Chris Dyer <cdyer@cs.cmu.edu>
Date: Tue, 24 Jan 2012 22:26:44 -0500
Subject: more models

---
 training/model1.cc | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 61 insertions(+), 3 deletions(-)

(limited to 'training')

diff --git a/training/model1.cc b/training/model1.cc
index 346c0033..40249aa3 100644
--- a/training/model1.cc
+++ b/training/model1.cc
@@ -14,6 +14,11 @@
 namespace po = boost::program_options;
 using namespace std;
 
+inline double log_poisson(unsigned x, const double& lambda) {
+  assert(lambda > 0.0);
+  return log(lambda) * x - lgamma(x + 1) - lambda;
+}
+
 bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
   po::options_description opts("Configuration options");
   opts.add_options()
@@ -25,6 +30,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
         ("diagonal_tension,T", po::value<double>()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (<1 = flat >1 = sharp)")
         ("prob_align_null", po::value<double>()->default_value(0.08), "When --favor_diagonal is set, what's the probability of a null alignment?")
         ("variational_bayes,v","Add a symmetric Dirichlet prior and infer VB estimate of weights")
+        ("testset,x", po::value<string>(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model")
         ("alpha,a", po::value<double>()->default_value(0.01), "Hyperparameter for optional Dirichlet prior")
         ("no_add_viterbi,V","Do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)");
   po::options_description clo("Command line options");
@@ -63,6 +69,8 @@ int main(int argc, char** argv) {
   const bool write_alignments = (conf.count("write_alignments") > 0);
   const double diagonal_tension = conf["diagonal_tension"].as<double>();
   const double prob_align_null = conf["prob_align_null"].as<double>();
+  string testset;
+  if (conf.count("testset")) testset = conf["testset"].as<string>();
   const double prob_align_not_null = 1.0 - prob_align_null;
   const double alpha = conf["alpha"].as<double>();
   const bool favor_diagonal = conf.count("favor_diagonal");
@@ -73,6 +81,8 @@ int main(int argc, char** argv) {
 
   TTable tt;
   TTable::Word2Word2Double was_viterbi;
+  double tot_len_ratio = 0;
+  double mean_srclen_multiplier = 0;
   for (int iter = 0; iter < ITERATIONS; ++iter) {
     const bool final_iteration = (iter == (ITERATIONS - 1));
     cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl;
@@ -83,13 +93,13 @@ int main(int argc, char** argv) {
     int lc = 0;
     bool flag = false;
     string line;
+    string ssrc, strg;
     while(true) {
       getline(in, line);
       if (!in) break;
       ++lc;
       if (lc % 1000 == 0) { cerr << '.'; flag = true; }
       if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; }
-      string ssrc, strg;
       ParseTranslatorInput(line, &ssrc, &strg);
       Lattice src, trg;
       LatticeTools::ConvertTextToLattice(ssrc, &src);
@@ -99,9 +109,10 @@ int main(int argc, char** argv) {
         assert(src.size() > 0);
         assert(trg.size() > 0);
       }
+      if (iter == 0)
+        tot_len_ratio += static_cast<double>(trg.size()) / static_cast<double>(src.size());
       denom += trg.size();
       vector<double> probs(src.size() + 1);
-      const double src_logprob = -log(src.size() + 1);
       bool first_al = true;  // used for write_alignments
       for (int j = 0; j < trg.size(); ++j) {
         const WordID& f_j = trg[j][0].label;
@@ -156,7 +167,7 @@ int main(int argc, char** argv) {
           for (int i = 1; i <= src.size(); ++i)
             tt.Increment(src[i-1][0].label, f_j, probs[i] / sum);
         }
-        likelihood += log(sum) + src_logprob;
+        likelihood += log(sum);
       }
       if (write_alignments && final_iteration) cout << endl;
     }
@@ -165,6 +176,10 @@ int main(int argc, char** argv) {
     double base2_likelihood = likelihood / log(2);
 
     if (flag) { cerr << endl; }
+    if (iter == 0) {
+      mean_srclen_multiplier = tot_len_ratio / lc;
+      cerr << "expected target length = source length * " << mean_srclen_multiplier << endl;
+    }
     cerr << "  log_e likelihood: " << likelihood << endl;
     cerr << "  log_2 likelihood: " << base2_likelihood << endl;
     cerr << "   cross entropy: " << (-base2_likelihood / denom) << endl;
@@ -176,6 +191,49 @@ int main(int argc, char** argv) {
         tt.Normalize();
     }
   }
+  if (testset.size()) {
+    ReadFile rf(testset);
+    istream& in = *rf.stream();
+    int lc = 0;
+    double tlp = 0;
+    string ssrc, strg, line;
+    while (getline(in, line)) {
+      ++lc;
+      ParseTranslatorInput(line, &ssrc, &strg);
+      Lattice src, trg;
+      LatticeTools::ConvertTextToLattice(ssrc, &src);
+      LatticeTools::ConvertTextToLattice(strg, &trg);
+      double log_prob = log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier);
+
+      // compute likelihood
+      for (int j = 0; j < trg.size(); ++j) {
+        const WordID& f_j = trg[j][0].label;
+        double sum = 0;
+        const double j_over_ts = double(j) / trg.size();
+        double prob_a_i = 1.0 / (src.size() + use_null);  // uniform (model 1)
+        if (use_null) {
+          if (favor_diagonal) prob_a_i = prob_align_null;
+          sum += tt.prob(kNULL, f_j) * prob_a_i;
+        }
+        double az = 0;
+        if (favor_diagonal) {
+          for (int ta = 0; ta < src.size(); ++ta)
+            az += exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension);
+          az /= prob_align_not_null;
+        }
+        for (int i = 1; i <= src.size(); ++i) {
+          if (favor_diagonal)
+            prob_a_i = exp(-fabs(double(i) / src.size() - j_over_ts) * diagonal_tension) / az;
+          sum += tt.prob(src[i-1][0].label, f_j) * prob_a_i;
+        }
+        log_prob += log(sum);
+      }
+      tlp += log_prob;
+      cerr << ssrc << " ||| " << strg << " ||| " << log_prob << endl;
+    }
+    cerr << "TOTAL LOG PROB " << tlp << endl;
+  }
+
   if (write_alignments) return 0;
 
   for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) {
-- 
cgit v1.2.3