summaryrefslogtreecommitdiff
path: root/dtrain/kbestget.h
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r--dtrain/kbestget.h12
1 files changed, 8 insertions, 4 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h
index bcd82610..77d4a139 100644
--- a/dtrain/kbestget.h
+++ b/dtrain/kbestget.h
@@ -2,6 +2,8 @@
#define _DTRAIN_KBESTGET_H_
#include "kbest.h" // cdec
+#include "sentence_metadata.h"
+
#include "verbose.h"
#include "viterbi.h"
#include "ff_register.h"
@@ -32,7 +34,7 @@ struct LocalScorer
vector<score_t> w_;
virtual score_t
- Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank)=0;
+ Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank, const unsigned src_len)=0;
void Reset() {} // only for approx bleu
@@ -71,13 +73,15 @@ struct KBestGetter : public HypSampler
const unsigned k_;
const string filter_type_;
vector<ScoredHyp> s_;
+ unsigned src_len_;
KBestGetter(const unsigned k, const string filter_type) :
k_(k), filter_type_(filter_type) {}
virtual void
- NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg)
+ NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
{
+ src_len_ = smeta.GetSourceLength();
KBestScored(*hg);
}
@@ -109,7 +113,7 @@ struct KBestGetter : public HypSampler
h.f = d->feature_values;
h.model = log(d->score);
h.rank = i;
- h.score = scorer_->Score(h.w, *ref_, i);
+ h.score = scorer_->Score(h.w, *ref_, i, src_len_);
s_.push_back(h);
}
}
@@ -128,7 +132,7 @@ struct KBestGetter : public HypSampler
h.f = d->feature_values;
h.model = log(d->score);
h.rank = i;
- h.score = scorer_->Score(h.w, *ref_, i);
+ h.score = scorer_->Score(h.w, *ref_, i, src_len_);
s_.push_back(h);
}
}