summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormjdenkowski <michael.j.denkowski@gmail.com>2013-08-29 02:27:51 -0400
committermjdenkowski <michael.j.denkowski@gmail.com>2013-08-29 02:27:51 -0400
commit532899b5101250d4f187733860404ee932a8851c (patch)
tree36685dd7887987b8024affb7ba79f0e9b05e65c5
parent0bc21f0fbcf5e060c1a9b249527e094436a383d8 (diff)
Cleanup, fix id issue.
-rw-r--r--mteval/scorer.cc13
-rw-r--r--mteval/scorer.h13
-rw-r--r--training/mira/kbest_cut_mira.cc8
3 files changed, 18 insertions, 16 deletions
diff --git a/mteval/scorer.cc b/mteval/scorer.cc
index ced0cadf..de84e076 100644
--- a/mteval/scorer.cc
+++ b/mteval/scorer.cc
@@ -595,7 +595,6 @@ void DocScorer::Init(
const vector<string>& ref_files,
const string& src_file, bool verbose) {
scorers_.clear();
- this->type = type;
// TODO stop using valarray, start using ReadFile
cerr << "Loading references (" << ref_files.size() << " files)\n";
ReadFile srcrf;
@@ -645,6 +644,9 @@ void DocScorer::Init(
cerr << "Loaded reference translations for " << scorers_.size() << " sentences.\n";
}
+DocStreamScorer::~DocStreamScorer() {
+}
+
void DocStreamScorer::Init(
const ScoreType type,
const vector<string>& ref_files,
@@ -655,17 +657,14 @@ void DocStreamScorer::Init(
this->type = type;
vector<vector<WordID> > refs(1);
string src_line;
- // Empty reference 0
+ // Initialize empty reference
TD::ConvertSentence("", &refs[0]);
- scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line)));
- // Reference 1 starts empty, updated as needed
- scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line)));
+ scorer = ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line));
}
void DocStreamScorer::update(const std::string& ref) {
- scorers_.pop_back();
vector<vector<WordID> > refs(1);
string src_line;
TD::ConvertSentence(ref, &refs[0]);
- scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(this->type, refs, src_line)));
+ scorer = ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line));
}
diff --git a/mteval/scorer.h b/mteval/scorer.h
index 56c39a7d..bb1e89ae 100644
--- a/mteval/scorer.h
+++ b/mteval/scorer.h
@@ -101,16 +101,16 @@ class DocScorer {
Init(type,ref_files,src_file,verbose);
}
- int size() const { return scorers_.size(); }
- ScorerP operator[](size_t i) const { return scorers_[i]; }
+ virtual int size() const { return scorers_.size(); }
+ virtual ScorerP operator[](size_t i) const { return scorers_[i]; }
virtual void update(const std::string& ref) {}
private:
- ScoreType type;
std::vector<ScorerP> scorers_;
};
class DocStreamScorer : public DocScorer {
public:
+ ~DocStreamScorer();
void Init(const ScoreType type,
const std::vector<std::string>& ref_files,
const std::string& src_file = "",
@@ -124,9 +124,12 @@ class DocStreamScorer : public DocScorer {
{
Init(type,ref_files,src_file,verbose);
}
- ScorerP operator[](size_t i);
- int size();
+ ScorerP operator[](size_t i) const { return scorer; }
+ int size() const { return 1; }
void update(const std::string& ref);
+ private:
+ ScoreType type;
+ ScorerP scorer;
};
#endif
diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc
index d8c42db7..e4435abb 100644
--- a/training/mira/kbest_cut_mira.cc
+++ b/training/mira/kbest_cut_mira.cc
@@ -400,7 +400,7 @@ struct TrainingObserver : public DecoderObserver {
virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) {
- cur_sent = smeta.GetSentenceID();
+ cur_sent = stream ? 0 : smeta.GetSentenceID();
curr_src_length = (float) smeta.GetSourceLength();
if(unique_kbest)
@@ -422,7 +422,8 @@ struct TrainingObserver : public DecoderObserver {
template <class Filter>
void UpdateOracles(int sent_id, const Hypergraph& forest) {
- bool PRINT_LIST= false;
+ if (stream) sent_id = 0;
+ bool PRINT_LIST= false;
vector<shared_ptr<HypothesisInfo> >& cur_good = oracles[sent_id].good;
vector<shared_ptr<HypothesisInfo> >& cur_bad = oracles[sent_id].bad;
//TODO: look at keeping previous iterations hypothesis lists around
@@ -723,10 +724,10 @@ int main(int argc, char** argv) {
getline(*in, buf);
if (buf.empty()) continue;
if (stream) {
+ cur_sent = 0;
int delim = buf.find(" ||| ");
// Translate only
if (delim == -1) {
- cur_sent = 0;
decoder.SetId(cur_sent);
decoder.Decode(buf, &bobs);
vector<WordID> trans;
@@ -735,7 +736,6 @@ int main(int argc, char** argv) {
continue;
// Translate and update (normal MIRA)
} else {
- cur_sent = 1;
ds->update(buf.substr(delim + 5));
buf = buf.substr(0, delim);
}