diff options
author | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 |
---|---|---|
committer | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 |
commit | 3d8d656fa7911524e0e6885647173474524e0784 (patch) | |
tree | 81b1ee2fcb67980376d03f0aa48e42e53abff222 /training/dtrain/ksampler.h | |
parent | be7f57fdd484e063775d7abf083b9fa4c403b610 (diff) | |
parent | 96fedabebafe7a38a6d5928be8fff767e411d705 (diff) |
fixed conflicts
Diffstat (limited to 'training/dtrain/ksampler.h')
-rw-r--r-- | training/dtrain/ksampler.h | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/training/dtrain/ksampler.h b/training/dtrain/ksampler.h new file mode 100644 index 00000000..bc2f56cd --- /dev/null +++ b/training/dtrain/ksampler.h @@ -0,0 +1,61 @@ +#ifndef _DTRAIN_KSAMPLER_H_ +#define _DTRAIN_KSAMPLER_H_ + +#include "hg_sampler.h" // cdec +#include "kbestget.h" +#include "score.h" + +namespace dtrain +{ + +bool +cmp_hyp_by_model_d(ScoredHyp a, ScoredHyp b) +{ + return a.model > b.model; +} + +struct KSampler : public HypSampler +{ + const unsigned k_; + vector<ScoredHyp> s_; + MT19937* prng_; + score_t (*scorer)(NgramCounts&, const unsigned, const unsigned, unsigned, vector<score_t>); + unsigned src_len_; + + explicit KSampler(const unsigned k, MT19937* prng) : + k_(k), prng_(prng) {} + + virtual void + NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) + { + src_len_ = smeta.GetSourceLength(); + ScoredSamples(*hg); + } + + vector<ScoredHyp>* GetSamples() { return &s_; } + + void ScoredSamples(const Hypergraph& forest) { + s_.clear(); sz_ = f_count_ = 0; + std::vector<HypergraphSampler::Hypothesis> samples; + HypergraphSampler::sample_hypotheses(forest, k_, prng_, &samples); + for (unsigned i = 0; i < k_; ++i) { + ScoredHyp h; + h.w = samples[i].words; + h.f = samples[i].fmap; + h.model = log(samples[i].model_score); + h.rank = i; + h.score = scorer_->Score(h.w, *ref_, i, src_len_); + s_.push_back(h); + sz_++; + f_count_ += h.f.size(); + } + sort(s_.begin(), s_.end(), cmp_hyp_by_model_d); + for (unsigned i = 0; i < s_.size(); i++) s_[i].rank = i; + } +}; + + +} // namespace + +#endif + |