summaryrefslogtreecommitdiff
path: root/dtrain/kbestget.h
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r--dtrain/kbestget.h46
1 files changed, 39 insertions, 7 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h
index 2a2c6073..c0fd3f47 100644
--- a/dtrain/kbestget.h
+++ b/dtrain/kbestget.h
@@ -1,12 +1,6 @@
#ifndef _DTRAIN_KBESTGET_H_
#define _DTRAIN_KBESTGET_H_
-
-#include <vector>
-#include <string>
-
-using namespace std;
-
#include "kbest.h" // cdec
#include "verbose.h"
#include "viterbi.h"
@@ -14,11 +8,13 @@ using namespace std;
#include "decoder.h"
#include "weights.h"
+using namespace std;
+
namespace dtrain
{
-typedef double score_t; // float
+typedef double score_t; // float
struct ScoredHyp
{
@@ -29,10 +25,44 @@ struct ScoredHyp
unsigned rank;
};
+struct LocalScorer
+{
+ unsigned N_;
+ vector<score_t> w_;
+
+ virtual score_t
+ Score(vector<WordID>& hyp, vector<WordID>& ref)=0;
+
+ void
+ Init(unsigned N, vector<score_t> weights)
+ {
+ assert(N > 0);
+ N_ = N;
+ if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_);
+ else w_ = weights;
+ }
+
+ score_t
+ brevity_penaly(const unsigned hyp_len, const unsigned ref_len)
+ {
+ if (hyp_len > ref_len) return 1;
+ return exp(1 - (score_t)ref_len/hyp_len);
+ }
+};
+
struct HypSampler : public DecoderObserver
{
+ LocalScorer* scorer_;
+ vector<WordID>* ref_;
virtual vector<ScoredHyp>* GetSamples()=0;
+ void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
+ void SetRef(vector<WordID>& ref) { ref_ = &ref; }
};
+/////////////////////////////////////////////////////////////////////
+// wtf
+
+
+
struct KBestGetter : public HypSampler
{
@@ -77,6 +107,7 @@ struct KBestGetter : public HypSampler
h.f = d->feature_values;
h.model = log(d->score);
h.rank = i;
+ h.score = scorer_->Score(h.w, *ref_);
s_.push_back(h);
}
}
@@ -95,6 +126,7 @@ struct KBestGetter : public HypSampler
h.f = d->feature_values;
h.model = log(d->score);
h.rank = i;
+ h.score = scorer_->Score(h.w, *ref_);
s_.push_back(h);
}
}