summaryrefslogtreecommitdiff
path: root/training/dtrain/sample.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/sample.h')
-rw-r--r--training/dtrain/sample.h62
1 files changed, 15 insertions, 47 deletions
diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h
index 25f02273..64d93cb0 100644
--- a/training/dtrain/sample.h
+++ b/training/dtrain/sample.h
@@ -1,5 +1,5 @@
-#ifndef _DTRAIN_KBESTGET_H_
-#define _DTRAIN_KBESTGET_H_
+#ifndef _DTRAIN_SAMPLE_H_
+#define _DTRAIN_SAMPLE_H_
#include "kbest.h"
@@ -7,78 +7,46 @@ namespace dtrain
{
-struct KBestGetter : public HypSampler
+struct ScoredKbest : public DecoderObserver
{
const unsigned k_;
- const string filter_type_;
vector<ScoredHyp> s_;
unsigned src_len_;
+ PerSentenceBleuScorer* scorer_;
+ vector<vector<WordID> >* refs_;
+ unsigned f_count_, sz_;
- KBestGetter(const unsigned k, const string filter_type) :
- k_(k), filter_type_(filter_type) {}
+ ScoredKbest(const unsigned k, PerSentenceBleuScorer* scorer) :
+ k_(k), scorer_(scorer) {}
virtual void
NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
{
src_len_ = smeta.GetSourceLength();
- KBestScored(*hg);
- }
-
- vector<ScoredHyp>* GetSamples() { return &s_; }
-
- void
- KBestScored(const Hypergraph& forest)
- {
- if (filter_type_ == "uniq") {
- KBestUnique(forest);
- } else if (filter_type_ == "not") {
- KBestNoFilter(forest);
- }
- }
-
- void
- KBestUnique(const Hypergraph& forest)
- {
s_.clear(); sz_ = f_count_ = 0;
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
- KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_);
+ KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_);
for (unsigned i = 0; i < k_; ++i) {
const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique,
prob_t, EdgeProb>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ kbest.LazyKthBest(hg->nodes_.size() - 1, i);
if (!d) break;
ScoredHyp h;
h.w = d->yield;
h.f = d->feature_values;
h.model = log(d->score);
h.rank = i;
- h.score = scorer_->Score(h.w, *refs_, i, src_len_);
+ h.score = scorer_->Score(h.w, *refs_);
s_.push_back(h);
sz_++;
f_count_ += h.f.size();
}
}
- void
- KBestNoFilter(const Hypergraph& forest)
- {
- s_.clear(); sz_ = f_count_ = 0;
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k_);
- for (unsigned i = 0; i < k_; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
- if (!d) break;
- ScoredHyp h;
- h.w = d->yield;
- h.f = d->feature_values;
- h.model = log(d->score);
- h.rank = i;
- h.score = scorer_->Score(h.w, *refs_, i, src_len_);
- s_.push_back(h);
- sz_++;
- f_count_ += h.f.size();
- }
- }
+ vector<ScoredHyp>* GetSamples() { return &s_; }
+ inline void SetReference(vector<vector<WordID> >& refs) { refs_ = &refs; }
+ inline unsigned GetFeatureCount() { return f_count_; }
+ inline unsigned GetSize() { return sz_; }
};