From 8aa29810bb77611cc20b7a384897ff6703783ea1 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 18 Nov 2012 13:35:42 -0500 Subject: major restructure of the training code --- dtrain/kbestget.h | 152 ------------------------------------------------------ 1 file changed, 152 deletions(-) delete mode 100644 dtrain/kbestget.h (limited to 'dtrain/kbestget.h') diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h deleted file mode 100644 index dd8882e1..00000000 --- a/dtrain/kbestget.h +++ /dev/null @@ -1,152 +0,0 @@ -#ifndef _DTRAIN_KBESTGET_H_ -#define _DTRAIN_KBESTGET_H_ - -#include "kbest.h" // cdec -#include "sentence_metadata.h" - -#include "verbose.h" -#include "viterbi.h" -#include "ff_register.h" -#include "decoder.h" -#include "weights.h" -#include "logval.h" - -using namespace std; - -namespace dtrain -{ - - -typedef double score_t; - -struct ScoredHyp -{ - vector w; - SparseVector f; - score_t model; - score_t score; - unsigned rank; -}; - -struct LocalScorer -{ - unsigned N_; - vector w_; - - virtual score_t - Score(vector& hyp, vector& ref, const unsigned rank, const unsigned src_len)=0; - - void Reset() {} // only for approx bleu - - inline void - Init(unsigned N, vector weights) - { - assert(N > 0); - N_ = N; - if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_); - else w_ = weights; - } - - inline score_t - brevity_penalty(const unsigned hyp_len, const unsigned ref_len) - { - if (hyp_len > ref_len) return 1; - return exp(1 - (score_t)ref_len/hyp_len); - } -}; - -struct HypSampler : public DecoderObserver -{ - LocalScorer* scorer_; - vector* ref_; - unsigned f_count_, sz_; - virtual vector* GetSamples()=0; - inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; } - inline void SetRef(vector& ref) { ref_ = &ref; } - inline unsigned get_f_count() { return f_count_; } - inline unsigned get_sz() { return sz_; } -}; -//////////////////////////////////////////////////////////////////////////////// - - - - -struct KBestGetter : public HypSampler -{ - const unsigned k_; - const string filter_type_; - vector s_; - unsigned src_len_; - - KBestGetter(const unsigned k, const string filter_type) : - k_(k), filter_type_(filter_type) {} - - virtual void - NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) - { - src_len_ = smeta.GetSourceLength(); - KBestScored(*hg); - } - - vector* GetSamples() { return &s_; } - - void - KBestScored(const Hypergraph& forest) - { - if (filter_type_ == "uniq") { - KBestUnique(forest); - } else if (filter_type_ == "not") { - KBestNoFilter(forest); - } - } - - void - KBestUnique(const Hypergraph& forest) - { - s_.clear(); sz_ = f_count_ = 0; - KBest::KBestDerivations, ESentenceTraversal, - KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_); - for (unsigned i = 0; i < k_; ++i) { - const KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique, - prob_t, EdgeProb>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - ScoredHyp h; - h.w = d->yield; - h.f = d->feature_values; - h.model = log(d->score); - h.rank = i; - h.score = scorer_->Score(h.w, *ref_, i, src_len_); - s_.push_back(h); - sz_++; - f_count_ += h.f.size(); - } - } - - void - KBestNoFilter(const Hypergraph& forest) - { - s_.clear(); sz_ = f_count_ = 0; - KBest::KBestDerivations, ESentenceTraversal> kbest(forest, k_); - for (unsigned i = 0; i < k_; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - ScoredHyp h; - h.w = d->yield; - h.f = d->feature_values; - h.model = log(d->score); - h.rank = i; - h.score = scorer_->Score(h.w, *ref_, i, src_len_); - s_.push_back(h); - sz_++; - f_count_ += h.f.size(); - } - } -}; - - -} // namespace - -#endif - -- cgit v1.2.3