diff options
Diffstat (limited to 'training/dtrain/score.h')
-rw-r--r-- | training/dtrain/score.h | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/training/dtrain/score.h b/training/dtrain/score.h new file mode 100644 index 00000000..bddaa071 --- /dev/null +++ b/training/dtrain/score.h @@ -0,0 +1,217 @@ +#ifndef _DTRAIN_SCORE_H_ +#define _DTRAIN_SCORE_H_ + +#include "kbestget.h" + +using namespace std; + +namespace dtrain +{ + + +struct NgramCounts +{ + unsigned N_; + map<unsigned, score_t> clipped_; + map<unsigned, score_t> sum_; + + NgramCounts(const unsigned N) : N_(N) { Zero(); } + + inline void + operator+=(const NgramCounts& rhs) + { + if (rhs.N_ > N_) Resize(rhs.N_); + for (unsigned i = 0; i < N_; i++) { + this->clipped_[i] += rhs.clipped_.find(i)->second; + this->sum_[i] += rhs.sum_.find(i)->second; + } + } + + inline const NgramCounts + operator+(const NgramCounts &other) const + { + NgramCounts result = *this; + result += other; + return result; + } + + inline void + operator*=(const score_t rhs) + { + for (unsigned i = 0; i < N_; i++) { + this->clipped_[i] *= rhs; + this->sum_[i] *= rhs; + } + } + + inline void + Add(const unsigned count, const unsigned ref_count, const unsigned i) + { + assert(i < N_); + if (count > ref_count) { + clipped_[i] += ref_count; + } else { + clipped_[i] += count; + } + sum_[i] += count; + } + + inline void + Zero() + { + for (unsigned i = 0; i < N_; i++) { + clipped_[i] = 0.; + sum_[i] = 0.; + } + } + + inline void + One() + { + for (unsigned i = 0; i < N_; i++) { + clipped_[i] = 1.; + sum_[i] = 1.; + } + } + + inline void + Print() + { + for (unsigned i = 0; i < N_; i++) { + cout << i+1 << "grams (clipped):\t" << clipped_[i] << endl; + cout << i+1 << "grams:\t\t\t" << sum_[i] << endl; + } + } + + inline void Resize(unsigned N) + { + if (N == N_) return; + else if (N > N_) { + for (unsigned i = N_; i < N; i++) { + clipped_[i] = 0.; + sum_[i] = 0.; + } + } else { // N < N_ + for (unsigned i = N_-1; i > N-1; i--) { + clipped_.erase(i); + sum_.erase(i); + } + } + N_ = N; + } +}; + +typedef map<vector<WordID>, unsigned> Ngrams; + +inline Ngrams +make_ngrams(const vector<WordID>& s, const unsigned N) +{ + Ngrams ngrams; + vector<WordID> ng; + for (size_t i = 0; i < s.size(); i++) { + ng.clear(); + for (unsigned j = i; j < min(i+N, s.size()); j++) { + ng.push_back(s[j]); + ngrams[ng]++; + } + } + return ngrams; +} + +inline NgramCounts +make_ngram_counts(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned N) +{ + Ngrams hyp_ngrams = make_ngrams(hyp, N); + Ngrams ref_ngrams = make_ngrams(ref, N); + NgramCounts counts(N); + Ngrams::iterator it; + Ngrams::iterator ti; + for (it = hyp_ngrams.begin(); it != hyp_ngrams.end(); it++) { + ti = ref_ngrams.find(it->first); + if (ti != ref_ngrams.end()) { + counts.Add(it->second, ti->second, it->first.size() - 1); + } else { + counts.Add(it->second, 0, it->first.size() - 1); + } + } + return counts; +} + +struct BleuScorer : public LocalScorer +{ + score_t Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len); + score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct StupidBleuScorer : public LocalScorer +{ + score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct FixedStupidBleuScorer : public LocalScorer +{ + score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct SmoothBleuScorer : public LocalScorer +{ + score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct SumBleuScorer : public LocalScorer +{ + score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct SumExpBleuScorer : public LocalScorer +{ + score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct SumWhateverBleuScorer : public LocalScorer +{ + score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct ApproxBleuScorer : public BleuScorer +{ + NgramCounts glob_onebest_counts_; + unsigned glob_hyp_len_, glob_ref_len_, glob_src_len_; + score_t discount_; + + ApproxBleuScorer(unsigned N, score_t d) : glob_onebest_counts_(NgramCounts(N)), discount_(d) + { + glob_hyp_len_ = glob_ref_len_ = glob_src_len_ = 0; + } + + inline void Reset() { + glob_onebest_counts_.Zero(); + glob_hyp_len_ = glob_ref_len_ = glob_src_len_ = 0.; + } + + score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank, const unsigned src_len); +}; + +struct LinearBleuScorer : public BleuScorer +{ + unsigned onebest_len_; + NgramCounts onebest_counts_; + + LinearBleuScorer(unsigned N) : onebest_len_(1), onebest_counts_(N) + { + onebest_counts_.One(); + } + + score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank, const unsigned /*src_len*/); + + inline void Reset() { + onebest_len_ = 1; + onebest_counts_.One(); + } +}; + + +} // namespace + +#endif + |