summaryrefslogtreecommitdiff
path: root/training/dtrain/sample.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/sample.h')
-rw-r--r--training/dtrain/sample.h24
1 files changed, 15 insertions, 9 deletions
diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h
index 64d93cb0..c3586c58 100644
--- a/training/dtrain/sample.h
+++ b/training/dtrain/sample.h
@@ -9,14 +9,16 @@ namespace dtrain
struct ScoredKbest : public DecoderObserver
{
- const unsigned k_;
+ const size_t k_;
vector<ScoredHyp> s_;
- unsigned src_len_;
+ size_t src_len_;
PerSentenceBleuScorer* scorer_;
vector<vector<WordID> >* refs_;
- unsigned f_count_, sz_;
+ vector<Ngrams>* ref_ngs_;
+ vector<size_t>* ref_ls_;
+ size_t f_count_, sz_;
- ScoredKbest(const unsigned k, PerSentenceBleuScorer* scorer) :
+ ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) :
k_(k), scorer_(scorer) {}
virtual void
@@ -26,7 +28,7 @@ struct ScoredKbest : public DecoderObserver
s_.clear(); sz_ = f_count_ = 0;
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_);
- for (unsigned i = 0; i < k_; ++i) {
+ for (size_t i = 0; i < k_; ++i) {
const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique,
prob_t, EdgeProb>::Derivation* d =
kbest.LazyKthBest(hg->nodes_.size() - 1, i);
@@ -36,7 +38,7 @@ struct ScoredKbest : public DecoderObserver
h.f = d->feature_values;
h.model = log(d->score);
h.rank = i;
- h.score = scorer_->Score(h.w, *refs_);
+ h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_);
s_.push_back(h);
sz_++;
f_count_ += h.f.size();
@@ -44,9 +46,13 @@ struct ScoredKbest : public DecoderObserver
}
vector<ScoredHyp>* GetSamples() { return &s_; }
- inline void SetReference(vector<vector<WordID> >& refs) { refs_ = &refs; }
- inline unsigned GetFeatureCount() { return f_count_; }
- inline unsigned GetSize() { return sz_; }
+ inline void SetReference(vector<Ngrams>& ngs, vector<size_t>& ls)
+ {
+ ref_ngs_ = &ngs;
+ ref_ls_ = &ls;
+ }
+ inline size_t GetFeatureCount() { return f_count_; }
+ inline size_t GetSize() { return sz_; }
};