summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.h')
-rw-r--r--training/dtrain/dtrain.h74
1 files changed, 59 insertions, 15 deletions
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index eb0b9f17..3981fb39 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -11,16 +11,19 @@
#include <boost/algorithm/string.hpp>
#include <boost/program_options.hpp>
-#include "ksampler.h"
-#include "pairsampling.h"
-
-#include "filelib.h"
-
+#include "decoder.h"
+#include "ff_register.h"
+#include "sentence_metadata.h"
+#include "verbose.h"
+#include "viterbi.h"
using namespace std;
-using namespace dtrain;
namespace po = boost::program_options;
+namespace dtrain
+{
+
+
inline void register_and_convert(const vector<string>& strs, vector<WordID>& ids)
{
vector<string>::const_iterator it;
@@ -42,17 +45,55 @@ inline string gettmpf(const string path, const string infix)
return string(fn);
}
-inline void split_in(string& s, vector<string>& parts)
+typedef double score_t;
+
+struct ScoredHyp
{
- unsigned f = 0;
- for(unsigned i = 0; i < 3; i++) {
- unsigned e = f;
- f = s.find("\t", f+1);
- if (e != 0) parts.push_back(s.substr(e+1, f-e-1));
- else parts.push_back(s.substr(0, f));
+ vector<WordID> w;
+ SparseVector<double> f;
+ score_t model;
+ score_t score;
+ unsigned rank;
+};
+
+struct LocalScorer
+{
+ unsigned N_;
+ vector<score_t> w_;
+
+ virtual score_t
+ Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank, const unsigned src_len)=0;
+
+ virtual void Reset() {} // only for ApproxBleuScorer, LinearBleuScorer
+
+ inline void
+ Init(unsigned N, vector<score_t> weights)
+ {
+ assert(N > 0);
+ N_ = N;
+ if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_);
+ else w_ = weights;
}
- s.erase(0, f+1);
-}
+
+ inline score_t
+ brevity_penalty(const unsigned hyp_len, const unsigned ref_len)
+ {
+ if (hyp_len > ref_len) return 1;
+ return exp(1 - (score_t)ref_len/hyp_len);
+ }
+};
+
+struct HypSampler : public DecoderObserver
+{
+ LocalScorer* scorer_;
+ vector<WordID>* ref_;
+ unsigned f_count_, sz_;
+ virtual vector<ScoredHyp>* GetSamples()=0;
+ inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
+ inline void SetRef(vector<WordID>& ref) { ref_ = &ref; }
+ inline unsigned get_f_count() { return f_count_; }
+ inline unsigned get_sz() { return sz_; }
+};
struct HSReporter
{
@@ -88,5 +129,8 @@ inline T sign(T z)
return z < 0 ? -1 : +1;
}
+
+} // namespace
+
#endif