diff options
Diffstat (limited to 'training/dtrain/sample.h')
-rw-r--r-- | training/dtrain/sample.h | 62 |
1 files changed, 15 insertions, 47 deletions
diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h index 25f02273..64d93cb0 100644 --- a/training/dtrain/sample.h +++ b/training/dtrain/sample.h @@ -1,5 +1,5 @@ -#ifndef _DTRAIN_KBESTGET_H_ -#define _DTRAIN_KBESTGET_H_ +#ifndef _DTRAIN_SAMPLE_H_ +#define _DTRAIN_SAMPLE_H_ #include "kbest.h" @@ -7,78 +7,46 @@ namespace dtrain { -struct KBestGetter : public HypSampler +struct ScoredKbest : public DecoderObserver { const unsigned k_; - const string filter_type_; vector<ScoredHyp> s_; unsigned src_len_; + PerSentenceBleuScorer* scorer_; + vector<vector<WordID> >* refs_; + unsigned f_count_, sz_; - KBestGetter(const unsigned k, const string filter_type) : - k_(k), filter_type_(filter_type) {} + ScoredKbest(const unsigned k, PerSentenceBleuScorer* scorer) : + k_(k), scorer_(scorer) {} virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { src_len_ = smeta.GetSourceLength(); - KBestScored(*hg); - } - - vector<ScoredHyp>* GetSamples() { return &s_; } - - void - KBestScored(const Hypergraph& forest) - { - if (filter_type_ == "uniq") { - KBestUnique(forest); - } else if (filter_type_ == "not") { - KBestNoFilter(forest); - } - } - - void - KBestUnique(const Hypergraph& forest) - { s_.clear(); sz_ = f_count_ = 0; KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, - KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_); + KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_); for (unsigned i = 0; i < k_; ++i) { const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); + kbest.LazyKthBest(hg->nodes_.size() - 1, i); if (!d) break; ScoredHyp h; h.w = d->yield; h.f = d->feature_values; h.model = log(d->score); h.rank = i; - h.score = scorer_->Score(h.w, *refs_, i, src_len_); + h.score = scorer_->Score(h.w, *refs_); s_.push_back(h); sz_++; f_count_ += h.f.size(); } } - void - KBestNoFilter(const Hypergraph& forest) - { - s_.clear(); sz_ = f_count_ = 0; - KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k_); - for (unsigned i = 0; i < k_; ++i) { - const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - ScoredHyp h; - h.w = d->yield; - h.f = d->feature_values; - h.model = log(d->score); - h.rank = i; - h.score = scorer_->Score(h.w, *refs_, i, src_len_); - s_.push_back(h); - sz_++; - f_count_ += h.f.size(); - } - } + vector<ScoredHyp>* GetSamples() { return &s_; } + inline void SetReference(vector<vector<WordID> >& refs) { refs_ = &refs; } + inline unsigned GetFeatureCount() { return f_count_; } + inline unsigned GetSize() { return sz_; } }; |