diff options
| author | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 | 
|---|---|---|
| committer | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 | 
| commit | 5b8253e0e1f1393a509fb9975ba8c1347af758ed (patch) | |
| tree | 1790470b1d07a0b4973ebce19192e896566ea60b /training/dtrain/ksampler.h | |
| parent | 2389a5a8a43dda87c355579838559515b0428421 (diff) | |
| parent | b203f8c5dc8cff1b9c9c2073832b248fcad0765a (diff) | |
fixed conflicts
Diffstat (limited to 'training/dtrain/ksampler.h')
| -rw-r--r-- | training/dtrain/ksampler.h | 61 | 
1 files changed, 61 insertions, 0 deletions
| diff --git a/training/dtrain/ksampler.h b/training/dtrain/ksampler.h new file mode 100644 index 00000000..bc2f56cd --- /dev/null +++ b/training/dtrain/ksampler.h @@ -0,0 +1,61 @@ +#ifndef _DTRAIN_KSAMPLER_H_ +#define _DTRAIN_KSAMPLER_H_ + +#include "hg_sampler.h" // cdec +#include "kbestget.h" +#include "score.h" + +namespace dtrain +{ + +bool +cmp_hyp_by_model_d(ScoredHyp a, ScoredHyp b) +{ +  return a.model > b.model; +} + +struct KSampler : public HypSampler +{ +  const unsigned k_; +  vector<ScoredHyp> s_; +  MT19937* prng_; +  score_t (*scorer)(NgramCounts&, const unsigned, const unsigned, unsigned, vector<score_t>); +  unsigned src_len_; + +  explicit KSampler(const unsigned k, MT19937* prng) : +    k_(k), prng_(prng) {} + +  virtual void +  NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) +  { +    src_len_ = smeta.GetSourceLength(); +    ScoredSamples(*hg); +  } + +  vector<ScoredHyp>* GetSamples() { return &s_; } + +  void ScoredSamples(const Hypergraph& forest) { +    s_.clear(); sz_ = f_count_ = 0; +    std::vector<HypergraphSampler::Hypothesis> samples; +    HypergraphSampler::sample_hypotheses(forest, k_, prng_, &samples); +    for (unsigned i = 0; i < k_; ++i) { +      ScoredHyp h; +      h.w = samples[i].words; +      h.f = samples[i].fmap; +      h.model = log(samples[i].model_score); +      h.rank = i; +      h.score = scorer_->Score(h.w, *ref_, i, src_len_); +      s_.push_back(h); +      sz_++; +      f_count_ += h.f.size(); +    } +    sort(s_.begin(), s_.end(), cmp_hyp_by_model_d); +    for (unsigned i = 0; i < s_.size(); i++) s_[i].rank = i; +  } +}; + + +} // namespace + +#endif + | 
