diff options
Diffstat (limited to 'mteval')
| -rw-r--r-- | mteval/mbr_kbest.cc | 4 | ||||
| -rw-r--r-- | mteval/scorer.cc | 12 | 
2 files changed, 12 insertions, 4 deletions
diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc index 2867b36b..64a6a8bf 100644 --- a/mteval/mbr_kbest.cc +++ b/mteval/mbr_kbest.cc @@ -32,7 +32,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {  }  struct LossComparer { -  bool operator()(const pair<vector<WordID>, double>& a, const pair<vector<WordID>, double>& b) const { +  bool operator()(const pair<vector<WordID>, prob_t>& a, const pair<vector<WordID>, prob_t>& b) const {      return a.second < b.second;    }  }; @@ -108,7 +108,7 @@ int main(int argc, char** argv) {            ScoreP s = scorer->ScoreCandidate(list[j].first);            double loss = 1.0 - s->ComputeScore();            if (type == TER || type == AER) loss = 1.0 - loss; -          double weighted_loss = loss * (joints[j] / marginal); +          double weighted_loss = loss * (joints[j] / marginal).as_float();            wl_acc += weighted_loss;            if ((!output_list) && wl_acc > mbr_loss) break;          } diff --git a/mteval/scorer.cc b/mteval/scorer.cc index 2daa0daa..a83b9e2f 100644 --- a/mteval/scorer.cc +++ b/mteval/scorer.cc @@ -430,6 +430,7 @@ float BLEUScore::ComputeScore(vector<float>* precs, float* bp) const {    float log_bleu = 0;    if (precs) precs->clear();    int count = 0; +  vector<float> total_precs(N());    for (int i = 0; i < N(); ++i) {      if (hyp_ngram_counts[i] > 0) {        float cor_count = correct_ngram_hit_counts[i]; @@ -440,14 +441,21 @@ float BLEUScore::ComputeScore(vector<float>* precs, float* bp) const {        log_bleu += lprec;        ++count;      } +    total_precs[i] = log_bleu;    } -  log_bleu /= static_cast<float>(count); +  vector<float> bleus(N());    float lbp = 0.0;    if (hyp_len < ref_len)      lbp = (hyp_len - ref_len) / hyp_len;    log_bleu += lbp;    if (bp) *bp = exp(lbp); -  return exp(log_bleu); +  float wb = 0; +  for (int i = 0; i < N(); ++i) { +    bleus[i] = exp(total_precs[i] / (i+1) + lbp); +    wb += bleus[i] / pow(2.0, 4.0 - i); +  } +  //return wb; +  return bleus.back();  }  | 
