summaryrefslogtreecommitdiff
path: root/training/dtrain/score.h
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2015-01-23 15:50:27 +0100
committerPatrick Simianer <p@simianer.de>2015-01-23 15:50:27 +0100
commit32dea3f24e56ac7c17343457c48f750f16838742 (patch)
tree79177b58cbff08c14991a0da8e851912b1c06309 /training/dtrain/score.h
parent556dc935c7a2d8df78a35447d20d71b4bf6e391a (diff)
dtrain: multi-reference BLEU
Diffstat (limited to 'training/dtrain/score.h')
-rw-r--r--training/dtrain/score.h37
1 files changed, 21 insertions, 16 deletions
diff --git a/training/dtrain/score.h b/training/dtrain/score.h
index 1cdd3fa9..7d88cb61 100644
--- a/training/dtrain/score.h
+++ b/training/dtrain/score.h
@@ -117,20 +117,25 @@ make_ngrams(const vector<WordID>& s, const unsigned N)
}
inline NgramCounts
-make_ngram_counts(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned N)
+make_ngram_counts(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned N)
{
Ngrams hyp_ngrams = make_ngrams(hyp, N);
- Ngrams ref_ngrams = make_ngrams(ref, N);
+ vector<Ngrams> refs_ngrams;
+ for (auto r: refs) {
+ Ngrams r_ng = make_ngrams(r, N);
+ refs_ngrams.push_back(r_ng);
+ }
NgramCounts counts(N);
Ngrams::iterator it;
Ngrams::iterator ti;
for (it = hyp_ngrams.begin(); it != hyp_ngrams.end(); it++) {
- ti = ref_ngrams.find(it->first);
- if (ti != ref_ngrams.end()) {
- counts.Add(it->second, ti->second, it->first.size() - 1);
- } else {
- counts.Add(it->second, 0, it->first.size() - 1);
+ unsigned max_ref_count = 0;
+ for (auto ref_ngrams: refs_ngrams) {
+ ti = ref_ngrams.find(it->first);
+ if (ti != ref_ngrams.end())
+ max_ref_count = max(max_ref_count, ti->second);
}
+ counts.Add(it->second, max_ref_count, it->first.size() - 1);
}
return counts;
}
@@ -138,43 +143,43 @@ make_ngram_counts(const vector<WordID>& hyp, const vector<WordID>& ref, const un
struct BleuScorer : public LocalScorer
{
score_t Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len);
- score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/);
+ score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/);
void Reset() {}
};
struct StupidBleuScorer : public LocalScorer
{
- score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/);
+ score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/);
void Reset() {}
};
struct FixedStupidBleuScorer : public LocalScorer
{
- score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/);
+ score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/);
void Reset() {}
};
struct SmoothBleuScorer : public LocalScorer
{
- score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/);
+ score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/);
void Reset() {}
};
struct SumBleuScorer : public LocalScorer
{
- score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/);
+ score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/);
void Reset() {}
};
struct SumExpBleuScorer : public LocalScorer
{
- score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/);
+ score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/);
void Reset() {}
};
struct SumWhateverBleuScorer : public LocalScorer
{
- score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned /*rank*/, const unsigned /*src_len*/);
+ score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned /*rank*/, const unsigned /*src_len*/);
void Reset() {};
};
@@ -194,7 +199,7 @@ struct ApproxBleuScorer : public BleuScorer
glob_hyp_len_ = glob_ref_len_ = glob_src_len_ = 0.;
}
- score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned rank, const unsigned src_len);
+ score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned rank, const unsigned src_len);
};
struct LinearBleuScorer : public BleuScorer
@@ -207,7 +212,7 @@ struct LinearBleuScorer : public BleuScorer
onebest_counts_.One();
}
- score_t Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned rank, const unsigned /*src_len*/);
+ score_t Score(const vector<WordID>& hyp, const vector<vector<WordID> >& refs, const unsigned rank, const unsigned /*src_len*/);
inline void Reset() {
onebest_len_ = 1;