diff options
Diffstat (limited to 'dtrain/score.h')
-rw-r--r-- | dtrain/score.h | 120 |
1 files changed, 100 insertions, 20 deletions
diff --git a/dtrain/score.h b/dtrain/score.h index 3e5d82a9..f87d708c 100644 --- a/dtrain/score.h +++ b/dtrain/score.h @@ -7,6 +7,8 @@ #include <cassert> #include <cmath> +#include "kbestget.h" + #include "wordid.h" // cdec using namespace std; @@ -15,15 +17,13 @@ namespace dtrain { -typedef double score_t; // float - struct NgramCounts { unsigned N_; map<unsigned, unsigned> clipped; map<unsigned, unsigned> sum; - NgramCounts(const unsigned N) : N_(N) { reset(); } + NgramCounts(const unsigned N) : N_(N) { Zero(); } void operator+=(const NgramCounts& rhs) @@ -44,20 +44,19 @@ struct NgramCounts } void - add(unsigned count, unsigned ref_count, unsigned i) + Add(unsigned count, unsigned ref_count, unsigned i) { assert(i < N_); if (count > ref_count) { clipped[i] += ref_count; - sum[i] += count; } else { clipped[i] += count; - sum[i] += count; } + sum[i] += count; } void - reset() + Zero() { unsigned i; for (i = 0; i < N_; i++) { @@ -67,7 +66,7 @@ struct NgramCounts } void - print() + Print() { for (unsigned i = 0; i < N_; i++) { cout << i+1 << "grams (clipped):\t" << clipped[i] << endl; @@ -78,18 +77,99 @@ struct NgramCounts typedef map<vector<WordID>, unsigned> Ngrams; -Ngrams make_ngrams(vector<WordID>& s, unsigned N); -NgramCounts make_ngram_counts(vector<WordID> hyp, vector<WordID> ref, unsigned N); - -score_t brevity_penaly(const unsigned hyp_len, const unsigned ref_len); -score_t bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len, const unsigned N, - vector<score_t> weights = vector<score_t>()); -score_t stupid_bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len, unsigned N, - vector<score_t> weights = vector<score_t>()); -score_t smooth_bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len, const unsigned N, - vector<score_t> weights = vector<score_t>()); -score_t approx_bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len, const unsigned N, - vector<score_t> weights = vector<score_t>()); +inline Ngrams +make_ngrams(const vector<WordID>& s, const unsigned N) +{ + Ngrams ngrams; + vector<WordID> ng; + for (size_t i = 0; i < s.size(); i++) { + ng.clear(); + for (unsigned j = i; j < min(i+N, s.size()); j++) { + ng.push_back(s[j]); + ngrams[ng]++; + } + } + return ngrams; +} + +inline NgramCounts +make_ngram_counts(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned N) +{ + Ngrams hyp_ngrams = make_ngrams(hyp, N); + Ngrams ref_ngrams = make_ngrams(ref, N); + NgramCounts counts(N); + Ngrams::iterator it; + Ngrams::iterator ti; + for (it = hyp_ngrams.begin(); it != hyp_ngrams.end(); it++) { + ti = ref_ngrams.find(it->first); + if (ti != ref_ngrams.end()) { + counts.Add(it->second, ti->second, it->first.size() - 1); + } else { + counts.Add(it->second, 0, it->first.size() - 1); + } + } + return counts; +} + +struct LocalScorer +{ + unsigned N_; + vector<score_t> w_; + + virtual score_t + Score(ScoredHyp& hyp, vector<WordID>& ref_ids, unsigned id)=0; + + void + Init(unsigned N, vector<score_t> weights) + { + assert(N > 0); + N_ = N; + if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_); + else w_ = weights; + } + + score_t + brevity_penaly(const unsigned hyp_len, const unsigned ref_len) + { + if (hyp_len > ref_len) return 1; + return exp(1 - (score_t)ref_len/hyp_len); + } +}; + +struct BleuScorer : public LocalScorer +{ + score_t Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len); + score_t Score(ScoredHyp& hyp, vector<WordID>& ref_ids, unsigned id); +}; + +struct StupidBleuScorer : public LocalScorer +{ + score_t Score(ScoredHyp& hyp, vector<WordID>& ref_ids, unsigned id); +}; + +struct SmoothBleuScorer : public LocalScorer +{ + score_t Score(ScoredHyp& hyp, vector<WordID>& ref_ids, unsigned id); +}; + +// FIXME +/*struct ApproxBleuScorer : public LocalScorer +{ + NgramCounts glob_onebest_counts; + unsigned glob_hyp_len, glob_ref_len; + + void Prep(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len); + void Reset(); + score_t Score(ScoredHyp& hyp, vector<WordID>& ref_ids, unsigned id); + + ApproxBleuScorer() + { + glob_onebest_counts.Zero(); + glob_hyp_len = 0; + glob_ref_len = 0; + } +};*/ + } // namespace |