diff options
author | Patrick Simianer <p@simianer.de> | 2011-09-25 21:43:57 +0200 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2011-09-25 21:43:57 +0200 |
commit | 4d8c300734c441821141f4bff044c439e004ff84 (patch) | |
tree | 5b9e2b7f9994d9a71e0e2d17f33ba2ff4a1145a1 /dtrain/kbestget.h | |
parent | fe471bb707226052551d75b043295ca5f57261c0 (diff) |
kbest, ksampler refactoring
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r-- | dtrain/kbestget.h | 55 |
1 files changed, 26 insertions, 29 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h index 79201182..403384de 100644 --- a/dtrain/kbestget.h +++ b/dtrain/kbestget.h @@ -7,28 +7,27 @@ namespace dtrain { -struct Samples +struct ScoredHyp { - vector<SparseVector<double> > feats; - vector<vector<WordID> > sents; - vector<double> model_scores; - vector<double> scores; - size_t GetSize() { return sents.size(); } + vector<WordID> w; + SparseVector<double> f; + score_t model; + score_t score; }; -struct HypoSampler : public DecoderObserver +struct HypSampler : public DecoderObserver { - virtual Samples* GetSamples() {} + virtual vector<ScoredHyp>* GetSamples() {} }; -struct KBestGetter : public HypoSampler +struct KBestGetter : public HypSampler { const size_t k_; - const string filter_type; - Samples s; + const string filter_type_; + vector<ScoredHyp> s_; KBestGetter(const size_t k, const string filter_type) : - k_(k), filter_type(filter_type) {} + k_(k), filter_type_(filter_type) {} virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) @@ -36,14 +35,14 @@ struct KBestGetter : public HypoSampler KBest(*hg); } - Samples* GetSamples() { return &s; } + vector<ScoredHyp>* GetSamples() { return &s_; } void KBest(const Hypergraph& forest) { - if (filter_type == "unique") { + if (filter_type_ == "unique") { KBestUnique(forest); - } else if (filter_type == "no") { + } else if (filter_type_ == "no") { KBestNoFilter(forest); } } @@ -51,36 +50,34 @@ struct KBestGetter : public HypoSampler void KBestUnique(const Hypergraph& forest) { - s.sents.clear(); - s.feats.clear(); - s.model_scores.clear(); - s.scores.clear(); + s_.clear(); KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_); for (size_t i = 0; i < k_; ++i) { const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; - s.sents.push_back(d->yield); - s.feats.push_back(d->feature_values); - s.model_scores.push_back(log(d->score)); + ScoredHyp h; + h.w = d->yield; + h.f = d->feature_values; + h.model = log(d->score); + s_.push_back(h); } } void KBestNoFilter(const Hypergraph& forest) { - s.sents.clear(); - s.feats.clear(); - s.model_scores.clear(); - s.scores.clear(); + s_.clear(); KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k_); for (size_t i = 0; i < k_; ++i) { const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; - s.sents.push_back(d->yield); - s.feats.push_back(d->feature_values); - s.model_scores.push_back(log(d->score)); + ScoredHyp h; + h.w = d->yield; + h.f = d->feature_values; + h.model = log(d->score); + s_.push_back(h); } } }; |