From b03d01f22df3c5e27014bf32748baacc10c7d360 Mon Sep 17 00:00:00 2001
From: Patrick Simianer 
Date: Thu, 13 Oct 2011 23:50:28 +0200
Subject: fixed approx bleu
---
 dtrain/dtrain.cc               |  7 ++++--
 dtrain/kbestget.h              | 21 ++++++++--------
 dtrain/ksampler.h              |  2 +-
 dtrain/score.cc                | 56 ++++++++++++++++++------------------------
 dtrain/score.h                 | 42 +++++++++++++++----------------
 dtrain/test/example/dtrain.ini |  1 -
 6 files changed, 61 insertions(+), 68 deletions(-)
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 25858738..f679c9f6 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -105,12 +105,13 @@ main(int argc, char** argv)
   string scorer_str = cfg["scorer"].as();
   LocalScorer* scorer;
   if (scorer_str == "bleu") {
+    scorer = dynamic_cast(new BleuScorer);
   } else if (scorer_str == "stupid_bleu") {
     scorer = dynamic_cast(new StupidBleuScorer);
   } else if (scorer_str == "smooth_bleu") {
     scorer = dynamic_cast(new SmoothBleuScorer);
   } else if (scorer_str == "approx_bleu") {
-    scorer = dynamic_cast(new StupidBleuScorer); // FIXME
+    scorer = dynamic_cast(new ApproxBleuScorer(N));
   } else {
     cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl;
     exit(1);
@@ -145,7 +146,7 @@ main(int argc, char** argv)
   // input
   string input_fn = cfg["input"].as();
   ReadFile input(input_fn);
-    // buffer input for t > 0
+  // buffer input for t > 0
   vector src_str_buf;          // source strings
   vector > ref_ids_buf; // references as WordID vecs
   vector weights_files;        // remember weights for each iteration
@@ -341,6 +342,8 @@ main(int argc, char** argv)
 
   } // input loop
 
+  if (scorer_str == "approx_bleu") scorer->Reset();
+
   if (t == 0) {
     in_sz = ii; // remember size of input (# lines)
     grammar_buf_out.close();
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h
index c0fd3f47..d141da60 100644
--- a/dtrain/kbestget.h
+++ b/dtrain/kbestget.h
@@ -14,7 +14,7 @@ namespace dtrain
 {
 
 
-typedef double score_t; // float
+typedef double score_t;
 
 struct ScoredHyp
 {
@@ -31,9 +31,11 @@ struct LocalScorer
   vector w_;
 
   virtual score_t
-  Score(vector& hyp, vector& ref)=0;
+  Score(vector& hyp, vector& ref, const unsigned rank)=0;
 
-  void
+  void Reset() {} // only for approx bleu
+
+  inline void
   Init(unsigned N, vector weights)
   {
     assert(N > 0);
@@ -42,7 +44,7 @@ struct LocalScorer
     else w_ = weights;
   }
 
-  score_t
+  inline score_t
   brevity_penaly(const unsigned hyp_len, const unsigned ref_len)
   {
     if (hyp_len > ref_len) return 1;
@@ -55,11 +57,10 @@ struct HypSampler : public DecoderObserver
   LocalScorer* scorer_;
   vector* ref_;
   virtual vector* GetSamples()=0;
-  void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
-  void SetRef(vector& ref) { ref_ = &ref; } 
+  inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
+  inline void SetRef(vector& ref) { ref_ = &ref; } 
 };
-/////////////////////////////////////////////////////////////////////
-// wtf
+///////////////////////////////////////////////////////////////////////////////
 
 
 
@@ -107,7 +108,7 @@ struct KBestGetter : public HypSampler
       h.f = d->feature_values;
       h.model = log(d->score);
       h.rank = i;
-      h.score = scorer_->Score(h.w, *ref_);
+      h.score = scorer_->Score(h.w, *ref_, i);
       s_.push_back(h);
     }
   }
@@ -126,7 +127,7 @@ struct KBestGetter : public HypSampler
       h.f = d->feature_values;
       h.model = log(d->score);
       h.rank = i;
-      h.score = scorer_->Score(h.w, *ref_);
+      h.score = scorer_->Score(h.w, *ref_, i);
       s_.push_back(h);
     }
   }
