diff options
author | Patrick Simianer <p@simianer.de> | 2011-09-04 23:40:44 +0200 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2011-09-23 19:13:58 +0200 |
commit | f9b3b81d7b5a817f39238c1d1fcbaa981d729f4a (patch) | |
tree | 0f573ab4b1969f7b5520ed29b5f1731f2e351e1f /dtrain/kbestget.h | |
parent | aceb387526478e34e41db6c046f707234953e0b5 (diff) |
minor updates, fixes, kbest filtering switch
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r-- | dtrain/kbestget.h | 41 |
1 files changed, 35 insertions, 6 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h index 5247a2be..bb430b85 100644 --- a/dtrain/kbestget.h +++ b/dtrain/kbestget.h @@ -24,29 +24,58 @@ struct KBestList { */ struct KBestGetter : public DecoderObserver { - KBestGetter( const size_t k ) : k_(k) {} const size_t k_; + const string filter_type; KBestList kb; + KBestGetter( const size_t k, const string filter_type ) : + k_(k), filter_type(filter_type) {} + virtual void - NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) + NotifyTranslationForest( const SentenceMetadata& smeta, Hypergraph* hg ) { - GetKBest(smeta.GetSentenceID(), *hg); + KBest( *hg ); } KBestList* GetKBest() { return &kb; } void - GetKBest(int sid, const Hypergraph& forest) + KBest( const Hypergraph& forest ) + { + if ( filter_type == "unique" ) { + KBestUnique( forest ); + } else if ( filter_type == "no" ) { + KBestNoFilter( forest ); + } + } + + void + KBestUnique( const Hypergraph& forest ) { kb.scores.clear(); kb.sents.clear(); kb.feats.clear(); - // FIXME TODO FIXME TODO KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb> kbest( forest, k_ ); for ( size_t i = 0; i < k_; ++i ) { const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = - kbest.LazyKthBest( forest.nodes_.size() - 1, i ); + kbest.LazyKthBest( forest.nodes_.size() - 1, i ); + if (!d) break; + kb.sents.push_back( d->yield); + kb.feats.push_back( d->feature_values ); + kb.scores.push_back( d->score ); + } + } + + void + KBestNoFilter( const Hypergraph& forest ) + { + kb.scores.clear(); + kb.sents.clear(); + kb.feats.clear(); + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest( forest, k_ ); + for ( size_t i = 0; i < k_; ++i ) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest( forest.nodes_.size() - 1, i ); if (!d) break; kb.sents.push_back( d->yield); kb.feats.push_back( d->feature_values ); |