From d18024a41cbc1b54db88d499571349a6234b6db8 Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Mon, 29 Apr 2013 15:24:39 +0200
Subject: fix, cleaned up headers
---
training/dtrain/dtrain.h | 74 ++++++++++++++++++++++++++++++++++++++----------
1 file changed, 59 insertions(+), 15 deletions(-)
(limited to 'training/dtrain/dtrain.h')
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index eb0b9f17..3981fb39 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -11,16 +11,19 @@
#include
#include
-#include "ksampler.h"
-#include "pairsampling.h"
-
-#include "filelib.h"
-
+#include "decoder.h"
+#include "ff_register.h"
+#include "sentence_metadata.h"
+#include "verbose.h"
+#include "viterbi.h"
using namespace std;
-using namespace dtrain;
namespace po = boost::program_options;
+namespace dtrain
+{
+
+
inline void register_and_convert(const vector& strs, vector& ids)
{
vector::const_iterator it;
@@ -42,17 +45,55 @@ inline string gettmpf(const string path, const string infix)
return string(fn);
}
-inline void split_in(string& s, vector& parts)
+typedef double score_t;
+
+struct ScoredHyp
{
- unsigned f = 0;
- for(unsigned i = 0; i < 3; i++) {
- unsigned e = f;
- f = s.find("\t", f+1);
- if (e != 0) parts.push_back(s.substr(e+1, f-e-1));
- else parts.push_back(s.substr(0, f));
+ vector w;
+ SparseVector f;
+ score_t model;
+ score_t score;
+ unsigned rank;
+};
+
+struct LocalScorer
+{
+ unsigned N_;
+ vector w_;
+
+ virtual score_t
+ Score(vector& hyp, vector& ref, const unsigned rank, const unsigned src_len)=0;
+
+ virtual void Reset() {} // only for ApproxBleuScorer, LinearBleuScorer
+
+ inline void
+ Init(unsigned N, vector weights)
+ {
+ assert(N > 0);
+ N_ = N;
+ if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_);
+ else w_ = weights;
}
- s.erase(0, f+1);
-}
+
+ inline score_t
+ brevity_penalty(const unsigned hyp_len, const unsigned ref_len)
+ {
+ if (hyp_len > ref_len) return 1;
+ return exp(1 - (score_t)ref_len/hyp_len);
+ }
+};
+
+struct HypSampler : public DecoderObserver
+{
+ LocalScorer* scorer_;
+ vector* ref_;
+ unsigned f_count_, sz_;
+ virtual vector* GetSamples()=0;
+ inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
+ inline void SetRef(vector& ref) { ref_ = &ref; }
+ inline unsigned get_f_count() { return f_count_; }
+ inline unsigned get_sz() { return sz_; }
+};
struct HSReporter
{
@@ -88,5 +129,8 @@ inline T sign(T z)
return z < 0 ? -1 : +1;
}
+
+} // namespace
+
#endif
--
cgit v1.2.3