diff options
author | mjdenkowski <michael.j.denkowski@gmail.com> | 2013-08-29 02:27:51 -0400 |
---|---|---|
committer | mjdenkowski <michael.j.denkowski@gmail.com> | 2013-08-29 02:27:51 -0400 |
commit | e078ac74f3499298742174a06f915b58f2d4cbdb (patch) | |
tree | def7762832ed4a7bc6adbb8b1b6515829544604d | |
parent | 6f462d23384b6e42a944feedaf6f37ae7a5b7921 (diff) |
Cleanup, fix id issue.
-rw-r--r-- | mteval/scorer.cc | 13 | ||||
-rw-r--r-- | mteval/scorer.h | 13 | ||||
-rw-r--r-- | training/mira/kbest_cut_mira.cc | 8 |
3 files changed, 18 insertions, 16 deletions
diff --git a/mteval/scorer.cc b/mteval/scorer.cc index ced0cadf..de84e076 100644 --- a/mteval/scorer.cc +++ b/mteval/scorer.cc @@ -595,7 +595,6 @@ void DocScorer::Init( const vector<string>& ref_files, const string& src_file, bool verbose) { scorers_.clear(); - this->type = type; // TODO stop using valarray, start using ReadFile cerr << "Loading references (" << ref_files.size() << " files)\n"; ReadFile srcrf; @@ -645,6 +644,9 @@ void DocScorer::Init( cerr << "Loaded reference translations for " << scorers_.size() << " sentences.\n"; } +DocStreamScorer::~DocStreamScorer() { +} + void DocStreamScorer::Init( const ScoreType type, const vector<string>& ref_files, @@ -655,17 +657,14 @@ void DocStreamScorer::Init( this->type = type; vector<vector<WordID> > refs(1); string src_line; - // Empty reference 0 + // Initialize empty reference TD::ConvertSentence("", &refs[0]); - scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line))); - // Reference 1 starts empty, updated as needed - scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line))); + scorer = ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line)); } void DocStreamScorer::update(const std::string& ref) { - scorers_.pop_back(); vector<vector<WordID> > refs(1); string src_line; TD::ConvertSentence(ref, &refs[0]); - scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(this->type, refs, src_line))); + scorer = ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line)); } diff --git a/mteval/scorer.h b/mteval/scorer.h index 56c39a7d..bb1e89ae 100644 --- a/mteval/scorer.h +++ b/mteval/scorer.h @@ -101,16 +101,16 @@ class DocScorer { Init(type,ref_files,src_file,verbose); } - int size() const { return scorers_.size(); } - ScorerP operator[](size_t i) const { return scorers_[i]; } + virtual int size() const { return scorers_.size(); } + virtual ScorerP operator[](size_t i) const { return scorers_[i]; } virtual void update(const std::string& ref) {} private: - ScoreType type; std::vector<ScorerP> scorers_; }; class DocStreamScorer : public DocScorer { public: + ~DocStreamScorer(); void Init(const ScoreType type, const std::vector<std::string>& ref_files, const std::string& src_file = "", @@ -124,9 +124,12 @@ class DocStreamScorer : public DocScorer { { Init(type,ref_files,src_file,verbose); } - ScorerP operator[](size_t i); - int size(); + ScorerP operator[](size_t i) const { return scorer; } + int size() const { return 1; } void update(const std::string& ref); + private: + ScoreType type; + ScorerP scorer; }; #endif diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc index d8c42db7..e4435abb 100644 --- a/training/mira/kbest_cut_mira.cc +++ b/training/mira/kbest_cut_mira.cc @@ -400,7 +400,7 @@ struct TrainingObserver : public DecoderObserver { virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { - cur_sent = smeta.GetSentenceID(); + cur_sent = stream ? 0 : smeta.GetSentenceID(); curr_src_length = (float) smeta.GetSourceLength(); if(unique_kbest) @@ -422,7 +422,8 @@ struct TrainingObserver : public DecoderObserver { template <class Filter> void UpdateOracles(int sent_id, const Hypergraph& forest) { - bool PRINT_LIST= false; + if (stream) sent_id = 0; + bool PRINT_LIST= false; vector<shared_ptr<HypothesisInfo> >& cur_good = oracles[sent_id].good; vector<shared_ptr<HypothesisInfo> >& cur_bad = oracles[sent_id].bad; //TODO: look at keeping previous iterations hypothesis lists around @@ -723,10 +724,10 @@ int main(int argc, char** argv) { getline(*in, buf); if (buf.empty()) continue; if (stream) { + cur_sent = 0; int delim = buf.find(" ||| "); // Translate only if (delim == -1) { - cur_sent = 0; decoder.SetId(cur_sent); decoder.Decode(buf, &bobs); vector<WordID> trans; @@ -735,7 +736,6 @@ int main(int argc, char** argv) { continue; // Translate and update (normal MIRA) } else { - cur_sent = 1; ds->update(buf.substr(delim + 5)); buf = buf.substr(0, delim); } |