From 8bb00a2a2775442418f1cb7c041f7cba5d6e0d42 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Mon, 26 Sep 2011 21:51:52 +0200 Subject: got rid of scoring loop --- dtrain/kbestget.h | 46 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 7 deletions(-) (limited to 'dtrain/kbestget.h') diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h index 2a2c6073..c0fd3f47 100644 --- a/dtrain/kbestget.h +++ b/dtrain/kbestget.h @@ -1,12 +1,6 @@ #ifndef _DTRAIN_KBESTGET_H_ #define _DTRAIN_KBESTGET_H_ - -#include -#include - -using namespace std; - #include "kbest.h" // cdec #include "verbose.h" #include "viterbi.h" @@ -14,11 +8,13 @@ using namespace std; #include "decoder.h" #include "weights.h" +using namespace std; + namespace dtrain { -typedef double score_t; // float +typedef double score_t; // float struct ScoredHyp { @@ -29,10 +25,44 @@ struct ScoredHyp unsigned rank; }; +struct LocalScorer +{ + unsigned N_; + vector w_; + + virtual score_t + Score(vector& hyp, vector& ref)=0; + + void + Init(unsigned N, vector weights) + { + assert(N > 0); + N_ = N; + if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_); + else w_ = weights; + } + + score_t + brevity_penaly(const unsigned hyp_len, const unsigned ref_len) + { + if (hyp_len > ref_len) return 1; + return exp(1 - (score_t)ref_len/hyp_len); + } +}; + struct HypSampler : public DecoderObserver { + LocalScorer* scorer_; + vector* ref_; virtual vector* GetSamples()=0; + void SetScorer(LocalScorer* scorer) { scorer_ = scorer; } + void SetRef(vector& ref) { ref_ = &ref; } }; +///////////////////////////////////////////////////////////////////// +// wtf + + + struct KBestGetter : public HypSampler { @@ -77,6 +107,7 @@ struct KBestGetter : public HypSampler h.f = d->feature_values; h.model = log(d->score); h.rank = i; + h.score = scorer_->Score(h.w, *ref_); s_.push_back(h); } } @@ -95,6 +126,7 @@ struct KBestGetter : public HypSampler h.f = d->feature_values; h.model = log(d->score); h.rank = i; + h.score = scorer_->Score(h.w, *ref_); s_.push_back(h); } } -- cgit v1.2.3