summaryrefslogtreecommitdiff
path: root/dtrain/kbestget.h
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r--dtrain/kbestget.h84
1 files changed, 39 insertions, 45 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h
index cf466fe4..79201182 100644
--- a/dtrain/kbestget.h
+++ b/dtrain/kbestget.h
@@ -1,19 +1,14 @@
#ifndef _DTRAIN_KBESTGET_H_
#define _DTRAIN_KBESTGET_H_
-
#include "kbest.h"
-
namespace dtrain
{
-/*
- * KBestList
- *
- */
-struct KBestList {
+struct Samples
+{
vector<SparseVector<double> > feats;
vector<vector<WordID> > sents;
vector<double> model_scores;
@@ -21,71 +16,71 @@ struct KBestList {
size_t GetSize() { return sents.size(); }
};
+struct HypoSampler : public DecoderObserver
+{
+ virtual Samples* GetSamples() {}
+};
-/*
- * KBestGetter
- *
- */
-struct KBestGetter : public DecoderObserver
+struct KBestGetter : public HypoSampler
{
const size_t k_;
const string filter_type;
- KBestList kb;
+ Samples s;
- KBestGetter( const size_t k, const string filter_type ) :
+ KBestGetter(const size_t k, const string filter_type) :
k_(k), filter_type(filter_type) {}
virtual void
- NotifyTranslationForest( const SentenceMetadata& smeta, Hypergraph* hg )
+ NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
{
- KBest( *hg );
+ KBest(*hg);
}
- KBestList* GetKBest() { return &kb; }
+ Samples* GetSamples() { return &s; }
void
- KBest( const Hypergraph& forest )
+ KBest(const Hypergraph& forest)
{
- if ( filter_type == "unique" ) {
- KBestUnique( forest );
- } else if ( filter_type == "no" ) {
- KBestNoFilter( forest );
+ if (filter_type == "unique") {
+ KBestUnique(forest);
+ } else if (filter_type == "no") {
+ KBestNoFilter(forest);
}
}
void
- KBestUnique( const Hypergraph& forest )
+ KBestUnique(const Hypergraph& forest)
{
- kb.sents.clear();
- kb.feats.clear();
- kb.model_scores.clear();
- kb.scores.clear();
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb> kbest( forest, k_ );
- for ( size_t i = 0; i < k_; ++i ) {
+ s.sents.clear();
+ s.feats.clear();
+ s.model_scores.clear();
+ s.scores.clear();
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_);
+ for (size_t i = 0; i < k_; ++i) {
const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
- kbest.LazyKthBest( forest.nodes_.size() - 1, i );
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- kb.sents.push_back( d->yield);
- kb.feats.push_back( d->feature_values );
- kb.model_scores.push_back( log(d->score) );
+ s.sents.push_back(d->yield);
+ s.feats.push_back(d->feature_values);
+ s.model_scores.push_back(log(d->score));
}
}
void
- KBestNoFilter( const Hypergraph& forest )
+ KBestNoFilter(const Hypergraph& forest)
{
- kb.sents.clear();
- kb.feats.clear();
- kb.model_scores.clear();
- kb.scores.clear();
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest( forest, k_ );
- for ( size_t i = 0; i < k_; ++i ) {
+ s.sents.clear();
+ s.feats.clear();
+ s.model_scores.clear();
+ s.scores.clear();
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k_);
+ for (size_t i = 0; i < k_; ++i) {
const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
- kbest.LazyKthBest( forest.nodes_.size() - 1, i );
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- kb.sents.push_back( d->yield);
- kb.feats.push_back( d->feature_values );
- kb.model_scores.push_back( log(d->score) );
+ s.sents.push_back(d->yield);
+ s.feats.push_back(d->feature_values);
+ s.model_scores.push_back(log(d->score));
}
}
};
@@ -93,6 +88,5 @@ struct KBestGetter : public DecoderObserver
} // namespace
-
#endif