From 8bb00a2a2775442418f1cb7c041f7cba5d6e0d42 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Mon, 26 Sep 2011 21:51:52 +0200 Subject: got rid of scoring loop --- dtrain/dtrain.cc | 77 +++++++++++++++++++++++--------------------------------- 1 file changed, 32 insertions(+), 45 deletions(-) (limited to 'dtrain/dtrain.cc') diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 35e6cc46..622cd01e 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -95,38 +95,32 @@ main(int argc, char** argv) cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as() << "'" << endl; Decoder decoder(ini_rf.stream()); - MT19937 rng; // random number generator - // setup decoder observer - HypSampler* observer; - if (sample_from == "kbest") { - observer = dynamic_cast(new KBestGetter(k, filter_type)); - } else { - observer = dynamic_cast(new KSampler(k, &rng)); - } - // scoring metric/scorer string scorer_str = cfg["scorer"].as(); - /*score_t (*scorer)(NgramCounts&, const unsigned, const unsigned, unsigned, vector); + LocalScorer* scorer; if (scorer_str == "bleu") { - scorer = &bleu; } else if (scorer_str == "stupid_bleu") { - scorer = &stupid_bleu; + scorer = dynamic_cast(new StupidBleuScorer); } else if (scorer_str == "smooth_bleu") { - scorer = &smooth_bleu; + scorer = dynamic_cast(new SmoothBleuScorer); } else if (scorer_str == "approx_bleu") { - scorer = &approx_bleu; + scorer = dynamic_cast(new StupidBleuScorer); // FIXME } else { cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl; exit(1); } - NgramCounts global_counts(N); // counts for 1 best translations - unsigned global_hyp_len = 0; // sum hypothesis lengths - unsigned global_ref_len = 0; // sum reference lengths - // ^^^ global_* for approx_bleu*/ - vector bleu_weights; // we leave this empty -> 1/N - //if (!quiet) cerr << setw(26) << "scorer '" << scorer_str << "'" << endl << endl; - StupidBleuScorer scorer; - scorer.Init(N, bleu_weights); + vector bleu_weights; + scorer->Init(N, bleu_weights); + if (!quiet) cerr << setw(26) << "scorer '" << scorer_str << "'" << endl << endl; + + // setup decoder observer + MT19937 rng; // random number generator + HypSampler* observer; + if (sample_from == "kbest") + observer = dynamic_cast(new KBestGetter(k, filter_type)); + else + observer = dynamic_cast(new KSampler(k, &rng)); + observer->SetScorer(scorer); // init weights Weights weights; @@ -240,10 +234,10 @@ main(int argc, char** argv) vector ref_ids; // reference as vector if (t == 0) { // handling input - strsplit(in, in_split, '\t', 4); + boost::split(in_split, in, boost::is_any_of("\t")); // getting reference vector ref_tok; - strsplit(in_split[2], ref_tok, ' '); + boost::split(ref_tok, in_split[2], boost::is_any_of(" ")); register_and_convert(ref_tok, ref_ids); ref_ids_buf.push_back(ref_ids); // process and set grammar @@ -259,8 +253,9 @@ main(int argc, char** argv) in_split[3] += "\n"; grammar_buf_out << in_split[3] << DTRAIN_GRAMMAR_DELIM << " " << in_split[0] << endl; decoder.SetSentenceGrammarFromString(in_split[3]); - // decode src_str_buf.push_back(in_split[1]); + // decode + observer->SetRef(ref_ids); decoder.Decode(in_split[1], observer); } else { // get buffered grammar @@ -273,32 +268,24 @@ main(int argc, char** argv) } decoder.SetSentenceGrammarFromString(grammar_str); // decode + observer->SetRef(ref_ids_buf[ii]); decoder.Decode(src_str_buf[ii], observer); } + // get (scored) samples vector* samples = observer->GetSamples(); - // (local) scoring - if (t > 0) ref_ids = ref_ids_buf[ii]; - for (unsigned i = 0; i < samples->size(); i++) { - //cout << ii << " " << i << endl; - - cout << _p9; - (*samples)[i].score = scorer.Score((*samples)[i], ref_ids, ii); - if (i == 0) { - score_sum += (*samples)[i].score; - model_sum += (*samples)[i].model; - } - - if (verbose) { - if (i == 0) cerr << "'" << TD::GetString(ref_ids) << "' [ref]" << endl; - cerr << _p5 << _np << "[hyp " << i << "] " << "'" << TD::GetString((*samples)[i].w) << "'"; - cerr << " [SCORE=" << (*samples)[i].score << ",model="<< (*samples)[i].model << "]" << endl; - cerr << (*samples)[i].f << endl; - } + if (verbose) { + cout << "[ref: '"; + if (t > 0) cout << ref_ids_buf[ii]; + else cout << ref_ids; + cout << endl; + cout << _p5 << _np << "1best: " << "'" << (*samples)[0].w << "'" << endl; + cout << "SCORE=" << (*samples)[0].score << ",model="<< (*samples)[0].model << endl; + cout << "F{" << (*samples)[0].f << "} ]" << endl << endl; } - - if (verbose) cerr << endl; + score_sum += (*samples)[0].score; + model_sum += (*samples)[0].model; ////////////////////////////////////////////////////////// // UPDATE WEIGHTS -- cgit v1.2.3