summaryrefslogtreecommitdiff
path: root/dtrain/kbestget.h
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-10-13 23:50:28 +0200
committerPatrick Simianer <p@simianer.de>2011-10-13 23:50:28 +0200
commitb03d01f22df3c5e27014bf32748baacc10c7d360 (patch)
treeb3336a13331099e8372af6c51ad508c7772c3177 /dtrain/kbestget.h
parentb9641702ba7aa86e9cc7ed0d4fffa4dd6271cc8f (diff)
fixed approx bleu
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r--dtrain/kbestget.h21
1 files changed, 11 insertions, 10 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h
index c0fd3f47..d141da60 100644
--- a/dtrain/kbestget.h
+++ b/dtrain/kbestget.h
@@ -14,7 +14,7 @@ namespace dtrain
{
-typedef double score_t; // float
+typedef double score_t;
struct ScoredHyp
{
@@ -31,9 +31,11 @@ struct LocalScorer
vector<score_t> w_;
virtual score_t
- Score(vector<WordID>& hyp, vector<WordID>& ref)=0;
+ Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank)=0;
- void
+ void Reset() {} // only for approx bleu
+
+ inline void
Init(unsigned N, vector<score_t> weights)
{
assert(N > 0);
@@ -42,7 +44,7 @@ struct LocalScorer
else w_ = weights;
}
- score_t
+ inline score_t
brevity_penaly(const unsigned hyp_len, const unsigned ref_len)
{
if (hyp_len > ref_len) return 1;
@@ -55,11 +57,10 @@ struct HypSampler : public DecoderObserver
LocalScorer* scorer_;
vector<WordID>* ref_;
virtual vector<ScoredHyp>* GetSamples()=0;
- void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
- void SetRef(vector<WordID>& ref) { ref_ = &ref; }
+ inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
+ inline void SetRef(vector<WordID>& ref) { ref_ = &ref; }
};
-/////////////////////////////////////////////////////////////////////
-// wtf
+///////////////////////////////////////////////////////////////////////////////
@@ -107,7 +108,7 @@ struct KBestGetter : public HypSampler
h.f = d->feature_values;
h.model = log(d->score);
h.rank = i;
- h.score = scorer_->Score(h.w, *ref_);
+ h.score = scorer_->Score(h.w, *ref_, i);
s_.push_back(h);
}
}
@@ -126,7 +127,7 @@ struct KBestGetter : public HypSampler
h.f = d->feature_values;
h.model = log(d->score);
h.rank = i;
- h.score = scorer_->Score(h.w, *ref_);
+ h.score = scorer_->Score(h.w, *ref_, i);
s_.push_back(h);
}
}