diff --git a/dtrain/ksampler.h b/dtrain/ksampler.h
index 7567f43a..276f2cc9 100644
--- a/dtrain/ksampler.h
+++ b/dtrain/ksampler.h
@@ -37,7 +37,7 @@ struct KSampler : public HypSampler
       h.f = samples[i].fmap;
       h.model = log(samples[i].model_score); 
       h.rank = i;
-      h.score = scorer_->Score(h.w, *ref_);
+      h.score = scorer_->Score(h.w, *ref_, i);
       s_.push_back(h);
     }
   }
diff --git a/dtrain/score.cc b/dtrain/score.cc
index 93c4e80b..f5e920a0 100644
--- a/dtrain/score.cc
+++ b/dtrain/score.cc
@@ -28,7 +28,8 @@ BleuScorer::Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref
 }
 
 score_t
-BleuScorer::Score(vector& hyp, vector& ref)
+BleuScorer::Score(vector& hyp, vector& ref,
+                  const unsigned rank)
 {
   unsigned hyp_len = hyp.size(), ref_len = ref.size();
   if (hyp_len == 0 || ref_len == 0) return 0;
@@ -47,7 +48,8 @@ BleuScorer::Score(vector& hyp, vector& ref)
  * NOTE: 0 iff no 1gram match
  */
 score_t
-StupidBleuScorer::Score(vector& hyp, vector& ref)
+StupidBleuScorer::Score(vector& hyp, vector& ref,
+                        const unsigned rank)
 {
   unsigned hyp_len = hyp.size(), ref_len = ref.size();
   if (hyp_len == 0 || ref_len == 0) return 0;
@@ -72,7 +74,8 @@ StupidBleuScorer::Score(vector& hyp, vector& ref)
  * NOTE: max is 0.9375
  */
 score_t
-SmoothBleuScorer::Score(vector& hyp, vector& ref)
+SmoothBleuScorer::Score(vector& hyp, vector& ref,
+                        const unsigned rank)
 {
   unsigned hyp_len = hyp.size(), ref_len = ref.size();
   if (hyp_len == 0 || ref_len == 0) return 0;
@@ -87,7 +90,6 @@ SmoothBleuScorer::Score(vector& hyp, vector& ref)
   return brevity_penaly(hyp_len, ref_len) * sum;
 }
 
-// FIXME
 /*
  * approx. bleu
  *
@@ -95,38 +97,28 @@ SmoothBleuScorer::Score(vector& hyp, vector& ref)
  *        and Structural Translation Features"
  * (Chiang et al. '08)
  */
-/*void
-ApproxBleuScorer::Prep(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len)
-{
-  glob_onebest_counts += counts;
-  glob_hyp_len += hyp_len;
-  glob_ref_len += ref_len;
-}
-
-void
-ApproxBleuScorer::Reset()
-{
-  glob_onebest_counts.Zero();
-  glob_hyp_len = 0;
-  glob_ref_len = 0;
-}
-
 score_t
-ApproxBleuScorer::Score(ScoredHyp& hyp, vector& ref_ids, unsigned id)
+ApproxBleuScorer::Score(vector& hyp, vector& ref,
+                        const unsigned rank)
 {
-  NgramCounts counts = make_ngram_counts(hyp.w, ref_ids, N_);
-  if (id == 0) reset();
-  unsigned hyp_len = 0, ref_len = 0;
-  if (hyp.rank == 0) { // 'context of 1best translations'
-    scorer->prep(counts, hyp.w.size(), ref_ids.size()); 
-    counts.reset();
+  unsigned hyp_len = hyp.size(), ref_len = ref.size();
+  if (hyp_len == 0 || ref_len == 0) return 0;
+  NgramCounts counts = make_ngram_counts(hyp, ref, N_);
+  NgramCounts tmp(N_);
+  if (rank == 0) { // 'context of 1best translations'
+    glob_onebest_counts += counts;
+    glob_hyp_len += hyp_len;
+    glob_ref_len += ref_len;
+    hyp_len = glob_hyp_len;
+    ref_len = glob_ref_len;
+    tmp = glob_onebest_counts;
   } else {
-    hyp_len = hyp.w.size();
-    ref_len = ref_ids.size();
+    hyp_len = hyp.size();
+    ref_len = ref.size();
+    tmp = glob_onebest_counts + counts;
   }
-  return 0.9 * BleuScorer::Bleu(glob_onebest_counts + counts,
-                                glob_hyp_len + hyp_len, glob_ref_len + ref_len);
-}*/
+  return 0.9 * Bleu(tmp, hyp_len, ref_len);
+}
 
 
 } // namespace
