diff options
Diffstat (limited to 'training/dtrain/score.h')
-rw-r--r-- | training/dtrain/score.h | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/training/dtrain/score.h b/training/dtrain/score.h index 1cdd3fa9..7d88cb61 100644 --- a/training/dtrain/score.h +++ b/training/dtrain/score.h @@ -117,20 +117,25 @@ make_ngrams(const vector<WordID>& s, const unsigned N) } inline NgramCounts -make_ngram_counts(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned N) +make_ngram_counts(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned N) { Ngrams hyp_ngrams = make_ngrams(hyp, N); - Ngrams ref_ngrams = make_ngrams(ref, N); + vector<Ngrams> refs_ngrams; + for (auto r: refs) { + Ngrams r_ng = make_ngrams(r, N); + refs_ngrams.push_back(r_ng); + } NgramCounts counts(N); Ngrams::iterator it; Ngrams::iterator ti; for (it = hyp_ngrams.begin(); it != hyp_ngrams.end(); it++) { - ti = ref_ngrams.find(it->first); - if (ti != ref_ngrams.end()) { - counts.Add(it->second, ti->second, it->first.size() - 1); - } else { - counts.Add(it->second, 0, it->first.size() - 1); + unsigned max_ref_count = 0; + for (auto ref_ngrams: refs_ngrams) { + ti = ref_ngrams.find(it->first); + if (ti != ref_ngrams.end()) + max_ref_count = max(max_ref_count, ti->second); } + counts.Add(it->second, max_ref_count, it->first.size() - 1); } return counts; } @@ -138,43 +143,43 @@ make_ngram_counts(const vector<WordID>& hyp, const vector<WordID>& ref, const un struct BleuScorer : public LocalScorer { score_t Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len); - score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); + score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/); void Reset() {} }; struct StupidBleuScorer : public LocalScorer { - score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); + score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/); void Reset() {} }; struct FixedStupidBleuScorer : public LocalScorer { - score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); + score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/); void Reset() {} }; struct SmoothBleuScorer : public LocalScorer { - score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); + score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/); void Reset() {} }; struct SumBleuScorer : public LocalScorer { - score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); + score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/); void Reset() {} }; struct SumExpBleuScorer : public LocalScorer { - score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); + score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/); void Reset() {} }; struct SumWhateverBleuScorer : public LocalScorer { - score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); + score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/); void Reset() {}; }; @@ -194,7 +199,7 @@ struct ApproxBleuScorer : public BleuScorer glob_hyp_len_ = glob_ref_len_ = glob_src_len_ = 0.; } - score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned rank, const unsigned src_len); + score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned rank, const unsigned src_len); }; struct LinearBleuScorer : public BleuScorer @@ -207,7 +212,7 @@ struct LinearBleuScorer : public BleuScorer onebest_counts_.One(); } - score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned rank, const unsigned /*src_len*/); + score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned rank, const unsigned /*src_len*/); inline void Reset() { onebest_len_ = 1; |