From b20c6665542cbfa1b4328b349d6912944a1483f2 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Thu, 26 Feb 2015 22:12:23 +0100 Subject: last tweaks & fixes --- training/dtrain/sample.h | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) (limited to 'training/dtrain/sample.h') diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h index c3586c58..03cc82c3 100644 --- a/training/dtrain/sample.h +++ b/training/dtrain/sample.h @@ -3,20 +3,19 @@ #include "kbest.h" +#include "score.h" + namespace dtrain { - struct ScoredKbest : public DecoderObserver { const size_t k_; - vector s_; - size_t src_len_; + size_t feature_count_, effective_sz_; + vector samples_; PerSentenceBleuScorer* scorer_; - vector >* refs_; vector* ref_ngs_; vector* ref_ls_; - size_t f_count_, sz_; ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) : k_(k), scorer_(scorer) {} @@ -24,14 +23,13 @@ struct ScoredKbest : public DecoderObserver virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { - src_len_ = smeta.GetSourceLength(); - s_.clear(); sz_ = f_count_ = 0; + samples_.clear(); effective_sz_ = feature_count_ = 0; KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_); for (size_t i = 0; i < k_; ++i) { - const KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique, - prob_t, EdgeProb>::Derivation* d = - kbest.LazyKthBest(hg->nodes_.size() - 1, i); + const KBest::KBestDerivations, ESentenceTraversal, + KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); if (!d) break; ScoredHyp h; h.w = d->yield; @@ -39,23 +37,22 @@ struct ScoredKbest : public DecoderObserver h.model = log(d->score); h.rank = i; h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_); - s_.push_back(h); - sz_++; - f_count_ += h.f.size(); + samples_.push_back(h); + effective_sz_++; + feature_count_ += h.f.size(); } } - vector* GetSamples() { return &s_; } + vector* GetSamples() { return &samples_; } inline void SetReference(vector& ngs, vector& ls) { ref_ngs_ = &ngs; ref_ls_ = &ls; } - inline size_t GetFeatureCount() { return f_count_; } - inline size_t GetSize() { return sz_; } + inline size_t GetFeatureCount() { return feature_count_; } + inline size_t GetSize() { return effective_sz_; } }; - } // namespace #endif -- cgit v1.2.3