diff options
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r-- | dtrain/kbestget.h | 84 |
1 files changed, 39 insertions, 45 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h index cf466fe4..79201182 100644 --- a/dtrain/kbestget.h +++ b/dtrain/kbestget.h @@ -1,19 +1,14 @@ #ifndef _DTRAIN_KBESTGET_H_ #define _DTRAIN_KBESTGET_H_ - #include "kbest.h" - namespace dtrain { -/* - * KBestList - * - */ -struct KBestList { +struct Samples +{ vector<SparseVector<double> > feats; vector<vector<WordID> > sents; vector<double> model_scores; @@ -21,71 +16,71 @@ struct KBestList { size_t GetSize() { return sents.size(); } }; +struct HypoSampler : public DecoderObserver +{ + virtual Samples* GetSamples() {} +}; -/* - * KBestGetter - * - */ -struct KBestGetter : public DecoderObserver +struct KBestGetter : public HypoSampler { const size_t k_; const string filter_type; - KBestList kb; + Samples s; - KBestGetter( const size_t k, const string filter_type ) : + KBestGetter(const size_t k, const string filter_type) : k_(k), filter_type(filter_type) {} virtual void - NotifyTranslationForest( const SentenceMetadata& smeta, Hypergraph* hg ) + NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { - KBest( *hg ); + KBest(*hg); } - KBestList* GetKBest() { return &kb; } + Samples* GetSamples() { return &s; } void - KBest( const Hypergraph& forest ) + KBest(const Hypergraph& forest) { - if ( filter_type == "unique" ) { - KBestUnique( forest ); - } else if ( filter_type == "no" ) { - KBestNoFilter( forest ); + if (filter_type == "unique") { + KBestUnique(forest); + } else if (filter_type == "no") { + KBestNoFilter(forest); } } void - KBestUnique( const Hypergraph& forest ) + KBestUnique(const Hypergraph& forest) { - kb.sents.clear(); - kb.feats.clear(); - kb.model_scores.clear(); - kb.scores.clear(); - KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb> kbest( forest, k_ ); - for ( size_t i = 0; i < k_; ++i ) { + s.sents.clear(); + s.feats.clear(); + s.model_scores.clear(); + s.scores.clear(); + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_); + for (size_t i = 0; i < k_; ++i) { const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = - kbest.LazyKthBest( forest.nodes_.size() - 1, i ); + kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; - kb.sents.push_back( d->yield); - kb.feats.push_back( d->feature_values ); - kb.model_scores.push_back( log(d->score) ); + s.sents.push_back(d->yield); + s.feats.push_back(d->feature_values); + s.model_scores.push_back(log(d->score)); } } void - KBestNoFilter( const Hypergraph& forest ) + KBestNoFilter(const Hypergraph& forest) { - kb.sents.clear(); - kb.feats.clear(); - kb.model_scores.clear(); - kb.scores.clear(); - KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest( forest, k_ ); - for ( size_t i = 0; i < k_; ++i ) { + s.sents.clear(); + s.feats.clear(); + s.model_scores.clear(); + s.scores.clear(); + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k_); + for (size_t i = 0; i < k_; ++i) { const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest( forest.nodes_.size() - 1, i ); + kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; - kb.sents.push_back( d->yield); - kb.feats.push_back( d->feature_values ); - kb.model_scores.push_back( log(d->score) ); + s.sents.push_back(d->yield); + s.feats.push_back(d->feature_values); + s.model_scores.push_back(log(d->score)); } } }; @@ -93,6 +88,5 @@ struct KBestGetter : public DecoderObserver } // namespace - #endif |