summaryrefslogtreecommitdiff
path: root/training/dtrain/ksampler.h
diff options
context:
space:
mode:
authorAvneesh Saluja <asaluja@gmail.com>2013-03-28 18:28:16 -0700
committerAvneesh Saluja <asaluja@gmail.com>2013-03-28 18:28:16 -0700
commit3d8d656fa7911524e0e6885647173474524e0784 (patch)
tree81b1ee2fcb67980376d03f0aa48e42e53abff222 /training/dtrain/ksampler.h
parentbe7f57fdd484e063775d7abf083b9fa4c403b610 (diff)
parent96fedabebafe7a38a6d5928be8fff767e411d705 (diff)
fixed conflicts
Diffstat (limited to 'training/dtrain/ksampler.h')
-rw-r--r--training/dtrain/ksampler.h61
1 files changed, 61 insertions, 0 deletions
diff --git a/training/dtrain/ksampler.h b/training/dtrain/ksampler.h
new file mode 100644
index 00000000..bc2f56cd
--- /dev/null
+++ b/training/dtrain/ksampler.h
@@ -0,0 +1,61 @@
+#ifndef _DTRAIN_KSAMPLER_H_
+#define _DTRAIN_KSAMPLER_H_
+
+#include "hg_sampler.h" // cdec
+#include "kbestget.h"
+#include "score.h"
+
+namespace dtrain
+{
+
+bool
+cmp_hyp_by_model_d(ScoredHyp a, ScoredHyp b)
+{
+ return a.model > b.model;
+}
+
+struct KSampler : public HypSampler
+{
+ const unsigned k_;
+ vector<ScoredHyp> s_;
+ MT19937* prng_;
+ score_t (*scorer)(NgramCounts&, const unsigned, const unsigned, unsigned, vector<score_t>);
+ unsigned src_len_;
+
+ explicit KSampler(const unsigned k, MT19937* prng) :
+ k_(k), prng_(prng) {}
+
+ virtual void
+ NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
+ {
+ src_len_ = smeta.GetSourceLength();
+ ScoredSamples(*hg);
+ }
+
+ vector<ScoredHyp>* GetSamples() { return &s_; }
+
+ void ScoredSamples(const Hypergraph& forest) {
+ s_.clear(); sz_ = f_count_ = 0;
+ std::vector<HypergraphSampler::Hypothesis> samples;
+ HypergraphSampler::sample_hypotheses(forest, k_, prng_, &samples);
+ for (unsigned i = 0; i < k_; ++i) {
+ ScoredHyp h;
+ h.w = samples[i].words;
+ h.f = samples[i].fmap;
+ h.model = log(samples[i].model_score);
+ h.rank = i;
+ h.score = scorer_->Score(h.w, *ref_, i, src_len_);
+ s_.push_back(h);
+ sz_++;
+ f_count_ += h.f.size();
+ }
+ sort(s_.begin(), s_.end(), cmp_hyp_by_model_d);
+ for (unsigned i = 0; i < s_.size(); i++) s_[i].rank = i;
+ }
+};
+
+
+} // namespace
+
+#endif
+