summaryrefslogtreecommitdiff
path: root/dtrain/kbestget.h
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-09-25 21:43:57 +0200
committerPatrick Simianer <p@simianer.de>2011-09-25 21:43:57 +0200
commit9fe6d978313f742477863ff42b9158cf2f55414f (patch)
tree468a1009193b6122eb3fb3941a026b53e3c201a0 /dtrain/kbestget.h
parent43e7ecdca09f4125346f64d45e44f440ac964421 (diff)
kbest, ksampler refactoring
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r--dtrain/kbestget.h55
1 files changed, 26 insertions, 29 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h
index 79201182..403384de 100644
--- a/dtrain/kbestget.h
+++ b/dtrain/kbestget.h
@@ -7,28 +7,27 @@ namespace dtrain
{
-struct Samples
+struct ScoredHyp
{
- vector<SparseVector<double> > feats;
- vector<vector<WordID> > sents;
- vector<double> model_scores;
- vector<double> scores;
- size_t GetSize() { return sents.size(); }
+ vector<WordID> w;
+ SparseVector<double> f;
+ score_t model;
+ score_t score;
};
-struct HypoSampler : public DecoderObserver
+struct HypSampler : public DecoderObserver
{
- virtual Samples* GetSamples() {}
+ virtual vector<ScoredHyp>* GetSamples() {}
};
-struct KBestGetter : public HypoSampler
+struct KBestGetter : public HypSampler
{
const size_t k_;
- const string filter_type;
- Samples s;
+ const string filter_type_;
+ vector<ScoredHyp> s_;
KBestGetter(const size_t k, const string filter_type) :
- k_(k), filter_type(filter_type) {}
+ k_(k), filter_type_(filter_type) {}
virtual void
NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
@@ -36,14 +35,14 @@ struct KBestGetter : public HypoSampler
KBest(*hg);
}
- Samples* GetSamples() { return &s; }
+ vector<ScoredHyp>* GetSamples() { return &s_; }
void
KBest(const Hypergraph& forest)
{
- if (filter_type == "unique") {
+ if (filter_type_ == "unique") {
KBestUnique(forest);
- } else if (filter_type == "no") {
+ } else if (filter_type_ == "no") {
KBestNoFilter(forest);
}
}
@@ -51,36 +50,34 @@ struct KBestGetter : public HypoSampler
void
KBestUnique(const Hypergraph& forest)
{
- s.sents.clear();
- s.feats.clear();
- s.model_scores.clear();
- s.scores.clear();
+ s_.clear();
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_);
for (size_t i = 0; i < k_; ++i) {
const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- s.sents.push_back(d->yield);
- s.feats.push_back(d->feature_values);
- s.model_scores.push_back(log(d->score));
+ ScoredHyp h;
+ h.w = d->yield;
+ h.f = d->feature_values;
+ h.model = log(d->score);
+ s_.push_back(h);
}
}
void
KBestNoFilter(const Hypergraph& forest)
{
- s.sents.clear();
- s.feats.clear();
- s.model_scores.clear();
- s.scores.clear();
+ s_.clear();
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k_);
for (size_t i = 0; i < k_; ++i) {
const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- s.sents.push_back(d->yield);
- s.feats.push_back(d->feature_values);
- s.model_scores.push_back(log(d->score));
+ ScoredHyp h;
+ h.w = d->yield;
+ h.f = d->feature_values;
+ h.model = log(d->score);
+ s_.push_back(h);
}
}
};