diff options
Diffstat (limited to 'training/dtrain')
| -rw-r--r-- | training/dtrain/sample_net_interface.h | 61 | 
1 files changed, 61 insertions, 0 deletions
diff --git a/training/dtrain/sample_net_interface.h b/training/dtrain/sample_net_interface.h new file mode 100644 index 00000000..497149d9 --- /dev/null +++ b/training/dtrain/sample_net_interface.h @@ -0,0 +1,61 @@ +#ifndef _DTRAIN_SAMPLE_NET_H_ +#define _DTRAIN_SAMPLE_NET_H_ + +#include "kbest.h" + +#include "score.h" + +namespace dtrain +{ + +struct ScoredKbest : public DecoderObserver +{ +  const size_t k_; +  size_t feature_count_, effective_sz_; +  vector<ScoredHyp> samples_; +  PerSentenceBleuScorer* scorer_; +  vector<Ngrams>* ref_ngs_; +  vector<size_t>* ref_ls_; +  bool dont_score; + +  ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) : +    k_(k), scorer_(scorer), dont_score(false) {} + +  virtual void +  NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) +  { +    samples_.clear(); effective_sz_ = feature_count_ = 0; +    KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, +      KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_); +    for (size_t i = 0; i < k_; ++i) { +      const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, +            KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = +              kbest.LazyKthBest(hg->nodes_.size() - 1, i); +      if (!d) break; +      ScoredHyp h; +      h.w = d->yield; +      h.f = d->feature_values; +      h.model = log(d->score); +      h.rank = i; +      if (!dont_score) +        h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_); +      samples_.push_back(h); +      effective_sz_++; +      feature_count_ += h.f.size(); +    } +  } + +  vector<ScoredHyp>* GetSamples() { return &samples_; } +  inline void SetReference(vector<Ngrams>& ngs, vector<size_t>& ls) +  { +    ref_ngs_ = &ngs; +    ref_ls_ = &ls; +  } +  inline size_t GetFeatureCount() { return feature_count_; } +  inline size_t GetSize() { return effective_sz_; } +}; + +} // namespace + +#endif +  | 
