summaryrefslogtreecommitdiff
path: root/dtrain/ksampler.h
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/ksampler.h')
-rw-r--r--dtrain/ksampler.h52
1 files changed, 52 insertions, 0 deletions
diff --git a/dtrain/ksampler.h b/dtrain/ksampler.h
new file mode 100644
index 00000000..a28b69c9
--- /dev/null
+++ b/dtrain/ksampler.h
@@ -0,0 +1,52 @@
+#ifndef _DTRAIN_KSAMPLER_H_
+#define _DTRAIN_KSAMPLER_H_
+
+#include "kbest.h"
+#include "sample_hg.h"
+#include "sampler.h"
+
+namespace dtrain
+{
+
+/*
+ * KSampler
+ *
+ */
+struct KSampler : public DecoderObserver
+{
+ const size_t k_;
+ KBestList kb;
+ MT19937* rng;
+
+ explicit KSampler( const size_t k, MT19937* prng ) :
+ k_(k), rng(prng) {}
+
+ virtual void
+ NotifyTranslationForest( const SentenceMetadata& smeta, Hypergraph* hg )
+ {
+ Sample( *hg );
+ }
+
+ KBestList* GetKBest() { return &kb; }
+
+ void Sample( const Hypergraph& forest ) {
+ kb.sents.clear();
+ kb.feats.clear();
+ kb.model_scores.clear();
+ kb.scores.clear();
+ std::vector<HypergraphSampler::Hypothesis> samples;
+ HypergraphSampler::sample_hypotheses(forest, k_, rng, &samples);
+ for ( size_t i = 0; i < k_; ++i ) {
+ kb.sents.push_back( samples[i].words );
+ kb.feats.push_back( samples[i].fmap );
+ kb.model_scores.push_back( log(samples[i].model_score) );
+ }
+ }
+};
+
+
+} // namespace
+
+
+#endif
+