diff options
-rw-r--r-- | training/dtrain/sample.h | 132 |
1 files changed, 102 insertions, 30 deletions
diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h index 1249e372..bccc29b7 100644 --- a/training/dtrain/sample.h +++ b/training/dtrain/sample.h @@ -2,55 +2,127 @@ #define _DTRAIN_SAMPLE_H_ #include "kbest.h" +#include "hg_sampler.h" #include "score.h" namespace dtrain { -struct ScoredKbest : public DecoderObserver +struct HypSampler : public DecoderObserver { - const size_t k_; - size_t feature_count_, effective_sz_; - vector<ScoredHyp> samples_; - Scorer* scorer_; - vector<Ngrams>* ref_ngs_; - vector<size_t>* ref_ls_; + size_t feature_count, effective_size; + vector<Hyp> sample; + vector<Ngrams>* reference_ngrams; + vector<size_t>* reference_lengths; - ScoredKbest(const size_t k, Scorer* scorer) : - k_(k), scorer_(scorer) {} + void + reset() + { + sample.clear(); + effective_size = feature_count = 0; + } +}; + +struct KBestSampler : public HypSampler +{ + size_t k; + bool unique; + Scorer* scorer; + + KBestSampler() {} + KBestSampler(const size_t k, Scorer* scorer) : + k(k), scorer(scorer) {} virtual void NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg) { - samples_.clear(); effective_sz_ = feature_count_ = 0; + reset(); KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, - KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_); - for (size_t i = 0; i < k_; ++i) { - const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, - KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = - kbest.LazyKthBest(hg->nodes_.size() - 1, i); + KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k); + for (size_t i=0; i<k; ++i) { + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, + KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); if (!d) break; - ScoredHyp h; - h.w = d->yield; - h.f = d->feature_values; - h.model = log(d->score); - h.rank = i; - h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_); - samples_.push_back(h); - effective_sz_++; - feature_count_ += h.f.size(); + sample.emplace_back( + d->yield, + d->feature_values, + log(d->score), + scorer->score(d->yield, *reference_ngrams, *reference_lengths), + i + ); + effective_size++; + feature_count += sample.back().f.size(); } } +}; - vector<ScoredHyp>* GetSamples() { return &samples_; } - inline void SetReference(vector<Ngrams>& ngs, vector<size_t>& ls) +struct KBestNoFilterSampler : public KBestSampler +{ + size_t k; + bool unique; + Scorer* scorer; + + KBestNoFilterSampler(const size_t k, Scorer* scorer) : + k(k), scorer(scorer) {} + + virtual void + NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg) { - ref_ngs_ = &ngs; - ref_ls_ = &ls; + reset(); + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(*hg, k); + for (size_t i=0; i<k; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); + if (!d) break; + sample.emplace_back( + d->yield, + d->feature_values, + log(d->score), + scorer->score(d->yield, *reference_ngrams, *reference_lengths), + i + ); + effective_size++; + feature_count += sample.back().f.size(); + } + } +}; + +struct KSampler : public HypSampler +{ + const size_t k; + Scorer* scorer; + MT19937 rng; + + explicit KSampler(const unsigned k, Scorer* scorer) : + k(k), scorer(scorer) {} + + virtual void + NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg) + { + reset(); + std::vector<HypergraphSampler::Hypothesis> hs; + HypergraphSampler::sample_hypotheses(*hg, k, &rng, &hs); + for (size_t i=0; i<k; ++i) { + sample.emplace_back( + hs[i].words, + hs[i].fmap, + log(hs[i].model_score), + 0, + 0 + ); + effective_size++; + feature_count += sample.back().f.size(); + } + sort(sample.begin(), sample.end(), [](Hyp& first, Hyp& second) { + return first.model > second.model; + }); + for (unsigned i=0; i<sample.size(); i++) { + sample[i].rank=i; + scorer->score(sample[i].w, *reference_ngrams, *reference_lengths); + } } - inline size_t GetFeatureCount() { return feature_count_; } - inline size_t GetSize() { return effective_sz_; } }; } // namespace |