diff options
Diffstat (limited to 'training/dtrain/score.h')
-rw-r--r-- | training/dtrain/score.h | 51 |
1 files changed, 22 insertions, 29 deletions
diff --git a/training/dtrain/score.h b/training/dtrain/score.h index d51aef82..06dbc5a4 100644 --- a/training/dtrain/score.h +++ b/training/dtrain/score.h @@ -34,15 +34,6 @@ struct NgramCounts } inline void - operator*=(const weight_t rhs) - { - for (size_t i = 0; i < N_; i++) { - this->clipped_[i] *= rhs; - this->sum_[i] *= rhs; - } - } - - inline void Add(const size_t count, const size_t ref_count, const size_t i) { assert(i < N_); @@ -64,15 +55,7 @@ struct NgramCounts } inline void - Print(ostream& os=cerr) - { - for (size_t i = 0; i < N_; i++) { - os << i+1 << "grams (clipped):\t" << clipped_[i] << endl; - os << i+1 << "grams:\t\t\t" << sum_[i] << endl; - } - } - - inline void Resize(size_t N) + Resize(size_t N) { if (N == N_) return; else if (N > N_) { @@ -158,16 +141,13 @@ struct PerSentenceBleuScorer return exp(1 - (weight_t)rl/hl); } - weight_t - Score(const vector<WordID>& hyp, - const vector<Ngrams>& ref_ngs, - const vector<size_t>& ref_ls) + inline size_t + BestMatchLength(const size_t hl, + const vector<size_t>& ref_ls) { - size_t hl = hyp.size(), rl = 0; - if (hl == 0) return 0.; - // best match reference length + size_t m; if (ref_ls.size() == 1) { - rl = ref_ls.front(); + m = ref_ls.front(); } else { size_t i = 0, best_idx = 0; size_t best = numeric_limits<size_t>::max(); @@ -179,8 +159,20 @@ struct PerSentenceBleuScorer } i += 1; } - rl = ref_ls[best_idx]; + m = ref_ls[best_idx]; } + + return m; + } + + weight_t + Score(const vector<WordID>& hyp, + const vector<Ngrams>& ref_ngs, + const vector<size_t>& ref_ls) + { + size_t hl = hyp.size(), rl = 0; + if (hl == 0) return 0.; + rl = BestMatchLength(hl, ref_ls); if (rl == 0) return 0.; NgramCounts counts = MakeNgramCounts(hyp, ref_ngs, N_); size_t M = N_; @@ -192,8 +184,9 @@ struct PerSentenceBleuScorer weight_t sum = 0, add = 0; for (size_t i = 0; i < M; i++) { if (i == 0 && (counts.sum_[i] == 0 || counts.clipped_[i] == 0)) return 0.; - if (i == 1) add = 1; - sum += v[i] * log(((weight_t)counts.clipped_[i] + add)/((counts.sum_[i] + add))); + if (i > 0) add = 1; + sum += v[i] * log(((weight_t)counts.clipped_[i] + add) + / ((counts.sum_[i] + add))); } return BrevityPenalty(hl, rl+1) * exp(sum); |