From 5a5b00f2ad1ef2cb50e9c58bcb77246f3ed99057 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Thu, 13 Oct 2011 23:50:28 +0200 Subject: fixed approx bleu --- dtrain/dtrain.cc | 7 ++++-- dtrain/kbestget.h | 21 ++++++++-------- dtrain/ksampler.h | 2 +- dtrain/score.cc | 56 ++++++++++++++++++------------------------ dtrain/score.h | 42 +++++++++++++++---------------- dtrain/test/example/dtrain.ini | 1 - 6 files changed, 61 insertions(+), 68 deletions(-) diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 25858738..f679c9f6 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -105,12 +105,13 @@ main(int argc, char** argv) string scorer_str = cfg["scorer"].as(); LocalScorer* scorer; if (scorer_str == "bleu") { + scorer = dynamic_cast(new BleuScorer); } else if (scorer_str == "stupid_bleu") { scorer = dynamic_cast(new StupidBleuScorer); } else if (scorer_str == "smooth_bleu") { scorer = dynamic_cast(new SmoothBleuScorer); } else if (scorer_str == "approx_bleu") { - scorer = dynamic_cast(new StupidBleuScorer); // FIXME + scorer = dynamic_cast(new ApproxBleuScorer(N)); } else { cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl; exit(1); @@ -145,7 +146,7 @@ main(int argc, char** argv) // input string input_fn = cfg["input"].as(); ReadFile input(input_fn); - // buffer input for t > 0 + // buffer input for t > 0 vector src_str_buf; // source strings vector > ref_ids_buf; // references as WordID vecs vector weights_files; // remember weights for each iteration @@ -341,6 +342,8 @@ main(int argc, char** argv) } // input loop + if (scorer_str == "approx_bleu") scorer->Reset(); + if (t == 0) { in_sz = ii; // remember size of input (# lines) grammar_buf_out.close(); diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h index c0fd3f47..d141da60 100644 --- a/dtrain/kbestget.h +++ b/dtrain/kbestget.h @@ -14,7 +14,7 @@ namespace dtrain { -typedef double score_t; // float +typedef double score_t; struct ScoredHyp { @@ -31,9 +31,11 @@ struct LocalScorer vector w_; virtual score_t - Score(vector& hyp, vector& ref)=0; + Score(vector& hyp, vector& ref, const unsigned rank)=0; - void + void Reset() {} // only for approx bleu + + inline void Init(unsigned N, vector weights) { assert(N > 0); @@ -42,7 +44,7 @@ struct LocalScorer else w_ = weights; } - score_t + inline score_t brevity_penaly(const unsigned hyp_len, const unsigned ref_len) { if (hyp_len > ref_len) return 1; @@ -55,11 +57,10 @@ struct HypSampler : public DecoderObserver LocalScorer* scorer_; vector* ref_; virtual vector* GetSamples()=0; - void SetScorer(LocalScorer* scorer) { scorer_ = scorer; } - void SetRef(vector& ref) { ref_ = &ref; } + inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; } + inline void SetRef(vector& ref) { ref_ = &ref; } }; -///////////////////////////////////////////////////////////////////// -// wtf +/////////////////////////////////////////////////////////////////////////////// @@ -107,7 +108,7 @@ struct KBestGetter : public HypSampler h.f = d->feature_values; h.model = log(d->score); h.rank = i; - h.score = scorer_->Score(h.w, *ref_); + h.score = scorer_->Score(h.w, *ref_, i); s_.push_back(h); } } @@ -126,7 +127,7 @@ struct KBestGetter : public HypSampler h.f = d->feature_values; h.model = log(d->score); h.rank = i; - h.score = scorer_->Score(h.w, *ref_); + h.score = scorer_->Score(h.w, *ref_, i); s_.push_back(h); } } diff --git a/dtrain/ksampler.h b/dtrain/ksampler.h index 7567f43a..276f2cc9 100644 --- a/dtrain/ksampler.h +++ b/dtrain/ksampler.h @@ -37,7 +37,7 @@ struct KSampler : public HypSampler h.f = samples[i].fmap; h.model = log(samples[i].model_score); h.rank = i; - h.score = scorer_->Score(h.w, *ref_); + h.score = scorer_->Score(h.w, *ref_, i); s_.push_back(h); } } diff --git a/dtrain/score.cc b/dtrain/score.cc index 93c4e80b..f5e920a0 100644 --- a/dtrain/score.cc +++ b/dtrain/score.cc @@ -28,7 +28,8 @@ BleuScorer::Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref } score_t -BleuScorer::Score(vector& hyp, vector& ref) +BleuScorer::Score(vector& hyp, vector& ref, + const unsigned rank) { unsigned hyp_len = hyp.size(), ref_len = ref.size(); if (hyp_len == 0 || ref_len == 0) return 0; @@ -47,7 +48,8 @@ BleuScorer::Score(vector& hyp, vector& ref) * NOTE: 0 iff no 1gram match */ score_t -StupidBleuScorer::Score(vector& hyp, vector& ref) +StupidBleuScorer::Score(vector& hyp, vector& ref, + const unsigned rank) { unsigned hyp_len = hyp.size(), ref_len = ref.size(); if (hyp_len == 0 || ref_len == 0) return 0; @@ -72,7 +74,8 @@ StupidBleuScorer::Score(vector& hyp, vector& ref) * NOTE: max is 0.9375 */ score_t -SmoothBleuScorer::Score(vector& hyp, vector& ref) +SmoothBleuScorer::Score(vector& hyp, vector& ref, + const unsigned rank) { unsigned hyp_len = hyp.size(), ref_len = ref.size(); if (hyp_len == 0 || ref_len == 0) return 0; @@ -87,7 +90,6 @@ SmoothBleuScorer::Score(vector& hyp, vector& ref) return brevity_penaly(hyp_len, ref_len) * sum; } -// FIXME /* * approx. bleu * @@ -95,38 +97,28 @@ SmoothBleuScorer::Score(vector& hyp, vector& ref) * and Structural Translation Features" * (Chiang et al. '08) */ -/*void -ApproxBleuScorer::Prep(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len) -{ - glob_onebest_counts += counts; - glob_hyp_len += hyp_len; - glob_ref_len += ref_len; -} - -void -ApproxBleuScorer::Reset() -{ - glob_onebest_counts.Zero(); - glob_hyp_len = 0; - glob_ref_len = 0; -} - score_t -ApproxBleuScorer::Score(ScoredHyp& hyp, vector& ref_ids, unsigned id) +ApproxBleuScorer::Score(vector& hyp, vector& ref, + const unsigned rank) { - NgramCounts counts = make_ngram_counts(hyp.w, ref_ids, N_); - if (id == 0) reset(); - unsigned hyp_len = 0, ref_len = 0; - if (hyp.rank == 0) { // 'context of 1best translations' - scorer->prep(counts, hyp.w.size(), ref_ids.size()); - counts.reset(); + unsigned hyp_len = hyp.size(), ref_len = ref.size(); + if (hyp_len == 0 || ref_len == 0) return 0; + NgramCounts counts = make_ngram_counts(hyp, ref, N_); + NgramCounts tmp(N_); + if (rank == 0) { // 'context of 1best translations' + glob_onebest_counts += counts; + glob_hyp_len += hyp_len; + glob_ref_len += ref_len; + hyp_len = glob_hyp_len; + ref_len = glob_ref_len; + tmp = glob_onebest_counts; } else { - hyp_len = hyp.w.size(); - ref_len = ref_ids.size(); + hyp_len = hyp.size(); + ref_len = ref.size(); + tmp = glob_onebest_counts + counts; } - return 0.9 * BleuScorer::Bleu(glob_onebest_counts + counts, - glob_hyp_len + hyp_len, glob_ref_len + ref_len); -}*/ + return 0.9 * Bleu(tmp, hyp_len, ref_len); +} } // namespace diff --git a/dtrain/score.h b/dtrain/score.h index 9af56ef9..85cd0317 100644 --- a/dtrain/score.h +++ b/dtrain/score.h @@ -17,7 +17,7 @@ struct NgramCounts NgramCounts(const unsigned N) : N_(N) { Zero(); } - void + inline void operator+=(const NgramCounts& rhs) { assert(N_ == rhs.N_); @@ -27,7 +27,7 @@ struct NgramCounts } } - const NgramCounts + inline const NgramCounts operator+(const NgramCounts &other) const { NgramCounts result = *this; @@ -35,8 +35,8 @@ struct NgramCounts return result; } - void - Add(unsigned count, unsigned ref_count, unsigned i) + inline void + Add(const unsigned count, const unsigned ref_count, const unsigned i) { assert(i < N_); if (count > ref_count) { @@ -47,7 +47,7 @@ struct NgramCounts sum[i] += count; } - void + inline void Zero() { unsigned i; @@ -57,7 +57,7 @@ struct NgramCounts } } - void + inline void Print() { for (unsigned i = 0; i < N_; i++) { @@ -106,38 +106,36 @@ make_ngram_counts(const vector& hyp, const vector& ref, const un struct BleuScorer : public LocalScorer { score_t Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len); - score_t Score(vector& hyp, vector& ref_ids); + score_t Score(vector& hyp, vector& ref, const unsigned rank); }; struct StupidBleuScorer : public LocalScorer { - score_t Score(vector& hyp, vector& ref); + score_t Score(vector& hyp, vector& ref, const unsigned rank); }; struct SmoothBleuScorer : public LocalScorer { - score_t Score(vector& hyp, vector& ref); + score_t Score(vector& hyp, vector& ref, const unsigned rank); }; -// FIXME -/*struct ApproxBleuScorer : public LocalScorer +struct ApproxBleuScorer : public BleuScorer { - bool prepped; - - NgramCounts* glob_onebest_counts; + NgramCounts glob_onebest_counts; unsigned glob_hyp_len, glob_ref_len; - void Prep(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len); - void Reset(); - score_t Score(ScoredHyp& hyp, vector& ref_ids, unsigned id); - - ApproxBleuScorer() + ApproxBleuScorer(unsigned N) : glob_onebest_counts(NgramCounts(N)) { + glob_hyp_len = glob_ref_len = 0; + } + + inline void Reset() { glob_onebest_counts.Zero(); - glob_hyp_len = 0; - glob_ref_len = 0; + glob_hyp_len = glob_ref_len = 0; } -};*/ + + score_t Score(vector& hyp, vector& ref, const unsigned rank); +}; } // namespace diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini index 3e5c2cd1..1e841824 100644 --- a/dtrain/test/example/dtrain.ini +++ b/dtrain/test/example/dtrain.ini @@ -2,7 +2,6 @@ decoder_config=test/example/cdec.ini k=100 N=3 gamma=0 -#gamma=0.00001 epochs=4 input=test/example/nc-1k-tabs.gz scorer=stupid_bleu -- cgit v1.2.3