summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc77
1 files changed, 32 insertions, 45 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 35e6cc46..622cd01e 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -95,38 +95,32 @@ main(int argc, char** argv)
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
Decoder decoder(ini_rf.stream());
- MT19937 rng; // random number generator
- // setup decoder observer
- HypSampler* observer;
- if (sample_from == "kbest") {
- observer = dynamic_cast<KBestGetter*>(new KBestGetter(k, filter_type));
- } else {
- observer = dynamic_cast<KSampler*>(new KSampler(k, &rng));
- }
-
// scoring metric/scorer
string scorer_str = cfg["scorer"].as<string>();
- /*score_t (*scorer)(NgramCounts&, const unsigned, const unsigned, unsigned, vector<score_t>);
+ LocalScorer* scorer;
if (scorer_str == "bleu") {
- scorer = &bleu;
} else if (scorer_str == "stupid_bleu") {
- scorer = &stupid_bleu;
+ scorer = dynamic_cast<StupidBleuScorer*>(new StupidBleuScorer);
} else if (scorer_str == "smooth_bleu") {
- scorer = &smooth_bleu;
+ scorer = dynamic_cast<SmoothBleuScorer*>(new SmoothBleuScorer);
} else if (scorer_str == "approx_bleu") {
- scorer = &approx_bleu;
+ scorer = dynamic_cast<StupidBleuScorer*>(new StupidBleuScorer); // FIXME
} else {
cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl;
exit(1);
}
- NgramCounts global_counts(N); // counts for 1 best translations
- unsigned global_hyp_len = 0; // sum hypothesis lengths
- unsigned global_ref_len = 0; // sum reference lengths
- // ^^^ global_* for approx_bleu*/
- vector<score_t> bleu_weights; // we leave this empty -> 1/N
- //if (!quiet) cerr << setw(26) << "scorer '" << scorer_str << "'" << endl << endl;
- StupidBleuScorer scorer;
- scorer.Init(N, bleu_weights);
+ vector<score_t> bleu_weights;
+ scorer->Init(N, bleu_weights);
+ if (!quiet) cerr << setw(26) << "scorer '" << scorer_str << "'" << endl << endl;
+
+ // setup decoder observer
+ MT19937 rng; // random number generator
+ HypSampler* observer;
+ if (sample_from == "kbest")
+ observer = dynamic_cast<KBestGetter*>(new KBestGetter(k, filter_type));
+ else
+ observer = dynamic_cast<KSampler*>(new KSampler(k, &rng));
+ observer->SetScorer(scorer);
// init weights
Weights weights;
@@ -240,10 +234,10 @@ main(int argc, char** argv)
vector<WordID> ref_ids; // reference as vector<WordID>
if (t == 0) {
// handling input
- strsplit(in, in_split, '\t', 4);
+ boost::split(in_split, in, boost::is_any_of("\t"));
// getting reference
vector<string> ref_tok;
- strsplit(in_split[2], ref_tok, ' ');
+ boost::split(ref_tok, in_split[2], boost::is_any_of(" "));
register_and_convert(ref_tok, ref_ids);
ref_ids_buf.push_back(ref_ids);
// process and set grammar
@@ -259,8 +253,9 @@ main(int argc, char** argv)
in_split[3] += "\n";
grammar_buf_out << in_split[3] << DTRAIN_GRAMMAR_DELIM << " " << in_split[0] << endl;
decoder.SetSentenceGrammarFromString(in_split[3]);
- // decode
src_str_buf.push_back(in_split[1]);
+ // decode
+ observer->SetRef(ref_ids);
decoder.Decode(in_split[1], observer);
} else {
// get buffered grammar
@@ -273,32 +268,24 @@ main(int argc, char** argv)
}
decoder.SetSentenceGrammarFromString(grammar_str);
// decode
+ observer->SetRef(ref_ids_buf[ii]);
decoder.Decode(src_str_buf[ii], observer);
}
+ // get (scored) samples
vector<ScoredHyp>* samples = observer->GetSamples();
- // (local) scoring
- if (t > 0) ref_ids = ref_ids_buf[ii];
- for (unsigned i = 0; i < samples->size(); i++) {
- //cout << ii << " " << i << endl;
-
- cout << _p9;
- (*samples)[i].score = scorer.Score((*samples)[i], ref_ids, ii);
- if (i == 0) {
- score_sum += (*samples)[i].score;
- model_sum += (*samples)[i].model;
- }
-
- if (verbose) {
- if (i == 0) cerr << "'" << TD::GetString(ref_ids) << "' [ref]" << endl;
- cerr << _p5 << _np << "[hyp " << i << "] " << "'" << TD::GetString((*samples)[i].w) << "'";
- cerr << " [SCORE=" << (*samples)[i].score << ",model="<< (*samples)[i].model << "]" << endl;
- cerr << (*samples)[i].f << endl;
- }
+ if (verbose) {
+ cout << "[ref: '";
+ if (t > 0) cout << ref_ids_buf[ii];
+ else cout << ref_ids;
+ cout << endl;
+ cout << _p5 << _np << "1best: " << "'" << (*samples)[0].w << "'" << endl;
+ cout << "SCORE=" << (*samples)[0].score << ",model="<< (*samples)[0].model << endl;
+ cout << "F{" << (*samples)[0].f << "} ]" << endl << endl;
}
-
- if (verbose) cerr << endl;
+ score_sum += (*samples)[0].score;
+ model_sum += (*samples)[0].model;
//////////////////////////////////////////////////////////
// UPDATE WEIGHTS