diff options
Diffstat (limited to 'training/dtrain/score.h')
-rw-r--r-- | training/dtrain/score.h | 137 |
1 files changed, 67 insertions, 70 deletions
diff --git a/training/dtrain/score.h b/training/dtrain/score.h index c727dd30..d51aef82 100644 --- a/training/dtrain/score.h +++ b/training/dtrain/score.h @@ -8,17 +8,17 @@ namespace dtrain struct NgramCounts { - unsigned N_; - map<unsigned, score_t> clipped_; - map<unsigned, score_t> sum_; + size_t N_; + map<size_t, weight_t> clipped_; + map<size_t, weight_t> sum_; - NgramCounts(const unsigned N) : N_(N) { Zero(); } + NgramCounts(const size_t N) : N_(N) { Zero(); } inline void operator+=(const NgramCounts& rhs) { if (rhs.N_ > N_) Resize(rhs.N_); - for (unsigned i = 0; i < N_; i++) { + for (size_t i = 0; i < N_; i++) { this->clipped_[i] += rhs.clipped_.find(i)->second; this->sum_[i] += rhs.sum_.find(i)->second; } @@ -34,16 +34,16 @@ struct NgramCounts } inline void - operator*=(const score_t rhs) + operator*=(const weight_t rhs) { - for (unsigned i = 0; i < N_; i++) { + for (size_t i = 0; i < N_; i++) { this->clipped_[i] *= rhs; this->sum_[i] *= rhs; } } inline void - Add(const unsigned count, const unsigned ref_count, const unsigned i) + Add(const size_t count, const size_t ref_count, const size_t i) { assert(i < N_); if (count > ref_count) { @@ -57,40 +57,31 @@ struct NgramCounts inline void Zero() { - for (unsigned i = 0; i < N_; i++) { + for (size_t i = 0; i < N_; i++) { clipped_[i] = 0.; sum_[i] = 0.; } } inline void - One() + Print(ostream& os=cerr) { - for (unsigned i = 0; i < N_; i++) { - clipped_[i] = 1.; - sum_[i] = 1.; + for (size_t i = 0; i < N_; i++) { + os << i+1 << "grams (clipped):\t" << clipped_[i] << endl; + os << i+1 << "grams:\t\t\t" << sum_[i] << endl; } } - inline void - Print() - { - for (unsigned i = 0; i < N_; i++) { - cout << i+1 << "grams (clipped):\t" << clipped_[i] << endl; - cout << i+1 << "grams:\t\t\t" << sum_[i] << endl; - } - } - - inline void Resize(unsigned N) + inline void Resize(size_t N) { if (N == N_) return; else if (N > N_) { - for (unsigned i = N_; i < N; i++) { + for (size_t i = N_; i < N; i++) { clipped_[i] = 0.; sum_[i] = 0.; } } else { // N < N_ - for (unsigned i = N_-1; i > N-1; i--) { + for (size_t i = N_-1; i > N-1; i--) { clipped_.erase(i); sum_.erase(i); } @@ -99,16 +90,16 @@ struct NgramCounts } }; -typedef map<vector<WordID>, unsigned> Ngrams; +typedef map<vector<WordID>, size_t> Ngrams; inline Ngrams -MakeNgrams(const vector<WordID>& s, const unsigned N) +MakeNgrams(const vector<WordID>& s, const size_t N) { Ngrams ngrams; vector<WordID> ng; for (size_t i = 0; i < s.size(); i++) { ng.clear(); - for (unsigned j = i; j < min(i+N, s.size()); j++) { + for (size_t j = i; j < min(i+N, s.size()); j++) { ng.push_back(s[j]); ngrams[ng]++; } @@ -118,24 +109,21 @@ MakeNgrams(const vector<WordID>& s, const unsigned N) } inline NgramCounts -MakeNgramCounts(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned N) +MakeNgramCounts(const vector<WordID>& hyp, + const vector<Ngrams>& ref, + const size_t N) { Ngrams hyp_ngrams = MakeNgrams(hyp, N); - vector<Ngrams> refs_ngrams; - for (auto r: refs) { - Ngrams r_ng = MakeNgrams(r, N); - refs_ngrams.push_back(r_ng); - } NgramCounts counts(N); Ngrams::iterator it, ti; for (it = hyp_ngrams.begin(); it != hyp_ngrams.end(); it++) { - unsigned max_ref_count = 0; - for (auto ref_ngrams: refs_ngrams) { - ti = ref_ngrams.find(it->first); - if (ti != ref_ngrams.end()) + size_t max_ref_count = 0; + for (auto r: ref) { + ti = r.find(it->first); + if (ti != r.end()) max_ref_count = max(max_ref_count, ti->second); } - counts.Add(it->second, min(it->second, max_ref_count), it->first.size() - 1); + counts.Add(it->second, min(it->second, max_ref_count), it->first.size()-1); } return counts; @@ -150,56 +138,65 @@ MakeNgramCounts(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, * [simply add 1 to reference length for calculation of BP] * */ - struct PerSentenceBleuScorer { - const unsigned N_; - vector<score_t> w_; + const size_t N_; + vector<weight_t> w_; - PerSentenceBleuScorer(unsigned n) : N_(n) + PerSentenceBleuScorer(size_t n) : N_(n) { - for (auto i = 1; i <= N_; i++) + for (size_t i = 1; i <= N_; i++) w_.push_back(1.0/N_); } - inline score_t - BrevityPenalty(const unsigned hyp_len, const unsigned ref_len) + inline weight_t + BrevityPenalty(const size_t hl, const size_t rl) { - if (hyp_len > ref_len) return 1; - return exp(1 - (score_t)ref_len/hyp_len); + if (hl > rl) + return 1; + + return exp(1 - (weight_t)rl/hl); } - score_t - Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs) + weight_t + Score(const vector<WordID>& hyp, + const vector<Ngrams>& ref_ngs, + const vector<size_t>& ref_ls) { - unsigned hyp_len = hyp.size(), ref_len = 0; + size_t hl = hyp.size(), rl = 0; + if (hl == 0) return 0.; // best match reference length - if (refs.size() == 1) { - ref_len = refs[0].size(); + if (ref_ls.size() == 1) { + rl = ref_ls.front(); } else { - unsigned i = 0, best_idx = 0; - unsigned best = std::numeric_limits<unsigned>::max(); - for (auto r: refs) { - unsigned d = abs(hyp_len-r.size()); - if (best > d) best_idx = i; + size_t i = 0, best_idx = 0; + size_t best = numeric_limits<size_t>::max(); + for (auto l: ref_ls) { + size_t d = abs(hl-l); + if (d < best) { + best_idx = i; + best = d; + } + i += 1; } - ref_len = refs[best_idx].size(); + rl = ref_ls[best_idx]; } - if (hyp_len == 0 || ref_len == 0) return 0.; - NgramCounts counts = MakeNgramCounts(hyp, refs, N_); - unsigned M = N_; - vector<score_t> v = w_; - if (ref_len < N_) { - M = ref_len; - for (unsigned i = 0; i < M; i++) v[i] = 1/((score_t)M); + if (rl == 0) return 0.; + NgramCounts counts = MakeNgramCounts(hyp, ref_ngs, N_); + size_t M = N_; + vector<weight_t> v = w_; + if (rl < N_) { + M = rl; + for (size_t i = 0; i < M; i++) v[i] = 1/((weight_t)M); } - score_t sum = 0, add = 0; - for (unsigned i = 0; i < M; i++) { + weight_t sum = 0, add = 0; + for (size_t i = 0; i < M; i++) { if (i == 0 && (counts.sum_[i] == 0 || counts.clipped_[i] == 0)) return 0.; if (i == 1) add = 1; - sum += v[i] * log(((score_t)counts.clipped_[i] + add)/((counts.sum_[i] + add))); + sum += v[i] * log(((weight_t)counts.clipped_[i] + add)/((counts.sum_[i] + add))); } - return BrevityPenalty(hyp_len, ref_len+1) * exp(sum); + + return BrevityPenalty(hl, rl+1) * exp(sum); } }; |