summaryrefslogtreecommitdiff
path: root/training/dtrain/sample.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/sample.h')
-rw-r--r--training/dtrain/sample.h31
1 files changed, 14 insertions, 17 deletions
diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h
index c3586c58..03cc82c3 100644
--- a/training/dtrain/sample.h
+++ b/training/dtrain/sample.h
@@ -3,20 +3,19 @@
#include "kbest.h"
+#include "score.h"
+
namespace dtrain
{
-
struct ScoredKbest : public DecoderObserver
{
const size_t k_;
- vector<ScoredHyp> s_;
- size_t src_len_;
+ size_t feature_count_, effective_sz_;
+ vector<ScoredHyp> samples_;
PerSentenceBleuScorer* scorer_;
- vector<vector<WordID> >* refs_;
vector<Ngrams>* ref_ngs_;
vector<size_t>* ref_ls_;
- size_t f_count_, sz_;
ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) :
k_(k), scorer_(scorer) {}
@@ -24,14 +23,13 @@ struct ScoredKbest : public DecoderObserver
virtual void
NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
{
- src_len_ = smeta.GetSourceLength();
- s_.clear(); sz_ = f_count_ = 0;
+ samples_.clear(); effective_sz_ = feature_count_ = 0;
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_);
for (size_t i = 0; i < k_; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique,
- prob_t, EdgeProb>::Derivation* d =
- kbest.LazyKthBest(hg->nodes_.size() - 1, i);
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
+ KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
+ kbest.LazyKthBest(hg->nodes_.size() - 1, i);
if (!d) break;
ScoredHyp h;
h.w = d->yield;
@@ -39,23 +37,22 @@ struct ScoredKbest : public DecoderObserver
h.model = log(d->score);
h.rank = i;
h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_);
- s_.push_back(h);
- sz_++;
- f_count_ += h.f.size();
+ samples_.push_back(h);
+ effective_sz_++;
+ feature_count_ += h.f.size();
}
}
- vector<ScoredHyp>* GetSamples() { return &s_; }
+ vector<ScoredHyp>* GetSamples() { return &samples_; }
inline void SetReference(vector<Ngrams>& ngs, vector<size_t>& ls)
{
ref_ngs_ = &ngs;
ref_ls_ = &ls;
}
- inline size_t GetFeatureCount() { return f_count_; }
- inline size_t GetSize() { return sz_; }
+ inline size_t GetFeatureCount() { return feature_count_; }
+ inline size_t GetSize() { return effective_sz_; }
};
-
} // namespace
#endif