diff options
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r-- | dtrain/kbestget.h | 46 |
1 files changed, 39 insertions, 7 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h index 2a2c6073..c0fd3f47 100644 --- a/dtrain/kbestget.h +++ b/dtrain/kbestget.h @@ -1,12 +1,6 @@ #ifndef _DTRAIN_KBESTGET_H_ #define _DTRAIN_KBESTGET_H_ - -#include <vector> -#include <string> - -using namespace std; - #include "kbest.h" // cdec #include "verbose.h" #include "viterbi.h" @@ -14,11 +8,13 @@ using namespace std; #include "decoder.h" #include "weights.h" +using namespace std; + namespace dtrain { -typedef double score_t; // float +typedef double score_t; // float struct ScoredHyp { @@ -29,10 +25,44 @@ struct ScoredHyp unsigned rank; }; +struct LocalScorer +{ + unsigned N_; + vector<score_t> w_; + + virtual score_t + Score(vector<WordID>& hyp, vector<WordID>& ref)=0; + + void + Init(unsigned N, vector<score_t> weights) + { + assert(N > 0); + N_ = N; + if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_); + else w_ = weights; + } + + score_t + brevity_penaly(const unsigned hyp_len, const unsigned ref_len) + { + if (hyp_len > ref_len) return 1; + return exp(1 - (score_t)ref_len/hyp_len); + } +}; + struct HypSampler : public DecoderObserver { + LocalScorer* scorer_; + vector<WordID>* ref_; virtual vector<ScoredHyp>* GetSamples()=0; + void SetScorer(LocalScorer* scorer) { scorer_ = scorer; } + void SetRef(vector<WordID>& ref) { ref_ = &ref; } }; +///////////////////////////////////////////////////////////////////// +// wtf + + + struct KBestGetter : public HypSampler { @@ -77,6 +107,7 @@ struct KBestGetter : public HypSampler h.f = d->feature_values; h.model = log(d->score); h.rank = i; + h.score = scorer_->Score(h.w, *ref_); s_.push_back(h); } } @@ -95,6 +126,7 @@ struct KBestGetter : public HypSampler h.f = d->feature_values; h.model = log(d->score); h.rank = i; + h.score = scorer_->Score(h.w, *ref_); s_.push_back(h); } } |