From 105a52a8d37497fe69a01a7de771ef9b9300cd71 Mon Sep 17 00:00:00 2001
From: Chris Dyer <cdyer@cs.cmu.edu>
Date: Fri, 11 Nov 2011 17:12:39 -0500
Subject: optionally sample from forest to get training instances, rather than
 k-best it

---
 mira/kbest_mira.cc | 79 +++++++++++++++++++++++++++++++++++++++++-------------
 1 file changed, 60 insertions(+), 19 deletions(-)

(limited to 'mira')

diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc
index 459a5e6f..904eba74 100644
--- a/mira/kbest_mira.cc
+++ b/mira/kbest_mira.cc
@@ -10,6 +10,7 @@
 #include <boost/program_options.hpp>
 #include <boost/program_options/variables_map.hpp>
 
+#include "hg_sampler.h"
 #include "sentence_metadata.h"
 #include "scorer.h"
 #include "verbose.h"
@@ -54,6 +55,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
         ("max_step_size,C", po::value<double>()->default_value(0.01), "regularization strength (C)")
         ("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Amount to scale MT loss function by")
         ("k_best_size,k", po::value<int>()->default_value(250), "Size of hypothesis list to search for oracles")
+        ("sample_forest,f", "Instead of a k-best list, sample k hypotheses from the decoder's forest")
+        ("sample_forest_unit_weight_vector,x", "Before sampling (must use -f option), rescale the weight vector used so it has unit length; this may improve the quality of the samples")
         ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)")
         ("decoder_config,c",po::value<string>(),"Decoder configuration file");
   po::options_description clo("Command line options");
@@ -91,11 +94,12 @@ struct GoodBadOracle {
 };
 
 struct TrainingObserver : public DecoderObserver {
-  TrainingObserver(const int k, const DocScorer& d, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k) {}
+  TrainingObserver(const int k, const DocScorer& d, bool sf, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k), sample_forest(sf) {}
   const DocScorer& ds;
   vector<GoodBadOracle>& oracles;
   shared_ptr<HypothesisInfo> cur_best;
   const int kbest_size;
+  const bool sample_forest;
 
   const HypothesisInfo& GetCurrentBestHypothesis() const {
     return *cur_best;
@@ -116,24 +120,43 @@ struct TrainingObserver : public DecoderObserver {
     shared_ptr<HypothesisInfo>& cur_good = oracles[sent_id].good;
     shared_ptr<HypothesisInfo>& cur_bad = oracles[sent_id].bad;
     cur_bad.reset();  // TODO get rid of??
-    KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size);
-    for (int i = 0; i < kbest_size; ++i) {
-      const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
-        kbest.LazyKthBest(forest.nodes_.size() - 1, i);
-      if (!d) break;
-      float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore();
-      if (invert_score) sentscore *= -1.0;
-      // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl;
-      if (i == 0)
-        cur_best = MakeHypothesisInfo(d->feature_values, sentscore);
-      if (!cur_good || sentscore > cur_good->mt_metric)
-        cur_good = MakeHypothesisInfo(d->feature_values, sentscore);
-      if (!cur_bad || sentscore < cur_bad->mt_metric)
-        cur_bad = MakeHypothesisInfo(d->feature_values, sentscore);
+
+    if (sample_forest) {
+      vector<WordID> cur_prediction;
+      ViterbiESentence(forest, &cur_prediction);
+      float sentscore = ds[sent_id]->ScoreCandidate(cur_prediction)->ComputeScore();
+      cur_best = MakeHypothesisInfo(ViterbiFeatures(forest), sentscore);
+
+      vector<HypergraphSampler::Hypothesis> samples;
+      HypergraphSampler::sample_hypotheses(forest, kbest_size, &*rng, &samples);
+      for (unsigned i = 0; i < samples.size(); ++i) {
+        sentscore = ds[sent_id]->ScoreCandidate(samples[i].words)->ComputeScore();
+        if (invert_score) sentscore *= -1.0;
+        if (!cur_good || sentscore > cur_good->mt_metric)
+          cur_good = MakeHypothesisInfo(samples[i].fmap, sentscore);
+        if (!cur_bad || sentscore < cur_bad->mt_metric)
+          cur_bad = MakeHypothesisInfo(samples[i].fmap, sentscore);
+      }
+    } else {
+      KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size);
+      for (int i = 0; i < kbest_size; ++i) {
+        const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+          kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+        if (!d) break;
+        float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore();
+        if (invert_score) sentscore *= -1.0;
+        // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl;
+        if (i == 0)
+          cur_best = MakeHypothesisInfo(d->feature_values, sentscore);
+        if (!cur_good || sentscore > cur_good->mt_metric)
+          cur_good = MakeHypothesisInfo(d->feature_values, sentscore);
+        if (!cur_bad || sentscore < cur_bad->mt_metric)
+          cur_bad = MakeHypothesisInfo(d->feature_values, sentscore);
+      }
+      //cerr << "GOOD: " << cur_good->mt_metric << endl;
+      //cerr << " CUR: " << cur_best->mt_metric << endl;
+      //cerr << " BAD: " << cur_bad->mt_metric << endl;
     }
-    //cerr << "GOOD: " << cur_good->mt_metric << endl;
-    //cerr << " CUR: " << cur_best->mt_metric << endl;
-    //cerr << " BAD: " << cur_bad->mt_metric << endl;
   }
 };
 
@@ -164,6 +187,12 @@ int main(int argc, char** argv) {
     rng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
   else
     rng.reset(new MT19937);
+  const bool sample_forest = conf.count("sample_forest") > 0;
+  const bool sample_forest_unit_weight_vector = conf.count("sample_forest_unit_weight_vector") > 0;
+  if (sample_forest_unit_weight_vector && !sample_forest) {
+    cerr << "Cannot --sample_forest_unit_weight_vector without --sample_forest" << endl;
+    return 1;
+  }
   vector<string> corpus;
   ReadTrainingCorpus(conf["source"].as<string>(), &corpus);
   const string metric_name = conf["mt_metric"].as<string>();
@@ -195,7 +224,7 @@ int main(int argc, char** argv) {
   assert(corpus.size() > 0);
   vector<GoodBadOracle> oracles(corpus.size());
 
-  TrainingObserver observer(conf["k_best_size"].as<int>(), ds, &oracles);
+  TrainingObserver observer(conf["k_best_size"].as<int>(), ds, sample_forest, &oracles);
   int cur_sent = 0;
   int lcount = 0;
   int normalizer = 0;
@@ -234,7 +263,19 @@ int main(int argc, char** argv) {
       cerr << "PASS " << (lcount / corpus.size() + 1) << endl;
     }
     decoder.SetId(order[cur_sent]);
+    double sc = 1.0;
+    if (sample_forest_unit_weight_vector) {
+      sc = lambdas.l2norm();
+      if (sc > 0) {
+        for (unsigned i = 0; i < dense_weights.size(); ++i)
+          dense_weights[i] /= sc;
+      }
+    }
     decoder.Decode(corpus[order[cur_sent]], &observer);  // update oracles
+    if (sc && sc != 1.0) {
+      for (unsigned i = 0; i < dense_weights.size(); ++i)
+        dense_weights[i] *= sc;
+    }
     const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis();
     const HypothesisInfo& cur_good = *oracles[order[cur_sent]].good;
     const HypothesisInfo& cur_bad = *oracles[order[cur_sent]].bad;
-- 
cgit v1.2.3