diff options
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r-- | dtrain/kbestget.h | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h index bcd82610..77d4a139 100644 --- a/dtrain/kbestget.h +++ b/dtrain/kbestget.h @@ -2,6 +2,8 @@ #define _DTRAIN_KBESTGET_H_ #include "kbest.h" // cdec +#include "sentence_metadata.h" + #include "verbose.h" #include "viterbi.h" #include "ff_register.h" @@ -32,7 +34,7 @@ struct LocalScorer vector<score_t> w_; virtual score_t - Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank)=0; + Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank, const unsigned src_len)=0; void Reset() {} // only for approx bleu @@ -71,13 +73,15 @@ struct KBestGetter : public HypSampler const unsigned k_; const string filter_type_; vector<ScoredHyp> s_; + unsigned src_len_; KBestGetter(const unsigned k, const string filter_type) : k_(k), filter_type_(filter_type) {} virtual void - NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg) + NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + src_len_ = smeta.GetSourceLength(); KBestScored(*hg); } @@ -109,7 +113,7 @@ struct KBestGetter : public HypSampler h.f = d->feature_values; h.model = log(d->score); h.rank = i; - h.score = scorer_->Score(h.w, *ref_, i); + h.score = scorer_->Score(h.w, *ref_, i, src_len_); s_.push_back(h); } } @@ -128,7 +132,7 @@ struct KBestGetter : public HypSampler h.f = d->feature_values; h.model = log(d->score); h.rank = i; - h.score = scorer_->Score(h.w, *ref_, i); + h.score = scorer_->Score(h.w, *ref_, i, src_len_); s_.push_back(h); } } |