diff options
author | Chris Dyer <redpony@gmail.com> | 2013-08-08 13:32:44 -0700 |
---|---|---|
committer | Chris Dyer <redpony@gmail.com> | 2013-08-08 13:32:44 -0700 |
commit | 951e7daa9539ffe640f9421897c374f786af53e7 (patch) | |
tree | 321898257090cc623fa7ea10d81b8e83126a5a0b /training/dtrain/dtrain.h | |
parent | f4a3a2547316ca5d31366e6808858fe94981415c (diff) | |
parent | af2b10dd036aa0088cfef108c1c9713b7e2d9f8f (diff) |
Merge pull request #24 from pks/master
current dtrain version
Diffstat (limited to 'training/dtrain/dtrain.h')
-rw-r--r-- | training/dtrain/dtrain.h | 74 |
1 files changed, 59 insertions, 15 deletions
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index eb0b9f17..3981fb39 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -11,16 +11,19 @@ #include <boost/algorithm/string.hpp> #include <boost/program_options.hpp> -#include "ksampler.h" -#include "pairsampling.h" - -#include "filelib.h" - +#include "decoder.h" +#include "ff_register.h" +#include "sentence_metadata.h" +#include "verbose.h" +#include "viterbi.h" using namespace std; -using namespace dtrain; namespace po = boost::program_options; +namespace dtrain +{ + + inline void register_and_convert(const vector<string>& strs, vector<WordID>& ids) { vector<string>::const_iterator it; @@ -42,17 +45,55 @@ inline string gettmpf(const string path, const string infix) return string(fn); } -inline void split_in(string& s, vector<string>& parts) +typedef double score_t; + +struct ScoredHyp { - unsigned f = 0; - for(unsigned i = 0; i < 3; i++) { - unsigned e = f; - f = s.find("\t", f+1); - if (e != 0) parts.push_back(s.substr(e+1, f-e-1)); - else parts.push_back(s.substr(0, f)); + vector<WordID> w; + SparseVector<double> f; + score_t model; + score_t score; + unsigned rank; +}; + +struct LocalScorer +{ + unsigned N_; + vector<score_t> w_; + + virtual score_t + Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank, const unsigned src_len)=0; + + virtual void Reset() {} // only for ApproxBleuScorer, LinearBleuScorer + + inline void + Init(unsigned N, vector<score_t> weights) + { + assert(N > 0); + N_ = N; + if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_); + else w_ = weights; } - s.erase(0, f+1); -} + + inline score_t + brevity_penalty(const unsigned hyp_len, const unsigned ref_len) + { + if (hyp_len > ref_len) return 1; + return exp(1 - (score_t)ref_len/hyp_len); + } +}; + +struct HypSampler : public DecoderObserver +{ + LocalScorer* scorer_; + vector<WordID>* ref_; + unsigned f_count_, sz_; + virtual vector<ScoredHyp>* GetSamples()=0; + inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; } + inline void SetRef(vector<WordID>& ref) { ref_ = &ref; } + inline unsigned get_f_count() { return f_count_; } + inline unsigned get_sz() { return sz_; } +}; struct HSReporter { @@ -88,5 +129,8 @@ inline T sign(T z) return z < 0 ? -1 : +1; } + +} // namespace + #endif |