diff options
Diffstat (limited to 'training/dtrain/sample.h')
-rw-r--r-- | training/dtrain/sample.h | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h new file mode 100644 index 00000000..25f02273 --- /dev/null +++ b/training/dtrain/sample.h @@ -0,0 +1,88 @@ +#ifndef _DTRAIN_KBESTGET_H_ +#define _DTRAIN_KBESTGET_H_ + +#include "kbest.h" + +namespace dtrain +{ + + +struct KBestGetter : public HypSampler +{ + const unsigned k_; + const string filter_type_; + vector<ScoredHyp> s_; + unsigned src_len_; + + KBestGetter(const unsigned k, const string filter_type) : + k_(k), filter_type_(filter_type) {} + + virtual void + NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) + { + src_len_ = smeta.GetSourceLength(); + KBestScored(*hg); + } + + vector<ScoredHyp>* GetSamples() { return &s_; } + + void + KBestScored(const Hypergraph& forest) + { + if (filter_type_ == "uniq") { + KBestUnique(forest); + } else if (filter_type_ == "not") { + KBestNoFilter(forest); + } + } + + void + KBestUnique(const Hypergraph& forest) + { + s_.clear(); sz_ = f_count_ = 0; + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, + KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_); + for (unsigned i = 0; i < k_; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, + prob_t, EdgeProb>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + ScoredHyp h; + h.w = d->yield; + h.f = d->feature_values; + h.model = log(d->score); + h.rank = i; + h.score = scorer_->Score(h.w, *refs_, i, src_len_); + s_.push_back(h); + sz_++; + f_count_ += h.f.size(); + } + } + + void + KBestNoFilter(const Hypergraph& forest) + { + s_.clear(); sz_ = f_count_ = 0; + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k_); + for (unsigned i = 0; i < k_; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + ScoredHyp h; + h.w = d->yield; + h.f = d->feature_values; + h.model = log(d->score); + h.rank = i; + h.score = scorer_->Score(h.w, *refs_, i, src_len_); + s_.push_back(h); + sz_++; + f_count_ += h.f.size(); + } + } +}; + + +} // namespace + +#endif + |