diff options
| author | Michael Denkowski <mdenkows@cs.cmu.edu> | 2013-08-19 08:24:48 -0700 | 
|---|---|---|
| committer | Michael Denkowski <mdenkows@cs.cmu.edu> | 2013-08-19 08:24:48 -0700 | 
| commit | 84a38f1b73c43b3cd22700404bf3882a082ae658 (patch) | |
| tree | ce12abca1071d429ffcad7005fc7f4fa5274ea2f /training/dtrain/dtrain.h | |
| parent | 3a6fa32ca16d0fbdc76e738449bf1b27d866acc6 (diff) | |
| parent | 14b6b5a397dad46080732e8345ba2b1e5593d4cb (diff) | |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'training/dtrain/dtrain.h')
| -rw-r--r-- | training/dtrain/dtrain.h | 74 | 
1 files changed, 59 insertions, 15 deletions
| diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index eb0b9f17..3981fb39 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -11,16 +11,19 @@  #include <boost/algorithm/string.hpp>  #include <boost/program_options.hpp> -#include "ksampler.h" -#include "pairsampling.h" - -#include "filelib.h" - +#include "decoder.h" +#include "ff_register.h" +#include "sentence_metadata.h" +#include "verbose.h" +#include "viterbi.h"  using namespace std; -using namespace dtrain;  namespace po = boost::program_options; +namespace dtrain +{ + +  inline void register_and_convert(const vector<string>& strs, vector<WordID>& ids)  {    vector<string>::const_iterator it; @@ -42,17 +45,55 @@ inline string gettmpf(const string path, const string infix)    return string(fn);  } -inline void split_in(string& s, vector<string>& parts) +typedef double score_t; + +struct ScoredHyp  { -  unsigned f = 0; -  for(unsigned i = 0; i < 3; i++) { -    unsigned e = f; -    f = s.find("\t", f+1); -    if (e != 0) parts.push_back(s.substr(e+1, f-e-1)); -    else parts.push_back(s.substr(0, f)); +  vector<WordID> w; +  SparseVector<double> f; +  score_t model; +  score_t score; +  unsigned rank; +}; + +struct LocalScorer +{ +  unsigned N_; +  vector<score_t> w_; + +  virtual score_t +  Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank, const unsigned src_len)=0; + +  virtual void Reset() {} // only for ApproxBleuScorer, LinearBleuScorer + +  inline void +  Init(unsigned N, vector<score_t> weights) +  { +    assert(N > 0); +    N_ = N; +    if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_); +    else w_ = weights;    } -  s.erase(0, f+1); -} + +  inline score_t +  brevity_penalty(const unsigned hyp_len, const unsigned ref_len) +  { +    if (hyp_len > ref_len) return 1; +    return exp(1 - (score_t)ref_len/hyp_len); +  } +}; + +struct HypSampler : public DecoderObserver +{ +  LocalScorer* scorer_; +  vector<WordID>* ref_; +  unsigned f_count_, sz_; +  virtual vector<ScoredHyp>* GetSamples()=0; +  inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; } +  inline void SetRef(vector<WordID>& ref) { ref_ = &ref; } +  inline unsigned get_f_count() { return f_count_; } +  inline unsigned get_sz() { return sz_; } +};  struct HSReporter  { @@ -88,5 +129,8 @@ inline T sign(T z)    return z < 0 ? -1 : +1;  } + +} // namespace +  #endif | 
