summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/dtrain/sample.h132
1 files changed, 102 insertions, 30 deletions
diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h
index 1249e372..bccc29b7 100644
--- a/training/dtrain/sample.h
+++ b/training/dtrain/sample.h
@@ -2,55 +2,127 @@
#define _DTRAIN_SAMPLE_H_
#include "kbest.h"
+#include "hg_sampler.h"
#include "score.h"
namespace dtrain
{
-struct ScoredKbest : public DecoderObserver
+struct HypSampler : public DecoderObserver
{
- const size_t k_;
- size_t feature_count_, effective_sz_;
- vector<ScoredHyp> samples_;
- Scorer* scorer_;
- vector<Ngrams>* ref_ngs_;
- vector<size_t>* ref_ls_;
+ size_t feature_count, effective_size;
+ vector<Hyp> sample;
+ vector<Ngrams>* reference_ngrams;
+ vector<size_t>* reference_lengths;
- ScoredKbest(const size_t k, Scorer* scorer) :
- k_(k), scorer_(scorer) {}
+ void
+ reset()
+ {
+ sample.clear();
+ effective_size = feature_count = 0;
+ }
+};
+
+struct KBestSampler : public HypSampler
+{
+ size_t k;
+ bool unique;
+ Scorer* scorer;
+
+ KBestSampler() {}
+ KBestSampler(const size_t k, Scorer* scorer) :
+ k(k), scorer(scorer) {}
virtual void
NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg)
{
- samples_.clear(); effective_sz_ = feature_count_ = 0;
+ reset();
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
- KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_);
- for (size_t i = 0; i < k_; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
- KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
- kbest.LazyKthBest(hg->nodes_.size() - 1, i);
+ KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k);
+ for (size_t i=0; i<k; ++i) {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
+ KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
+ kbest.LazyKthBest(hg->nodes_.size() - 1, i);
if (!d) break;
- ScoredHyp h;
- h.w = d->yield;
- h.f = d->feature_values;
- h.model = log(d->score);
- h.rank = i;
- h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_);
- samples_.push_back(h);
- effective_sz_++;
- feature_count_ += h.f.size();
+ sample.emplace_back(
+ d->yield,
+ d->feature_values,
+ log(d->score),
+ scorer->score(d->yield, *reference_ngrams, *reference_lengths),
+ i
+ );
+ effective_size++;
+ feature_count += sample.back().f.size();
}
}
+};
- vector<ScoredHyp>* GetSamples() { return &samples_; }
- inline void SetReference(vector<Ngrams>& ngs, vector<size_t>& ls)
+struct KBestNoFilterSampler : public KBestSampler
+{
+ size_t k;
+ bool unique;
+ Scorer* scorer;
+
+ KBestNoFilterSampler(const size_t k, Scorer* scorer) :
+ k(k), scorer(scorer) {}
+
+ virtual void
+ NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg)
{
- ref_ngs_ = &ngs;
- ref_ls_ = &ls;
+ reset();
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(*hg, k);
+ for (size_t i=0; i<k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg->nodes_.size() - 1, i);
+ if (!d) break;
+ sample.emplace_back(
+ d->yield,
+ d->feature_values,
+ log(d->score),
+ scorer->score(d->yield, *reference_ngrams, *reference_lengths),
+ i
+ );
+ effective_size++;
+ feature_count += sample.back().f.size();
+ }
+ }
+};
+
+struct KSampler : public HypSampler
+{
+ const size_t k;
+ Scorer* scorer;
+ MT19937 rng;
+
+ explicit KSampler(const unsigned k, Scorer* scorer) :
+ k(k), scorer(scorer) {}
+
+ virtual void
+ NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg)
+ {
+ reset();
+ std::vector<HypergraphSampler::Hypothesis> hs;
+ HypergraphSampler::sample_hypotheses(*hg, k, &rng, &hs);
+ for (size_t i=0; i<k; ++i) {
+ sample.emplace_back(
+ hs[i].words,
+ hs[i].fmap,
+ log(hs[i].model_score),
+ 0,
+ 0
+ );
+ effective_size++;
+ feature_count += sample.back().f.size();
+ }
+ sort(sample.begin(), sample.end(), [](Hyp& first, Hyp& second) {
+ return first.model > second.model;
+ });
+ for (unsigned i=0; i<sample.size(); i++) {
+ sample[i].rank=i;
+ scorer->score(sample[i].w, *reference_ngrams, *reference_lengths);
+ }
}
- inline size_t GetFeatureCount() { return feature_count_; }
- inline size_t GetSize() { return effective_sz_; }
};
} // namespace