summaryrefslogtreecommitdiff
path: root/training/mira/kbest_cut_mira.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/mira/kbest_cut_mira.cc')
-rw-r--r--training/mira/kbest_cut_mira.cc8
1 files changed, 4 insertions, 4 deletions
diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc
index d8c42db7..e4435abb 100644
--- a/training/mira/kbest_cut_mira.cc
+++ b/training/mira/kbest_cut_mira.cc
@@ -400,7 +400,7 @@ struct TrainingObserver : public DecoderObserver {
virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) {
- cur_sent = smeta.GetSentenceID();
+ cur_sent = stream ? 0 : smeta.GetSentenceID();
curr_src_length = (float) smeta.GetSourceLength();
if(unique_kbest)
@@ -422,7 +422,8 @@ struct TrainingObserver : public DecoderObserver {
template <class Filter>
void UpdateOracles(int sent_id, const Hypergraph& forest) {
- bool PRINT_LIST= false;
+ if (stream) sent_id = 0;
+ bool PRINT_LIST= false;
vector<shared_ptr<HypothesisInfo> >& cur_good = oracles[sent_id].good;
vector<shared_ptr<HypothesisInfo> >& cur_bad = oracles[sent_id].bad;
//TODO: look at keeping previous iterations hypothesis lists around
@@ -723,10 +724,10 @@ int main(int argc, char** argv) {
getline(*in, buf);
if (buf.empty()) continue;
if (stream) {
+ cur_sent = 0;
int delim = buf.find(" ||| ");
// Translate only
if (delim == -1) {
- cur_sent = 0;
decoder.SetId(cur_sent);
decoder.Decode(buf, &bobs);
vector<WordID> trans;
@@ -735,7 +736,6 @@ int main(int argc, char** argv) {
continue;
// Translate and update (normal MIRA)
} else {
- cur_sent = 1;
ds->update(buf.substr(delim + 5));
buf = buf.substr(0, delim);
}