diff options
Diffstat (limited to 'dtrain/ksampler.h')
-rw-r--r-- | dtrain/ksampler.h | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/dtrain/ksampler.h b/dtrain/ksampler.h new file mode 100644 index 00000000..a28b69c9 --- /dev/null +++ b/dtrain/ksampler.h @@ -0,0 +1,52 @@ +#ifndef _DTRAIN_KSAMPLER_H_ +#define _DTRAIN_KSAMPLER_H_ + +#include "kbest.h" +#include "sample_hg.h" +#include "sampler.h" + +namespace dtrain +{ + +/* + * KSampler + * + */ +struct KSampler : public DecoderObserver +{ + const size_t k_; + KBestList kb; + MT19937* rng; + + explicit KSampler( const size_t k, MT19937* prng ) : + k_(k), rng(prng) {} + + virtual void + NotifyTranslationForest( const SentenceMetadata& smeta, Hypergraph* hg ) + { + Sample( *hg ); + } + + KBestList* GetKBest() { return &kb; } + + void Sample( const Hypergraph& forest ) { + kb.sents.clear(); + kb.feats.clear(); + kb.model_scores.clear(); + kb.scores.clear(); + std::vector<HypergraphSampler::Hypothesis> samples; + HypergraphSampler::sample_hypotheses(forest, k_, rng, &samples); + for ( size_t i = 0; i < k_; ++i ) { + kb.sents.push_back( samples[i].words ); + kb.feats.push_back( samples[i].fmap ); + kb.model_scores.push_back( log(samples[i].model_score) ); + } + } +}; + + +} // namespace + + +#endif + |