From c1a892e3fc38de076b4254e9993e9493a0d12b6c Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Fri, 16 Oct 2015 10:26:47 +0200 Subject: dtrain sample.h: forest sampling, non-unique k-best sampling --- training/dtrain/sample.h | 132 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 102 insertions(+), 30 deletions(-) (limited to 'training/dtrain') diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h index 1249e372..bccc29b7 100644 --- a/training/dtrain/sample.h +++ b/training/dtrain/sample.h @@ -2,55 +2,127 @@ #define _DTRAIN_SAMPLE_H_ #include "kbest.h" +#include "hg_sampler.h" #include "score.h" namespace dtrain { -struct ScoredKbest : public DecoderObserver +struct HypSampler : public DecoderObserver { - const size_t k_; - size_t feature_count_, effective_sz_; - vector samples_; - Scorer* scorer_; - vector* ref_ngs_; - vector* ref_ls_; + size_t feature_count, effective_size; + vector sample; + vector* reference_ngrams; + vector* reference_lengths; - ScoredKbest(const size_t k, Scorer* scorer) : - k_(k), scorer_(scorer) {} + void + reset() + { + sample.clear(); + effective_size = feature_count = 0; + } +}; + +struct KBestSampler : public HypSampler +{ + size_t k; + bool unique; + Scorer* scorer; + + KBestSampler() {} + KBestSampler(const size_t k, Scorer* scorer) : + k(k), scorer(scorer) {} virtual void NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg) { - samples_.clear(); effective_sz_ = feature_count_ = 0; + reset(); KBest::KBestDerivations, ESentenceTraversal, - KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_); - for (size_t i = 0; i < k_; ++i) { - const KBest::KBestDerivations, ESentenceTraversal, - KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = - kbest.LazyKthBest(hg->nodes_.size() - 1, i); + KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k); + for (size_t i=0; i, ESentenceTraversal, + KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); if (!d) break; - ScoredHyp h; - h.w = d->yield; - h.f = d->feature_values; - h.model = log(d->score); - h.rank = i; - h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_); - samples_.push_back(h); - effective_sz_++; - feature_count_ += h.f.size(); + sample.emplace_back( + d->yield, + d->feature_values, + log(d->score), + scorer->score(d->yield, *reference_ngrams, *reference_lengths), + i + ); + effective_size++; + feature_count += sample.back().f.size(); } } +}; - vector* GetSamples() { return &samples_; } - inline void SetReference(vector& ngs, vector& ls) +struct KBestNoFilterSampler : public KBestSampler +{ + size_t k; + bool unique; + Scorer* scorer; + + KBestNoFilterSampler(const size_t k, Scorer* scorer) : + k(k), scorer(scorer) {} + + virtual void + NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg) { - ref_ngs_ = &ngs; - ref_ls_ = &ls; + reset(); + KBest::KBestDerivations, ESentenceTraversal> kbest(*hg, k); + for (size_t i=0; i, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); + if (!d) break; + sample.emplace_back( + d->yield, + d->feature_values, + log(d->score), + scorer->score(d->yield, *reference_ngrams, *reference_lengths), + i + ); + effective_size++; + feature_count += sample.back().f.size(); + } + } +}; + +struct KSampler : public HypSampler +{ + const size_t k; + Scorer* scorer; + MT19937 rng; + + explicit KSampler(const unsigned k, Scorer* scorer) : + k(k), scorer(scorer) {} + + virtual void + NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg) + { + reset(); + std::vector hs; + HypergraphSampler::sample_hypotheses(*hg, k, &rng, &hs); + for (size_t i=0; i second.model; + }); + for (unsigned i=0; iscore(sample[i].w, *reference_ngrams, *reference_lengths); + } } - inline size_t GetFeatureCount() { return feature_count_; } - inline size_t GetSize() { return effective_sz_; } }; } // namespace -- cgit v1.2.3