summaryrefslogtreecommitdiff
path: root/training/dtrain/score.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/score.h')
-rw-r--r--training/dtrain/score.h51
1 files changed, 22 insertions, 29 deletions
diff --git a/training/dtrain/score.h b/training/dtrain/score.h
index d51aef82..06dbc5a4 100644
--- a/training/dtrain/score.h
+++ b/training/dtrain/score.h
@@ -34,15 +34,6 @@ struct NgramCounts
}
inline void
- operator*=(const weight_t rhs)
- {
- for (size_t i = 0; i < N_; i++) {
- this->clipped_[i] *= rhs;
- this->sum_[i] *= rhs;
- }
- }
-
- inline void
Add(const size_t count, const size_t ref_count, const size_t i)
{
assert(i < N_);
@@ -64,15 +55,7 @@ struct NgramCounts
}
inline void
- Print(ostream& os=cerr)
- {
- for (size_t i = 0; i < N_; i++) {
- os << i+1 << "grams (clipped):\t" << clipped_[i] << endl;
- os << i+1 << "grams:\t\t\t" << sum_[i] << endl;
- }
- }
-
- inline void Resize(size_t N)
+ Resize(size_t N)
{
if (N == N_) return;
else if (N > N_) {
@@ -158,16 +141,13 @@ struct PerSentenceBleuScorer
return exp(1 - (weight_t)rl/hl);
}
- weight_t
- Score(const vector<WordID>& hyp,
- const vector<Ngrams>& ref_ngs,
- const vector<size_t>& ref_ls)
+ inline size_t
+ BestMatchLength(const size_t hl,
+ const vector<size_t>& ref_ls)
{
- size_t hl = hyp.size(), rl = 0;
- if (hl == 0) return 0.;
- // best match reference length
+ size_t m;
if (ref_ls.size() == 1) {
- rl = ref_ls.front();
+ m = ref_ls.front();
} else {
size_t i = 0, best_idx = 0;
size_t best = numeric_limits<size_t>::max();
@@ -179,8 +159,20 @@ struct PerSentenceBleuScorer
}
i += 1;
}
- rl = ref_ls[best_idx];
+ m = ref_ls[best_idx];
}
+
+ return m;
+ }
+
+ weight_t
+ Score(const vector<WordID>& hyp,
+ const vector<Ngrams>& ref_ngs,
+ const vector<size_t>& ref_ls)
+ {
+ size_t hl = hyp.size(), rl = 0;
+ if (hl == 0) return 0.;
+ rl = BestMatchLength(hl, ref_ls);
if (rl == 0) return 0.;
NgramCounts counts = MakeNgramCounts(hyp, ref_ngs, N_);
size_t M = N_;
@@ -192,8 +184,9 @@ struct PerSentenceBleuScorer
weight_t sum = 0, add = 0;
for (size_t i = 0; i < M; i++) {
if (i == 0 && (counts.sum_[i] == 0 || counts.clipped_[i] == 0)) return 0.;
- if (i == 1) add = 1;
- sum += v[i] * log(((weight_t)counts.clipped_[i] + add)/((counts.sum_[i] + add)));
+ if (i > 0) add = 1;
+ sum += v[i] * log(((weight_t)counts.clipped_[i] + add)
+ / ((counts.sum_[i] + add)));
}
return BrevityPenalty(hl, rl+1) * exp(sum);