diff options
author | Patrick Simianer <p@simianer.de> | 2011-09-26 21:51:52 +0200 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2011-09-26 21:51:52 +0200 |
commit | 36de7283576dd22a91577ef175c62434f3d933b4 (patch) | |
tree | 544916f6305deb5c281153e7f4e208b6e3a8b568 /dtrain/dtrain.cc | |
parent | e16b311246f9f2c309b257debd5f50a28b04802b (diff) |
got rid of scoring loop
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 77 |
1 files changed, 32 insertions, 45 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 35e6cc46..622cd01e 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -95,38 +95,32 @@ main(int argc, char** argv) cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl; Decoder decoder(ini_rf.stream()); - MT19937 rng; // random number generator - // setup decoder observer - HypSampler* observer; - if (sample_from == "kbest") { - observer = dynamic_cast<KBestGetter*>(new KBestGetter(k, filter_type)); - } else { - observer = dynamic_cast<KSampler*>(new KSampler(k, &rng)); - } - // scoring metric/scorer string scorer_str = cfg["scorer"].as<string>(); - /*score_t (*scorer)(NgramCounts&, const unsigned, const unsigned, unsigned, vector<score_t>); + LocalScorer* scorer; if (scorer_str == "bleu") { - scorer = &bleu; } else if (scorer_str == "stupid_bleu") { - scorer = &stupid_bleu; + scorer = dynamic_cast<StupidBleuScorer*>(new StupidBleuScorer); } else if (scorer_str == "smooth_bleu") { - scorer = &smooth_bleu; + scorer = dynamic_cast<SmoothBleuScorer*>(new SmoothBleuScorer); } else if (scorer_str == "approx_bleu") { - scorer = &approx_bleu; + scorer = dynamic_cast<StupidBleuScorer*>(new StupidBleuScorer); // FIXME } else { cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl; exit(1); } - NgramCounts global_counts(N); // counts for 1 best translations - unsigned global_hyp_len = 0; // sum hypothesis lengths - unsigned global_ref_len = 0; // sum reference lengths - // ^^^ global_* for approx_bleu*/ - vector<score_t> bleu_weights; // we leave this empty -> 1/N - //if (!quiet) cerr << setw(26) << "scorer '" << scorer_str << "'" << endl << endl; - StupidBleuScorer scorer; - scorer.Init(N, bleu_weights); + vector<score_t> bleu_weights; + scorer->Init(N, bleu_weights); + if (!quiet) cerr << setw(26) << "scorer '" << scorer_str << "'" << endl << endl; + + // setup decoder observer + MT19937 rng; // random number generator + HypSampler* observer; + if (sample_from == "kbest") + observer = dynamic_cast<KBestGetter*>(new KBestGetter(k, filter_type)); + else + observer = dynamic_cast<KSampler*>(new KSampler(k, &rng)); + observer->SetScorer(scorer); // init weights Weights weights; @@ -240,10 +234,10 @@ main(int argc, char** argv) vector<WordID> ref_ids; // reference as vector<WordID> if (t == 0) { // handling input - strsplit(in, in_split, '\t', 4); + boost::split(in_split, in, boost::is_any_of("\t")); // getting reference vector<string> ref_tok; - strsplit(in_split[2], ref_tok, ' '); + boost::split(ref_tok, in_split[2], boost::is_any_of(" ")); register_and_convert(ref_tok, ref_ids); ref_ids_buf.push_back(ref_ids); // process and set grammar @@ -259,8 +253,9 @@ main(int argc, char** argv) in_split[3] += "\n"; grammar_buf_out << in_split[3] << DTRAIN_GRAMMAR_DELIM << " " << in_split[0] << endl; decoder.SetSentenceGrammarFromString(in_split[3]); - // decode src_str_buf.push_back(in_split[1]); + // decode + observer->SetRef(ref_ids); decoder.Decode(in_split[1], observer); } else { // get buffered grammar @@ -273,32 +268,24 @@ main(int argc, char** argv) } decoder.SetSentenceGrammarFromString(grammar_str); // decode + observer->SetRef(ref_ids_buf[ii]); decoder.Decode(src_str_buf[ii], observer); } + // get (scored) samples vector<ScoredHyp>* samples = observer->GetSamples(); - // (local) scoring - if (t > 0) ref_ids = ref_ids_buf[ii]; - for (unsigned i = 0; i < samples->size(); i++) { - //cout << ii << " " << i << endl; - - cout << _p9; - (*samples)[i].score = scorer.Score((*samples)[i], ref_ids, ii); - if (i == 0) { - score_sum += (*samples)[i].score; - model_sum += (*samples)[i].model; - } - - if (verbose) { - if (i == 0) cerr << "'" << TD::GetString(ref_ids) << "' [ref]" << endl; - cerr << _p5 << _np << "[hyp " << i << "] " << "'" << TD::GetString((*samples)[i].w) << "'"; - cerr << " [SCORE=" << (*samples)[i].score << ",model="<< (*samples)[i].model << "]" << endl; - cerr << (*samples)[i].f << endl; - } + if (verbose) { + cout << "[ref: '"; + if (t > 0) cout << ref_ids_buf[ii]; + else cout << ref_ids; + cout << endl; + cout << _p5 << _np << "1best: " << "'" << (*samples)[0].w << "'" << endl; + cout << "SCORE=" << (*samples)[0].score << ",model="<< (*samples)[0].model << endl; + cout << "F{" << (*samples)[0].f << "} ]" << endl << endl; } - - if (verbose) cerr << endl; + score_sum += (*samples)[0].score; + model_sum += (*samples)[0].model; ////////////////////////////////////////////////////////// // UPDATE WEIGHTS |