diff options
-rw-r--r-- | training/dtrain/score.h | 308 |
1 files changed, 188 insertions, 120 deletions
diff --git a/training/dtrain/score.h b/training/dtrain/score.h index ca3da39b..a9c60b64 100644 --- a/training/dtrain/score.h +++ b/training/dtrain/score.h @@ -9,66 +9,77 @@ namespace dtrain struct NgramCounts { size_t N_; - map<size_t, weight_t> clipped_; - map<size_t, weight_t> sum_; + map<size_t, weight_t> clipped; + map<size_t, weight_t> sum; NgramCounts() {} - NgramCounts(const size_t N) : N_(N) { Zero(); } + NgramCounts(const size_t N) : N_(N) { zero(); } inline void operator+=(const NgramCounts& rhs) { - if (rhs.N_ > N_) Resize(rhs.N_); + if (rhs.N_ > N_) resize(rhs.N_); for (size_t i = 0; i < N_; i++) { - this->clipped_[i] += rhs.clipped_.find(i)->second; - this->sum_[i] += rhs.sum_.find(i)->second; + this->clipped[i] += rhs.clipped.find(i)->second; + this->sum[i] += rhs.sum.find(i)->second; } } inline void operator*=(const weight_t rhs) { - for (unsigned i = 0; i < N_; i++) { - this->clipped_[i] *= rhs; - this->sum_[i] *= rhs; + for (size_t i=0; i<N_; i++) { + this->clipped[i] *= rhs; + this->sum[i] *= rhs; } } inline void - Add(const size_t count, const size_t ref_count, const size_t i) + add(const size_t count, + const size_t count_ref, + const size_t i) { assert(i < N_); - if (count > ref_count) { - clipped_[i] += ref_count; + if (count > count_ref) { + clipped[i] += count_ref; } else { - clipped_[i] += count; + clipped[i] += count; } - sum_[i] += count; + sum[i] += count; } inline void - Zero() + zero() { - for (size_t i = 0; i < N_; i++) { - clipped_[i] = 0.; - sum_[i] = 0.; + for (size_t i=0; i<N_; i++) { + clipped[i] = 0.; + sum[i] = 0.; } } inline void - Resize(size_t N) + one() + { + for (size_t i=0; i<N_; i++) { + clipped[i] = 1.; + sum[i] = 1.; + } + } + + inline void + resize(size_t N) { if (N == N_) return; else if (N > N_) { for (size_t i = N_; i < N; i++) { - clipped_[i] = 0.; - sum_[i] = 0.; + clipped[i] = 0.; + sum[i] = 0.; } } else { // N < N_ for (size_t i = N_-1; i > N-1; i--) { - clipped_.erase(i); - sum_.erase(i); + clipped.erase(i); + sum.erase(i); } } N_ = N; @@ -78,37 +89,38 @@ struct NgramCounts typedef map<vector<WordID>, size_t> Ngrams; inline Ngrams -MakeNgrams(const vector<WordID>& s, const size_t N) +ngrams(const vector<WordID>& vw, + const size_t N) { - Ngrams ngrams; + Ngrams r; vector<WordID> ng; - for (size_t i = 0; i < s.size(); i++) { + for (size_t i=0; i<vw.size(); i++) { ng.clear(); - for (size_t j = i; j < min(i+N, s.size()); j++) { - ng.push_back(s[j]); - ngrams[ng]++; + for (size_t j=i; j<min(i+N, vw.size()); j++) { + ng.push_back(vw[j]); + r[ng]++; } } - return ngrams; + return r; } inline NgramCounts -MakeNgramCounts(const vector<WordID>& hyp, - const vector<Ngrams>& ref, - const size_t N) +ngram_counts(const vector<WordID>& hyp, + const vector<Ngrams>& ngrams_ref, + const size_t N) { - Ngrams hyp_ngrams = MakeNgrams(hyp, N); + Ngrams ngrams_hyp = ngrams(hyp, N); NgramCounts counts(N); Ngrams::iterator it, ti; - for (it = hyp_ngrams.begin(); it != hyp_ngrams.end(); it++) { + for (it = ngrams_hyp.begin(); it != ngrams_hyp.end(); it++) { size_t max_ref_count = 0; - for (auto r: ref) { + for (auto r: ngrams_ref) { ti = r.find(it->first); if (ti != r.end()) max_ref_count = max(max_ref_count, ti->second); } - counts.Add(it->second, min(it->second, max_ref_count), it->first.size()-1); + counts.add(it->second, min(it->second, max_ref_count), it->first.size()-1); } return counts; @@ -128,9 +140,9 @@ class Scorer } inline bool - Init(const vector<WordID>& hyp, - const vector<Ngrams>& ref_ngs, - const vector<size_t>& ref_ls, + init(const vector<WordID>& hyp, + const vector<Ngrams>& reference_ngrams, + const vector<size_t>& reference_lengths, size_t& hl, size_t& rl, size_t& M, @@ -138,10 +150,12 @@ class Scorer NgramCounts& counts) { hl = hyp.size(); - if (hl == 0) return false; - rl = BestMatchLength(hl, ref_ls); - if (rl == 0) return false; - counts = MakeNgramCounts(hyp, ref_ngs, N_); + if (hl == 0) + return false; + rl = best_match_length(hl, reference_lengths); + if (rl == 0) + return false; + counts = ngram_counts(hyp, reference_ngrams, N_); if (rl < N_) { M = rl; for (size_t i = 0; i < M; i++) v.push_back(1/((weight_t)M)); @@ -154,7 +168,8 @@ class Scorer } inline weight_t - BrevityPenalty(const size_t hl, const size_t rl) + brevity_penalty(const size_t hl, + const size_t rl) { if (hl > rl) return 1; @@ -163,16 +178,16 @@ class Scorer } inline size_t - BestMatchLength(const size_t hl, - const vector<size_t>& ref_ls) + best_match_length(const size_t hl, + const vector<size_t>& reference_lengths) { size_t m; - if (ref_ls.size() == 1) { - m = ref_ls.front(); + if (reference_lengths.size() == 1) { + m = reference_lengths.front(); } else { size_t i = 0, best_idx = 0; size_t best = numeric_limits<size_t>::max(); - for (auto l: ref_ls) { + for (auto l: reference_lengths) { size_t d = abs(hl-l); if (d < best) { best_idx = i; @@ -180,61 +195,63 @@ class Scorer } i += 1; } - m = ref_ls[best_idx]; + m = reference_lengths[best_idx]; } return m; } virtual weight_t - Score(const vector<WordID>&, + score(const vector<WordID>&, const vector<Ngrams>&, const vector<size_t>&) = 0; void - UpdateContext(const vector<WordID>& /*hyp*/, - const vector<Ngrams>& /*ref_ngs*/, - const vector<size_t>& /*ref_ls*/, - weight_t /*decay*/) {} + update_context(const vector<WordID>& /*hyp*/, + const vector<Ngrams>& /*reference_ngrams*/, + const vector<size_t>& /*reference_lengths*/, + weight_t /*decay*/) {} }; /* - * 'fixed' per-sentence BLEU - * simply add 1 to reference length for calculation of BP + * ['fixed'] per-sentence BLEU + * simply add 'fix' (1) to reference length for calculation of BP + * to avoid short translations * * as in "Optimizing for Sentence-Level BLEU+1 * Yields Short Translations" * (Nakov et al. '12) * */ -class PerSentenceBleuScorer : public Scorer +class NakovBleuScorer : public Scorer { + weight_t fix; + public: - PerSentenceBleuScorer(size_t n) : Scorer(n) {} + NakovBleuScorer(size_t n, weight_t fix) : Scorer(n), fix(fix) {} weight_t - Score(const vector<WordID>& hyp, - const vector<Ngrams>& ref_ngs, - const vector<size_t>& ref_ls) + score(const vector<WordID>& hyp, + const vector<Ngrams>& reference_ngrams, + const vector<size_t>& reference_lengths) { size_t hl, rl, M; vector<weight_t> v; NgramCounts counts; - if (!Init(hyp, ref_ngs, ref_ls, hl, rl, M, v, counts)) + if (!init(hyp, reference_ngrams, reference_lengths, hl, rl, M, v, counts)) return 0.; weight_t sum=0, add=0; - for (size_t i = 0; i < M; i++) { - if (i == 0 && (counts.sum_[i] == 0 || counts.clipped_[i] == 0)) return 0.; + for (size_t i=0; i<M; i++) { + if (i == 0 && (counts.sum[i]==0 || counts.clipped[i]==0)) return 0.; if (i > 0) add = 1; - sum += v[i] * log(((weight_t)counts.clipped_[i] + add) - / ((counts.sum_[i] + add))); + sum += v[i] * log(((weight_t)counts.clipped[i] + add) + / ((counts.sum[i] + add))); } - return BrevityPenalty(hl, rl+1) * exp(sum); + return brevity_penalty(hl, rl+1) * exp(sum); } }; - /* * BLEU * 0 if for one n \in {1..N} count is 0 @@ -244,29 +261,28 @@ class PerSentenceBleuScorer : public Scorer * (Papineni et al. '02) * */ - -class BleuScorer : public Scorer +class PapineniBleuScorer : public Scorer { public: - BleuScorer(size_t n) : Scorer(n) {} + PapineniBleuScorer(size_t n) : Scorer(n) {} weight_t - Score(const vector<WordID>& hyp, - const vector<Ngrams>& ref_ngs, - const vector<size_t>& ref_ls) + score(const vector<WordID>& hyp, + const vector<Ngrams>& reference_ngrams, + const vector<size_t>& reference_lengths) { size_t hl, rl, M; vector<weight_t> v; NgramCounts counts; - if (!Init(hyp, ref_ngs, ref_ls, hl, rl, M, v, counts)) + if (!init(hyp, reference_ngrams, reference_lengths, hl, rl, M, v, counts)) return 0.; weight_t sum = 0; - for (size_t i = 0; i < M; i++) { - if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) return 0.; - sum += v[i] * log((weight_t)counts.clipped_[i]/counts.sum_[i]); + for (size_t i=0; i<M; i++) { + if (counts.sum[i] == 0 || counts.clipped[i] == 0) return 0.; + sum += v[i] * log((weight_t)counts.clipped[i]/counts.sum[i]); } - return BrevityPenalty(hl, rl) * exp(sum); + return brevity_penalty(hl, rl) * exp(sum); } }; @@ -280,29 +296,30 @@ class BleuScorer : public Scorer * (Lin & Och '04) * */ -class OriginalPerSentenceBleuScorer : public Scorer +class LinBleuScorer : public Scorer { public: - OriginalPerSentenceBleuScorer(size_t n) : Scorer(n) {} + LinBleuScorer(size_t n) : Scorer(n) {} weight_t - Score(const vector<WordID>& hyp, - const vector<Ngrams>& ref_ngs, - const vector<size_t>& ref_ls) + score(const vector<WordID>& hyp, + const vector<Ngrams>& reference_ngrams, + const vector<size_t>& reference_lengths) { size_t hl, rl, M; vector<weight_t> v; NgramCounts counts; - if (!Init(hyp, ref_ngs, ref_ls, hl, rl, M, v, counts)) + if (!init(hyp, reference_ngrams, reference_lengths, hl, rl, M, v, counts)) return 0.; weight_t sum=0, add=0; - for (size_t i = 0; i < M; i++) { - if (i == 0 && (counts.sum_[i] == 0 || counts.clipped_[i] == 0)) return 0.; + for (size_t i=0; i<M; i++) { + if (i == 0 && (counts.sum[i]==0 || counts.clipped[i]==0)) return 0.; if (i == 1) add = 1; - sum += v[i] * log(((weight_t)counts.clipped_[i] + add)/((counts.sum_[i] + add))); + sum += v[i] * log(((weight_t)counts.clipped[i] + add) + / ((counts.sum[i] + add))); } - return BrevityPenalty(hl, rl) * exp(sum); + return brevity_penalty(hl, rl) * exp(sum); } }; @@ -315,30 +332,30 @@ class OriginalPerSentenceBleuScorer : public Scorer * (Liang et al. '06) * */ -class SmoothPerSentenceBleuScorer : public Scorer +class LiangBleuScorer : public Scorer { public: - SmoothPerSentenceBleuScorer(size_t n) : Scorer(n) {} + LiangBleuScorer(size_t n) : Scorer(n) {} weight_t - Score(const vector<WordID>& hyp, - const vector<Ngrams>& ref_ngs, - const vector<size_t>& ref_ls) + score(const vector<WordID>& hyp, + const vector<Ngrams>& reference_ngrams, + const vector<size_t>& reference_lengths) { - size_t hl=hyp.size(), rl=BestMatchLength(hl, ref_ls); + size_t hl=hyp.size(), rl=best_match_length(hl, reference_lengths); if (hl == 0 || rl == 0) return 0.; - NgramCounts counts = MakeNgramCounts(hyp, ref_ngs, N_); + NgramCounts counts = ngram_counts(hyp, reference_ngrams, N_); size_t M = N_; if (rl < N_) M = rl; weight_t sum = 0.; vector<weight_t> i_bleu; - for (size_t i=0; i < M; i++) + for (size_t i=0; i<M; i++) i_bleu.push_back(0.); - for (size_t i=0; i < M; i++) { - if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) { + for (size_t i=0; i<M; i++) { + if (counts.sum[i]==0 || counts.clipped[i]==0) { break; } else { - weight_t i_score = log((weight_t)counts.clipped_[i]/counts.sum_[i]); + weight_t i_score = log((weight_t)counts.clipped[i]/counts.sum[i]); for (size_t j=i; j < M; j++) { i_bleu[j] += (1/((weight_t)j+1)) * i_score; } @@ -346,23 +363,24 @@ class SmoothPerSentenceBleuScorer : public Scorer sum += exp(i_bleu[i])/pow(2.0, (double)(N_-i+2)); } - return BrevityPenalty(hl, hl) * sum; + return brevity_penalty(hl, hl) * sum; } }; /* * approx. bleu * Needs some more code in dtrain.cc . - * We do not scaling by source lengths, as hypotheses are compared only - * within an kbest list, not globally. + * We do not scale by source length, as hypotheses are compared only + * within single k-best lists, not globally (as in batch algorithms). + * TODO: reset after one iteration? + * TODO: maybe scale by source length? * * as in "Online Large-Margin Training of Syntactic * and Structural Translation Features" * (Chiang et al. '08) * - */ -class ApproxBleuScorer : public Scorer +class ChiangBleuScorer : public Scorer { private: NgramCounts context; @@ -370,39 +388,39 @@ class ApproxBleuScorer : public Scorer weight_t ref_sz_sum; public: - ApproxBleuScorer(size_t n) : + ChiangBleuScorer(size_t n) : Scorer(n), context(n), hyp_sz_sum(0), ref_sz_sum(0) {} weight_t - Score(const vector<WordID>& hyp, - const vector<Ngrams>& ref_ngs, - const vector<size_t>& ref_ls) + score(const vector<WordID>& hyp, + const vector<Ngrams>& reference_ngrams, + const vector<size_t>& reference_lengths) { size_t hl, rl, M; vector<weight_t> v; NgramCounts counts; - if (!Init(hyp, ref_ngs, ref_ls, hl, rl, M, v, counts)) + if (!init(hyp, reference_ngrams, reference_lengths, hl, rl, M, v, counts)) return 0.; counts += context; weight_t sum = 0; for (size_t i = 0; i < M; i++) { - if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) return 0.; - sum += v[i] * log((weight_t)counts.clipped_[i]/counts.sum_[i]); + if (counts.sum[i]==0 || counts.clipped[i]==0) return 0.; + sum += v[i] * log((weight_t)counts.clipped[i] / counts.sum[i]); } - return BrevityPenalty(hyp_sz_sum+hl, ref_sz_sum+rl) * exp(sum); + return brevity_penalty(hyp_sz_sum+hl, ref_sz_sum+rl) * exp(sum); } void - UpdateContext(const vector<WordID>& hyp, - const vector<Ngrams>& ref_ngs, - const vector<size_t>& ref_ls, - weight_t decay=0.9) + update_context(const vector<WordID>& hyp, + const vector<Ngrams>& reference_ngrams, + const vector<size_t>& reference_lengths, + weight_t decay=0.9) { size_t hl, rl, M; vector<weight_t> v; NgramCounts counts; - Init(hyp, ref_ngs, ref_ls, hl, rl, M, v, counts); + init(hyp, reference_ngrams, reference_lengths, hl, rl, M, v, counts); context += counts; context *= decay; @@ -413,6 +431,56 @@ class ApproxBleuScorer : public Scorer } }; +/* + * 'sum' bleu + * + * Merely sum up Ngram precisions + */ +class SumBleuScorer : public Scorer +{ + public: + SumBleuScorer(size_t n) : Scorer(n) {} + + weight_t + score(const vector<WordID>& hyp, + const vector<Ngrams>& reference_ngrams, + const vector<size_t>& reference_lengths) + { + size_t hl, rl, M; + vector<weight_t> v; + NgramCounts counts; + if (!init(hyp, reference_ngrams, reference_lengths, hl, rl, M, v, counts)) + return 0.; + weight_t sum = 0.; + size_t j = 1; + for (size_t i=0; i<M; i++) { + if (counts.sum[i]==0 || counts.clipped[i]==0) break; + sum += ((weight_t)counts.clipped[i]/counts.sum[i]) + / pow(2.0, (weight_t) (N_-j+1)); + //sum += exp(((score_t)counts.clipped[i]/counts.sum[i])) + // / pow(2.0, (weight_t) (N_-j+1)); + //sum += exp(v[i] * log(((score_t)counts.clipped[i]/counts.sum[i]))) + // / pow(2.0, (weight_t) (N_-j+1)); + j++; + } + + return brevity_penalty(hl, rl) * sum; + } +}; + +/* + * Linear (Corpus) Bleu + * TODO + * + * as in "Lattice Minimum Bayes-Risk Decoding + * for Statistical Machine Translation" + * (Tromble et al. '08) + * or "Hope and fear for discriminative training of + * statistical translation models" + * (Chiang '12) + * + */ + } // namespace #endif |