diff options
Diffstat (limited to 'vest/scorer.cc')
-rw-r--r-- | vest/scorer.cc | 127 |
1 files changed, 114 insertions, 13 deletions
diff --git a/vest/scorer.cc b/vest/scorer.cc index 6c604ab8..524b15a5 100644 --- a/vest/scorer.cc +++ b/vest/scorer.cc @@ -35,6 +35,8 @@ ScoreType ScoreTypeFromString(const string& st) { return AER; if (sl == "bleu" || sl == "ibm_bleu") return IBM_BLEU; + if (sl == "ibm_bleu_3") + return IBM_BLEU_3; if (sl == "nist_bleu") return NIST_BLEU; if (sl == "koehn_bleu") @@ -53,6 +55,7 @@ class SERScore : public Score { friend class SERScorer; public: SERScore() : correct(0), total(0) {} + float ComputePartialScore() const { return 0.0;} float ComputeScore() const { return static_cast<float>(correct) / static_cast<float>(total); } @@ -61,11 +64,18 @@ class SERScore : public Score { os << "SER= " << ComputeScore() << " (" << correct << '/' << total << ')'; *details = os.str(); } - void PlusEquals(const Score& delta) { + void PlusPartialEquals(const Score& delta, int oracle_e_cover, int oracle_f_cover, int src_len){} + + void PlusEquals(const Score& delta, const float scale) { correct += static_cast<const SERScore&>(delta).correct; total += static_cast<const SERScore&>(delta).total; } + void PlusEquals(const Score& delta) { + correct += static_cast<const SERScore&>(delta).correct; + total += static_cast<const SERScore&>(delta).total; + } Score* GetZero() const { return new SERScore; } + Score* GetOne() const { return new SERScore; } void Subtract(const Score& rhs, Score* res) const { SERScore* r = static_cast<SERScore*>(res); r->correct = correct - static_cast<const SERScore&>(rhs).correct; @@ -84,6 +94,10 @@ class SERScore : public Score { class SERScorer : public SentenceScorer { public: SERScorer(const vector<vector<WordID> >& references) : refs_(references) {} + Score* ScoreCCandidate(const vector<WordID>& hyp) const { + Score* a = NULL; + return a; + } Score* ScoreCandidate(const vector<WordID>& hyp) const { SERScore* res = new SERScore; res->total = 1; @@ -101,13 +115,20 @@ class SERScorer : public SentenceScorer { class BLEUScore : public Score { friend class BLEUScorerBase; public: - BLEUScore(int n) : correct_ngram_hit_counts(0,n), hyp_ngram_counts(0,n) { + BLEUScore(int n) : correct_ngram_hit_counts(float(0),float(n)), hyp_ngram_counts(float(0),float(n)) { ref_len = 0; hyp_len = 0; } + BLEUScore(int n, int k) : correct_ngram_hit_counts(float(k),float(n)), hyp_ngram_counts(float(k),float(n)) { + ref_len = k; + hyp_len = k; } float ComputeScore() const; + float ComputePartialScore() const; void ScoreDetails(string* details) const; void PlusEquals(const Score& delta); + void PlusEquals(const Score& delta, const float scale); + void PlusPartialEquals(const Score& delta, int oracle_e_cover, int oracle_f_cover, int src_len); Score* GetZero() const; + Score* GetOne() const; void Subtract(const Score& rhs, Score* res) const; void Encode(string* out) const; bool IsAdditiveIdentity() const { @@ -119,10 +140,11 @@ class BLEUScore : public Score { } private: float ComputeScore(vector<float>* precs, float* bp) const; - valarray<int> correct_ngram_hit_counts; - valarray<int> hyp_ngram_counts; + float ComputePartialScore(vector<float>* prec, float* bp) const; + valarray<float> correct_ngram_hit_counts; + valarray<float> hyp_ngram_counts; float ref_len; - int hyp_len; + float hyp_len; }; class BLEUScorerBase : public SentenceScorer { @@ -131,6 +153,7 @@ class BLEUScorerBase : public SentenceScorer { int n ); Score* ScoreCandidate(const vector<WordID>& hyp) const; + Score* ScoreCCandidate(const vector<WordID>& hyp) const; static Score* ScoreFromString(const string& in); protected: @@ -171,8 +194,10 @@ class BLEUScorerBase : public SentenceScorer { } void ComputeNgramStats(const vector<WordID>& sent, - valarray<int>* correct, - valarray<int>* hyp) const { + valarray<float>* correct, + valarray<float>* hyp, + bool clip_counts) + const { assert(correct->size() == n_); assert(hyp->size() == n_); vector<WordID> ngram(n_); @@ -186,10 +211,15 @@ class BLEUScorerBase : public SentenceScorer { for (int i=1; i<=k; ++i) { ngram.push_back(sent[j + i - 1]); pair<int,int>& p = ngrams_[ngram]; - if (p.second < p.first) { - ++p.second; - (*correct)[i-1]++; - } + if(clip_counts){ + if (p.second < p.first) { + ++p.second; + (*correct)[i-1]++; + }} + else { + ++p.second; + (*correct)[i-1]++; + } // if the 1 gram isn't found, don't try to match don't need to match any 2- 3- .. grams: if (!p.first) { for (; i<=k; ++i) @@ -284,7 +314,8 @@ SentenceScorer* SentenceScorer::CreateSentenceScorer(const ScoreType type, const vector<vector<WordID> >& refs, const string& src) { switch (type) { - case IBM_BLEU: return new IBM_BLEUScorer(refs, 4); + case IBM_BLEU: return new IBM_BLEUScorer(refs, 4); + case IBM_BLEU_3 : return new IBM_BLEUScorer(refs,3); case NIST_BLEU: return new NIST_BLEUScorer(refs, 4); case Koehn_BLEU: return new Koehn_BLEUScorer(refs, 4); case AER: return new AERScorer(refs, src); @@ -299,6 +330,7 @@ SentenceScorer* SentenceScorer::CreateSentenceScorer(const ScoreType type, Score* SentenceScorer::CreateScoreFromString(const ScoreType type, const string& in) { switch (type) { case IBM_BLEU: + case IBM_BLEU_3: case NIST_BLEU: case Koehn_BLEU: return BLEUScorerBase::ScoreFromString(in); @@ -423,6 +455,36 @@ float BLEUScore::ComputeScore(vector<float>* precs, float* bp) const { return exp(log_bleu); } + +//comptue scaled score for oracle retrieval +float BLEUScore::ComputePartialScore(vector<float>* precs, float* bp) const { + // cerr << "Then here " << endl; + float log_bleu = 0; + if (precs) precs->clear(); + int count = 0; + for (int i = 0; i < hyp_ngram_counts.size(); ++i) { + // cerr << "In CPS " << hyp_ngram_counts[i] << " " << correct_ngram_hit_counts[i] << endl; + if (hyp_ngram_counts[i] > 0) { + float lprec = log(correct_ngram_hit_counts[i]) - log(hyp_ngram_counts[i]); + if (precs) precs->push_back(exp(lprec)); + log_bleu += lprec; + ++count; + } + } + log_bleu /= static_cast<float>(count); + float lbp = 0.0; + if (hyp_len < ref_len) + lbp = (hyp_len - ref_len) / hyp_len; + log_bleu += lbp; + if (bp) *bp = exp(lbp); + return exp(log_bleu); +} + +float BLEUScore::ComputePartialScore() const { + // cerr << "In here first " << endl; + return ComputePartialScore(NULL, NULL); +} + float BLEUScore::ComputeScore() const { return ComputeScore(NULL, NULL); } @@ -444,10 +506,37 @@ void BLEUScore::PlusEquals(const Score& delta) { hyp_len += d.hyp_len; } +void BLEUScore::PlusEquals(const Score& delta, const float scale) { + const BLEUScore& d = static_cast<const BLEUScore&>(delta); + correct_ngram_hit_counts = (correct_ngram_hit_counts + d.correct_ngram_hit_counts) * scale; + hyp_ngram_counts = ( hyp_ngram_counts + d.hyp_ngram_counts) * scale; + ref_len = (ref_len + d.ref_len) * scale; + hyp_len = ( hyp_len + d.hyp_len) * scale; + +} + +void BLEUScore::PlusPartialEquals(const Score& delta, int oracle_e_cover, int oracle_f_cover, int src_len){ + const BLEUScore& d = static_cast<const BLEUScore&>(delta); + correct_ngram_hit_counts += d.correct_ngram_hit_counts; + hyp_ngram_counts += d.hyp_ngram_counts; + //scale the reference length according to the size of the input sentence covered by this rule + + ref_len *= (float)oracle_f_cover / src_len; + ref_len += d.ref_len; + + hyp_len = oracle_e_cover; + hyp_len += d.hyp_len; +} + + Score* BLEUScore::GetZero() const { return new BLEUScore(hyp_ngram_counts.size()); } +Score* BLEUScore::GetOne() const { + return new BLEUScore(hyp_ngram_counts.size(),1); +} + void BLEUScore::Encode(string* out) const { ostringstream os; const int n = correct_ngram_hit_counts.size(); @@ -470,12 +559,24 @@ Score* BLEUScorerBase::ScoreCandidate(const vector<WordID>& hyp) const { BLEUScore* bs = new BLEUScore(n_); for (NGramCountMap::iterator i=ngrams_.begin(); i != ngrams_.end(); ++i) i->second.second = 0; - ComputeNgramStats(hyp, &bs->correct_ngram_hit_counts, &bs->hyp_ngram_counts); + ComputeNgramStats(hyp, &bs->correct_ngram_hit_counts, &bs->hyp_ngram_counts, true); bs->ref_len = ComputeRefLength(hyp); bs->hyp_len = hyp.size(); return bs; } +Score* BLEUScorerBase::ScoreCCandidate(const vector<WordID>& hyp) const { + BLEUScore* bs = new BLEUScore(n_); + for (NGramCountMap::iterator i=ngrams_.begin(); i != ngrams_.end(); ++i) + i->second.second = 0; + bool clip = false; + ComputeNgramStats(hyp, &bs->correct_ngram_hit_counts, &bs->hyp_ngram_counts,clip); + bs->ref_len = ComputeRefLength(hyp); + bs->hyp_len = hyp.size(); + return bs; +} + + DocScorer::~DocScorer() { for (int i=0; i < scorers_.size(); ++i) delete scorers_[i]; |