diff options
| -rw-r--r-- | mteval/scorer.cc | 13 | ||||
| -rw-r--r-- | mteval/scorer.h | 13 | ||||
| -rw-r--r-- | training/mira/kbest_cut_mira.cc | 8 | 
3 files changed, 18 insertions, 16 deletions
diff --git a/mteval/scorer.cc b/mteval/scorer.cc index ced0cadf..de84e076 100644 --- a/mteval/scorer.cc +++ b/mteval/scorer.cc @@ -595,7 +595,6 @@ void DocScorer::Init(        const vector<string>& ref_files,        const string& src_file, bool verbose) {    scorers_.clear(); -  this->type = type;    // TODO stop using valarray, start using ReadFile    cerr << "Loading references (" << ref_files.size() << " files)\n";    ReadFile srcrf; @@ -645,6 +644,9 @@ void DocScorer::Init(    cerr << "Loaded reference translations for " << scorers_.size() << " sentences.\n";  } +DocStreamScorer::~DocStreamScorer() { +} +  void DocStreamScorer::Init(        const ScoreType type,        const vector<string>& ref_files, @@ -655,17 +657,14 @@ void DocStreamScorer::Init(    this->type = type;    vector<vector<WordID> > refs(1);    string src_line; -  // Empty reference 0 +  // Initialize empty reference    TD::ConvertSentence("", &refs[0]); -  scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line))); -  // Reference 1 starts empty, updated as needed -  scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line))); +  scorer = ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line));  }  void DocStreamScorer::update(const std::string& ref) { -	scorers_.pop_back();  	vector<vector<WordID> > refs(1);  	string src_line;  	TD::ConvertSentence(ref, &refs[0]); -	scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(this->type, refs, src_line))); +	scorer = ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line));  } diff --git a/mteval/scorer.h b/mteval/scorer.h index 56c39a7d..bb1e89ae 100644 --- a/mteval/scorer.h +++ b/mteval/scorer.h @@ -101,16 +101,16 @@ class DocScorer {      Init(type,ref_files,src_file,verbose);    } -  int size() const { return scorers_.size(); } -  ScorerP operator[](size_t i) const { return scorers_[i]; } +  virtual int size() const { return scorers_.size(); } +  virtual ScorerP operator[](size_t i) const { return scorers_[i]; }    virtual void update(const std::string& ref) {}   private: -  ScoreType type;    std::vector<ScorerP> scorers_;  };  class DocStreamScorer : public DocScorer {  	public: +		~DocStreamScorer();  		void Init(const ScoreType type,  					const std::vector<std::string>& ref_files,  					const std::string& src_file = "", @@ -124,9 +124,12 @@ class DocStreamScorer : public DocScorer {  		{  			Init(type,ref_files,src_file,verbose);  		} -		ScorerP operator[](size_t i); -		int size(); +		ScorerP operator[](size_t i) const { return scorer; } +		int size() const { return 1; }  		void update(const std::string& ref); +	private: +		ScoreType type; +		ScorerP scorer;  };  #endif diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc index d8c42db7..e4435abb 100644 --- a/training/mira/kbest_cut_mira.cc +++ b/training/mira/kbest_cut_mira.cc @@ -400,7 +400,7 @@ struct TrainingObserver : public DecoderObserver {    virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { -    cur_sent = smeta.GetSentenceID(); +    cur_sent = stream ? 0 : smeta.GetSentenceID();      curr_src_length = (float) smeta.GetSourceLength();      if(unique_kbest) @@ -422,7 +422,8 @@ struct TrainingObserver : public DecoderObserver {    template <class Filter>      void UpdateOracles(int sent_id, const Hypergraph& forest) { -    bool PRINT_LIST= false;     +	if (stream) sent_id = 0; +    bool PRINT_LIST= false;      vector<shared_ptr<HypothesisInfo> >& cur_good = oracles[sent_id].good;      vector<shared_ptr<HypothesisInfo> >& cur_bad = oracles[sent_id].bad;      //TODO: look at keeping previous iterations hypothesis lists around @@ -723,10 +724,10 @@ int main(int argc, char** argv) {        getline(*in, buf);        if (buf.empty()) continue;        if (stream) { +    	  cur_sent = 0;      	  int delim = buf.find(" ||| ");      	  // Translate only      	  if (delim == -1) { -    		  cur_sent = 0;      		  decoder.SetId(cur_sent);      		  decoder.Decode(buf, &bobs);      		  vector<WordID> trans; @@ -735,7 +736,6 @@ int main(int argc, char** argv) {      		  continue;      	  // Translate and update (normal MIRA)      	  } else { -    		  cur_sent = 1;      		  ds->update(buf.substr(delim + 5));      		  buf = buf.substr(0, delim);      	  }  | 
