From 2f1e5c6106995993c7211c7871126421e60d4909 Mon Sep 17 00:00:00 2001
From: Patrick Simianer <p@simianer.de>
Date: Tue, 23 Jun 2015 17:16:34 +0200
Subject: sample_net_interace.h

---
 training/dtrain/sample_net_interface.h | 61 ++++++++++++++++++++++++++++++++++
 1 file changed, 61 insertions(+)
 create mode 100644 training/dtrain/sample_net_interface.h

(limited to 'training/dtrain')

diff --git a/training/dtrain/sample_net_interface.h b/training/dtrain/sample_net_interface.h
new file mode 100644
index 00000000..497149d9
--- /dev/null
+++ b/training/dtrain/sample_net_interface.h
@@ -0,0 +1,61 @@
+#ifndef _DTRAIN_SAMPLE_NET_H_
+#define _DTRAIN_SAMPLE_NET_H_
+
+#include "kbest.h"
+
+#include "score.h"
+
+namespace dtrain
+{
+
+struct ScoredKbest : public DecoderObserver
+{
+  const size_t k_;
+  size_t feature_count_, effective_sz_;
+  vector<ScoredHyp> samples_;
+  PerSentenceBleuScorer* scorer_;
+  vector<Ngrams>* ref_ngs_;
+  vector<size_t>* ref_ls_;
+  bool dont_score;
+
+  ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) :
+    k_(k), scorer_(scorer), dont_score(false) {}
+
+  virtual void
+  NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
+  {
+    samples_.clear(); effective_sz_ = feature_count_ = 0;
+    KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
+      KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_);
+    for (size_t i = 0; i < k_; ++i) {
+      const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
+            KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
+              kbest.LazyKthBest(hg->nodes_.size() - 1, i);
+      if (!d) break;
+      ScoredHyp h;
+      h.w = d->yield;
+      h.f = d->feature_values;
+      h.model = log(d->score);
+      h.rank = i;
+      if (!dont_score)
+        h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_);
+      samples_.push_back(h);
+      effective_sz_++;
+      feature_count_ += h.f.size();
+    }
+  }
+
+  vector<ScoredHyp>* GetSamples() { return &samples_; }
+  inline void SetReference(vector<Ngrams>& ngs, vector<size_t>& ls)
+  {
+    ref_ngs_ = &ngs;
+    ref_ls_ = &ls;
+  }
+  inline size_t GetFeatureCount() { return feature_count_; }
+  inline size_t GetSize() { return effective_sz_; }
+};
+
+} // namespace
+
+#endif
+
-- 
cgit v1.2.3