From 1cfa8735f4cd7264e70cc6918bbd58c86a015ee4 Mon Sep 17 00:00:00 2001
From: graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>
Date: Fri, 16 Jul 2010 01:56:55 +0000
Subject: oracle is_null

git-svn-id: https://ws10smt.googlecode.com/svn/trunk@279 ec762483-ff6d-05da-a07a-a48fb63a330f
---
 decoder/cdec.cc       |  2 +-
 decoder/oracle_bleu.h | 56 ++++++++++++++++++++++++++-------------------------
 decoder/tdict.h       |  8 ++++++++
 3 files changed, 38 insertions(+), 28 deletions(-)

(limited to 'decoder')

diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index 75c907b1..77179948 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -604,7 +604,7 @@ int main(int argc, char** argv) {
 
     /*Oracle Rescoring*/
     if(get_oracle_forest) {
-      Oracles o=oracles.ComputeOracles(smeta,forest,feature_weights,&cerr,10,conf["forest_output"].as<std::string>());
+      Oracles o=oracles.ComputeOracles(smeta,&forest,feature_weights,&cerr,10,conf["forest_output"].as<std::string>());
       cerr << "  +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
       cerr << "  +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
       o.hope.Print(cerr,"  +Oracle BLEU");
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 6708b02e..b58117c1 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -16,6 +16,7 @@
 #include "sentence_metadata.h"
 #include "apply_models.h"
 #include "kbest.h"
+#include "timing_stats.h"
 
 //TODO: put function impls into .cc
 //TODO: disentangle
@@ -51,15 +52,15 @@ struct Oracles {
   Translation model,fear,hope;
   // feature 0 will be the error rate in fear and hope
   // move toward hope
-  WeightVector ModelHopeGradient() {
-    WeightVector r=hope-model;
-    r[0]=0;
+  FeatureVector ModelHopeGradient() {
+    FeatureVector r=hope.features-model.features;
+    r.set_value(0,0);
     return r;
   }
   // move toward hope from fear
-  WeightVector FearHopeGradient() {
-    Weightvector r=hope-fear;
-    r[0]=0;
+  FeatureVector FearHopeGradient() {
+    FeatureVector r=hope.features-fear.features;
+    r.set_value(0,0);
     return r;
   }
 };
@@ -100,6 +101,12 @@ struct OracleBleu {
 
   typedef boost::shared_ptr<Score> ScoreP;
   ScoreP doc_score,sentscore; // made from factory, so we delete them
+  ScoreP GetScore(Sentence const& sentence,int sent_id) {
+    return ScoreP(ds[sent_id]->ScoreCandidate(sentence));
+  }
+  ScoreP GetScore(Hypergraph const& forest,int sent_id) {
+    return GetScore(Translation(forest).sentence,sent_id);
+  }
 
   double bleu_weight;
   void UseConf(boost::program_options::variables_map const& conf) {
@@ -134,44 +141,39 @@ struct OracleBleu {
   SentenceMetadata MakeMetadata(Hypergraph const& forest,int sent_id) {
     std::vector<WordID> srcsent;
     ViterbiFSentence(forest,&srcsent);
-    SentenceMetadata sm(sent_id,Lattice()); //TODO: make reference from refs?
-    sm.SetSourceLength(srcsent.size());
+    SentenceMetadata smeta(sent_id,Lattice()); //TODO: make reference from refs?
+    smeta.SetSourceLength(srcsent.size());
 	smeta.SetScore(doc_score.get());
 	smeta.SetDocScorer(&ds);
 	smeta.SetDocLen(doc_src_length);
-    return sm;
+    return smeta;
   }
 
-  Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph const& forest,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") {
+  // destroys forest (replaces it w/ rescored oracle one)
+  Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") {
+    Hypergraph &forest=*forest_in_out;
     Oracles r;
     int sent_id=smeta.GetSentenceID();
     r.model=Translation(forest);
-
     if (kbest) DumpKBest("model",sent_id, forest, kbest, true, forest_output);
     {
       Timer t("Forest Oracle rescoring:");
       Hypergraph oracle_forest;
-      Rescore(smeta,forest,&oracle_forest,feature_weights,blue_weight,log);
+      Rescore(smeta,forest,&oracle_forest,feature_weights,bleu_weight,log);
       forest.swap(oracle_forest);
     }
     r.hope=Translation(forest);
     if (kbest) DumpKBest("oracle",sent_id, forest, kbest, true, forest_output);
-    oracle.ReweightBleu(&forest,-blue_weight);
+    ReweightBleu(&forest,-bleu_weight);
     r.fear=Translation(forest);
     if (kbest) DumpKBest("negative",sent_id, forest, kbest, true, forest_output);
     return r;
   }
-
-  ScoreP Score(Sentence const& sentence,int sent_id) {
-    return ds[sent_id]->ScoreCandidate(sentence);
-  }
-  ScoreP Score(Hypergraph const& forest,int sent_id) {
-    return Score(model_trans(forest).translation,sent_id);
-  }
+  typedef std::vector<WordID> Sentence;
 
   void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {
     // the sentence bleu stats will get added to doc only if you call IncludeLastScore
-    sentscore=Score(forest,smeta.GetSentenceID());
+    sentscore=GetScore(forest,smeta.GetSentenceID());
 	if (!doc_score) { doc_score.reset(sentscore->GetOne()); }
 	tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from?
     using namespace std;
@@ -219,13 +221,13 @@ struct OracleBleu {
     float curr_src_length = doc_src_length + tmp_src_length;
 
     if (unique) {
-      KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
+      KBest::KBestDerivations<Sentence, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
       for (int i = 0; i < k; ++i) {
-        const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
+        const KBest::KBestDerivations<Sentence, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
           kbest.LazyKthBest(forest.nodes_.size() - 1, i);
         if (!d) break;
         //calculate score in context of psuedo-doc
-        Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield);
+        Score* sentscore = GetScore(d->yield,sent_id);
         sentscore->PlusEquals(*doc_score,float(1));
         float bleu = curr_src_length * sentscore->ComputeScore();
         kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
@@ -234,9 +236,9 @@ struct OracleBleu {
         //     << d->feature_values << " ||| " << log(d->score) << endl;
       }
     } else {
-      KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k);
+      KBest::KBestDerivations<Sentence, ESentenceTraversal> kbest(forest, k);
       for (int i = 0; i < k; ++i) {
-        const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+        const KBest::KBestDerivations<Sentence, ESentenceTraversal>::Derivation* d =
           kbest.LazyKthBest(forest.nodes_.size() - 1, i);
         if (!d) break;
         cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
@@ -245,7 +247,7 @@ struct OracleBleu {
     }
   }
 
-void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const& forest_output)
+void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const& forest_output)
   {
     std::ostringstream kbest_string_stream;
     kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id;
diff --git a/decoder/tdict.h b/decoder/tdict.h
index fd77543d..1fba5179 100644
--- a/decoder/tdict.h
+++ b/decoder/tdict.h
@@ -28,4 +28,12 @@ struct TD {
   static const char* Convert(const WordID& w);
 };
 
+struct ToTD {
+  typedef WordID result_type;
+  result_type operator()(std::string const& t) const {
+    return TD::Convert(t);
+  }
+};
+
+
 #endif
-- 
cgit v1.2.3