diff --git a/dtrain/score.h b/dtrain/score.h
index 9af56ef9..85cd0317 100644
--- a/dtrain/score.h
+++ b/dtrain/score.h
@@ -17,7 +17,7 @@ struct NgramCounts
 
   NgramCounts(const unsigned N) : N_(N) { Zero(); } 
 
-  void
+  inline void
   operator+=(const NgramCounts& rhs)
   {
     assert(N_ == rhs.N_);
@@ -27,7 +27,7 @@ struct NgramCounts
     }
   }
 
-  const NgramCounts
+  inline const NgramCounts
   operator+(const NgramCounts &other) const
   {
     NgramCounts result = *this;
@@ -35,8 +35,8 @@ struct NgramCounts
     return result;
   }
 
-  void
-  Add(unsigned count, unsigned ref_count, unsigned i)
+  inline void
+  Add(const unsigned count, const unsigned ref_count, const unsigned i)
   {
     assert(i < N_);
     if (count > ref_count) {
@@ -47,7 +47,7 @@ struct NgramCounts
     sum[i] += count;
   }
 
-  void
+  inline void
   Zero()
   {
     unsigned i;
@@ -57,7 +57,7 @@ struct NgramCounts
     }
   }
 
-  void
+  inline void
   Print()
   {
     for (unsigned i = 0; i < N_; i++) {
@@ -106,38 +106,36 @@ make_ngram_counts(const vector& hyp, const vector& ref, const un
 struct BleuScorer : public LocalScorer
 {
   score_t Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len);
-  score_t Score(vector& hyp, vector& ref_ids);
+  score_t Score(vector& hyp, vector& ref, const unsigned rank);
 };
 
 struct StupidBleuScorer : public LocalScorer
 {
-  score_t Score(vector& hyp, vector& ref);
+  score_t Score(vector& hyp, vector& ref, const unsigned rank);
 };
 
 struct SmoothBleuScorer : public LocalScorer
 {
-  score_t Score(vector& hyp, vector& ref);
+  score_t Score(vector& hyp, vector& ref, const unsigned rank);
 };
 
-// FIXME
-/*struct ApproxBleuScorer : public LocalScorer
+struct ApproxBleuScorer : public BleuScorer
 {
-  bool prepped;
-
-  NgramCounts* glob_onebest_counts;
+  NgramCounts glob_onebest_counts;
   unsigned glob_hyp_len, glob_ref_len;
 
-  void Prep(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len);
-  void Reset();
-  score_t Score(ScoredHyp& hyp, vector& ref_ids, unsigned id);
-
-  ApproxBleuScorer() 
+  ApproxBleuScorer(unsigned N) : glob_onebest_counts(NgramCounts(N))
   {
+    glob_hyp_len = glob_ref_len = 0;
+  }
+
+  inline void Reset() {
     glob_onebest_counts.Zero();
-    glob_hyp_len = 0;
-    glob_ref_len = 0;
+    glob_hyp_len = glob_ref_len = 0;
   }
-};*/
+
+  score_t Score(vector& hyp, vector& ref, const unsigned rank);
+};
 
 
 } // namespace
diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini
index 3e5c2cd1..1e841824 100644
--- a/dtrain/test/example/dtrain.ini
+++ b/dtrain/test/example/dtrain.ini
@@ -2,7 +2,6 @@ decoder_config=test/example/cdec.ini
 k=100
 N=3
 gamma=0
-#gamma=0.00001
 epochs=4
 input=test/example/nc-1k-tabs.gz
 scorer=stupid_bleu
-- 
cgit v1.2.3