diff options
author | pks <pks@users.noreply.github.com> | 2019-05-12 20:10:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-05-12 20:10:37 +0200 |
commit | 4a13b41700f34c15c30b551f98dbea9cb41f67c3 (patch) | |
tree | 0218f41c350a626f5af9909d77406309fa873fdf /training/dtrain/sample_net_interface.h | |
parent | e9268eb3dcd867f3baf67a7bb3d2aad56196ecde (diff) | |
parent | f64746ac87fc7338629b19de9fa2da0f03fa2790 (diff) |
Merge branch 'net' into origin/net
Diffstat (limited to 'training/dtrain/sample_net_interface.h')
-rw-r--r-- | training/dtrain/sample_net_interface.h | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/training/dtrain/sample_net_interface.h b/training/dtrain/sample_net_interface.h new file mode 100644 index 00000000..6d00e5d5 --- /dev/null +++ b/training/dtrain/sample_net_interface.h @@ -0,0 +1,68 @@ +#ifndef _DTRAIN_SAMPLE_NET_H_ +#define _DTRAIN_SAMPLE_NET_H_ + +#include "kbest.h" + +#include "score_net_interface.h" + +namespace dtrain +{ + +struct ScoredKbest : public DecoderObserver +{ + const size_t k_; + size_t feature_count_, effective_sz_; + vector<ScoredHyp> samples_; + PerSentenceBleuScorer* scorer_; + vector<Ngrams>* ref_ngs_; + vector<size_t>* ref_ls_; + bool dont_score; + string viterbiTreeStr_, viterbiRules_; + + ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) : + k_(k), scorer_(scorer), dont_score(false) {} + + virtual void + NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg) + { + samples_.clear(); effective_sz_ = feature_count_ = 0; + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, + KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_); + for (size_t i = 0; i < k_; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, + KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); + if (!d) break; + ScoredHyp h; + h.w = d->yield; + h.f = d->feature_values; + h.model = log(d->score); + h.rank = i; + if (!dont_score) + h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_); + samples_.push_back(h); + effective_sz_++; + feature_count_ += h.f.size(); + viterbiTreeStr_ = hg->show_viterbi_tree(false); + ostringstream ss; + ViterbiRules(*hg, &ss); + viterbiRules_ = ss.str(); + } + } + + vector<ScoredHyp>* GetSamples() { return &samples_; } + inline void SetReference(vector<Ngrams>& ngs, vector<size_t>& ls) + { + ref_ngs_ = &ngs; + ref_ls_ = &ls; + } + inline size_t GetFeatureCount() { return feature_count_; } + inline size_t GetSize() { return effective_sz_; } + inline string GetViterbiTreeStr() { return viterbiTreeStr_; } + inline string GetViterbiRules() { return viterbiRules_; } +}; + +} // namespace + +#endif + |