summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2015-06-23 17:16:34 +0200
committerPatrick Simianer <p@simianer.de>2015-06-23 17:16:34 +0200
commit2f1e5c6106995993c7211c7871126421e60d4909 (patch)
tree67af3447d9e788eed438eee848c0ef23152020bc
parent754d7d23aef09dceab5acd5688f81dab32f5a695 (diff)
sample_net_interace.h
-rw-r--r--training/dtrain/sample_net_interface.h61
1 files changed, 61 insertions, 0 deletions
diff --git a/training/dtrain/sample_net_interface.h b/training/dtrain/sample_net_interface.h
new file mode 100644
index 00000000..497149d9
--- /dev/null
+++ b/training/dtrain/sample_net_interface.h
@@ -0,0 +1,61 @@
+#ifndef _DTRAIN_SAMPLE_NET_H_
+#define _DTRAIN_SAMPLE_NET_H_
+
+#include "kbest.h"
+
+#include "score.h"
+
+namespace dtrain
+{
+
+struct ScoredKbest : public DecoderObserver
+{
+ const size_t k_;
+ size_t feature_count_, effective_sz_;
+ vector<ScoredHyp> samples_;
+ PerSentenceBleuScorer* scorer_;
+ vector<Ngrams>* ref_ngs_;
+ vector<size_t>* ref_ls_;
+ bool dont_score;
+
+ ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) :
+ k_(k), scorer_(scorer), dont_score(false) {}
+
+ virtual void
+ NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
+ {
+ samples_.clear(); effective_sz_ = feature_count_ = 0;
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
+ KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_);
+ for (size_t i = 0; i < k_; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
+ KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
+ kbest.LazyKthBest(hg->nodes_.size() - 1, i);
+ if (!d) break;
+ ScoredHyp h;
+ h.w = d->yield;
+ h.f = d->feature_values;
+ h.model = log(d->score);
+ h.rank = i;
+ if (!dont_score)
+ h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_);
+ samples_.push_back(h);
+ effective_sz_++;
+ feature_count_ += h.f.size();
+ }
+ }
+
+ vector<ScoredHyp>* GetSamples() { return &samples_; }
+ inline void SetReference(vector<Ngrams>& ngs, vector<size_t>& ls)
+ {
+ ref_ngs_ = &ngs;
+ ref_ls_ = &ls;
+ }
+ inline size_t GetFeatureCount() { return feature_count_; }
+ inline size_t GetSize() { return effective_sz_; }
+};
+
+} // namespace
+
+#endif
+