From c1a892e3fc38de076b4254e9993e9493a0d12b6c Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Fri, 16 Oct 2015 10:26:47 +0200
Subject: dtrain sample.h: forest sampling, non-unique k-best sampling
---
training/dtrain/sample.h | 132 ++++++++++++++++++++++++++++++++++++-----------
1 file changed, 102 insertions(+), 30 deletions(-)
(limited to 'training/dtrain')
diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h
index 1249e372..bccc29b7 100644
--- a/training/dtrain/sample.h
+++ b/training/dtrain/sample.h
@@ -2,55 +2,127 @@
#define _DTRAIN_SAMPLE_H_
#include "kbest.h"
+#include "hg_sampler.h"
#include "score.h"
namespace dtrain
{
-struct ScoredKbest : public DecoderObserver
+struct HypSampler : public DecoderObserver
{
- const size_t k_;
- size_t feature_count_, effective_sz_;
- vector samples_;
- Scorer* scorer_;
- vector* ref_ngs_;
- vector* ref_ls_;
+ size_t feature_count, effective_size;
+ vector sample;
+ vector* reference_ngrams;
+ vector* reference_lengths;
- ScoredKbest(const size_t k, Scorer* scorer) :
- k_(k), scorer_(scorer) {}
+ void
+ reset()
+ {
+ sample.clear();
+ effective_size = feature_count = 0;
+ }
+};
+
+struct KBestSampler : public HypSampler
+{
+ size_t k;
+ bool unique;
+ Scorer* scorer;
+
+ KBestSampler() {}
+ KBestSampler(const size_t k, Scorer* scorer) :
+ k(k), scorer(scorer) {}
virtual void
NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg)
{
- samples_.clear(); effective_sz_ = feature_count_ = 0;
+ reset();
KBest::KBestDerivations, ESentenceTraversal,
- KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_);
- for (size_t i = 0; i < k_; ++i) {
- const KBest::KBestDerivations, ESentenceTraversal,
- KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
- kbest.LazyKthBest(hg->nodes_.size() - 1, i);
+ KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k);
+ for (size_t i=0; i, ESentenceTraversal,
+ KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
+ kbest.LazyKthBest(hg->nodes_.size() - 1, i);
if (!d) break;
- ScoredHyp h;
- h.w = d->yield;
- h.f = d->feature_values;
- h.model = log(d->score);
- h.rank = i;
- h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_);
- samples_.push_back(h);
- effective_sz_++;
- feature_count_ += h.f.size();
+ sample.emplace_back(
+ d->yield,
+ d->feature_values,
+ log(d->score),
+ scorer->score(d->yield, *reference_ngrams, *reference_lengths),
+ i
+ );
+ effective_size++;
+ feature_count += sample.back().f.size();
}
}
+};
- vector* GetSamples() { return &samples_; }
- inline void SetReference(vector& ngs, vector& ls)
+struct KBestNoFilterSampler : public KBestSampler
+{
+ size_t k;
+ bool unique;
+ Scorer* scorer;
+
+ KBestNoFilterSampler(const size_t k, Scorer* scorer) :
+ k(k), scorer(scorer) {}
+
+ virtual void
+ NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg)
{
- ref_ngs_ = &ngs;
- ref_ls_ = &ls;
+ reset();
+ KBest::KBestDerivations, ESentenceTraversal> kbest(*hg, k);
+ for (size_t i=0; i, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg->nodes_.size() - 1, i);
+ if (!d) break;
+ sample.emplace_back(
+ d->yield,
+ d->feature_values,
+ log(d->score),
+ scorer->score(d->yield, *reference_ngrams, *reference_lengths),
+ i
+ );
+ effective_size++;
+ feature_count += sample.back().f.size();
+ }
+ }
+};
+
+struct KSampler : public HypSampler
+{
+ const size_t k;
+ Scorer* scorer;
+ MT19937 rng;
+
+ explicit KSampler(const unsigned k, Scorer* scorer) :
+ k(k), scorer(scorer) {}
+
+ virtual void
+ NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg)
+ {
+ reset();
+ std::vector hs;
+ HypergraphSampler::sample_hypotheses(*hg, k, &rng, &hs);
+ for (size_t i=0; i second.model;
+ });
+ for (unsigned i=0; iscore(sample[i].w, *reference_ngrams, *reference_lengths);
+ }
}
- inline size_t GetFeatureCount() { return feature_count_; }
- inline size_t GetSize() { return effective_sz_; }
};
} // namespace
--
cgit v1.2.3