summaryrefslogtreecommitdiff
path: root/dtrain/kbestget.h
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2012-11-18 13:35:42 -0500
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2012-11-18 13:35:42 -0500
commit8aa29810bb77611cc20b7a384897ff6703783ea1 (patch)
tree8635daa8fffb3f2cd90e30b41e27f4f9e0909447 /dtrain/kbestget.h
parentfbdacabc85bea65d735f2cb7f92b98e08ce72d04 (diff)
major restructure of the training code
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r--dtrain/kbestget.h152
1 files changed, 0 insertions, 152 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h
deleted file mode 100644
index dd8882e1..00000000
--- a/dtrain/kbestget.h
+++ /dev/null
@@ -1,152 +0,0 @@
-#ifndef _DTRAIN_KBESTGET_H_
-#define _DTRAIN_KBESTGET_H_
-
-#include "kbest.h" // cdec
-#include "sentence_metadata.h"
-
-#include "verbose.h"
-#include "viterbi.h"
-#include "ff_register.h"
-#include "decoder.h"
-#include "weights.h"
-#include "logval.h"
-
-using namespace std;
-
-namespace dtrain
-{
-
-
-typedef double score_t;
-
-struct ScoredHyp
-{
- vector<WordID> w;
- SparseVector<double> f;
- score_t model;
- score_t score;
- unsigned rank;
-};
-
-struct LocalScorer
-{
- unsigned N_;
- vector<score_t> w_;
-
- virtual score_t
- Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank, const unsigned src_len)=0;
-
- void Reset() {} // only for approx bleu
-
- inline void
- Init(unsigned N, vector<score_t> weights)
- {
- assert(N > 0);
- N_ = N;
- if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_);
- else w_ = weights;
- }
-
- inline score_t
- brevity_penalty(const unsigned hyp_len, const unsigned ref_len)
- {
- if (hyp_len > ref_len) return 1;
- return exp(1 - (score_t)ref_len/hyp_len);
- }
-};
-
-struct HypSampler : public DecoderObserver
-{
- LocalScorer* scorer_;
- vector<WordID>* ref_;
- unsigned f_count_, sz_;
- virtual vector<ScoredHyp>* GetSamples()=0;
- inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
- inline void SetRef(vector<WordID>& ref) { ref_ = &ref; }
- inline unsigned get_f_count() { return f_count_; }
- inline unsigned get_sz() { return sz_; }
-};
-////////////////////////////////////////////////////////////////////////////////
-
-
-
-
-struct KBestGetter : public HypSampler
-{
- const unsigned k_;
- const string filter_type_;
- vector<ScoredHyp> s_;
- unsigned src_len_;
-
- KBestGetter(const unsigned k, const string filter_type) :
- k_(k), filter_type_(filter_type) {}
-
- virtual void
- NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
- {
- src_len_ = smeta.GetSourceLength();
- KBestScored(*hg);
- }
-
- vector<ScoredHyp>* GetSamples() { return &s_; }
-
- void
- KBestScored(const Hypergraph& forest)
- {
- if (filter_type_ == "uniq") {
- KBestUnique(forest);
- } else if (filter_type_ == "not") {
- KBestNoFilter(forest);
- }
- }
-
- void
- KBestUnique(const Hypergraph& forest)
- {
- s_.clear(); sz_ = f_count_ = 0;
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
- KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_);
- for (unsigned i = 0; i < k_; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique,
- prob_t, EdgeProb>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
- if (!d) break;
- ScoredHyp h;
- h.w = d->yield;
- h.f = d->feature_values;
- h.model = log(d->score);
- h.rank = i;
- h.score = scorer_->Score(h.w, *ref_, i, src_len_);
- s_.push_back(h);
- sz_++;
- f_count_ += h.f.size();
- }
- }
-
- void
- KBestNoFilter(const Hypergraph& forest)
- {
- s_.clear(); sz_ = f_count_ = 0;
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k_);
- for (unsigned i = 0; i < k_; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
- if (!d) break;
- ScoredHyp h;
- h.w = d->yield;
- h.f = d->feature_values;
- h.model = log(d->score);
- h.rank = i;
- h.score = scorer_->Score(h.w, *ref_, i, src_len_);
- s_.push_back(h);
- sz_++;
- f_count_ += h.f.size();
- }
- }
-};
-
-
-} // namespace
-
-#endif
-