diff options
author | Patrick Simianer <p@simianer.de> | 2011-09-25 21:43:57 +0200 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2011-09-25 21:43:57 +0200 |
commit | 4d8c300734c441821141f4bff044c439e004ff84 (patch) | |
tree | 5b9e2b7f9994d9a71e0e2d17f33ba2ff4a1145a1 /dtrain/dtrain.cc | |
parent | fe471bb707226052551d75b043295ca5f57261c0 (diff) |
kbest, ksampler refactoring
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 48 |
1 files changed, 24 insertions, 24 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index a70ca2f1..ad1ab7b7 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -10,11 +10,11 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("output", po::value<string>()->default_value("-"), "output weights file (or VOID)") ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") ("decoder_config", po::value<string>(), "configuration file for cdec") - ("ksamples", po::value<size_t>()->default_value(100), "size of kbest or sample from forest") + ("k", po::value<size_t>()->default_value(100), "size of kbest or sample from forest") ("sample_from", po::value<string>()->default_value("kbest"), "where to get translations from") ("filter", po::value<string>()->default_value("unique"), "filter kbest list") ("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, rand") - ("ngrams", po::value<size_t>()->default_value(3), "N for Ngrams") + ("N", po::value<size_t>()->default_value(3), "N for Ngrams") ("epochs", po::value<size_t>()->default_value(2), "# of iterations T") ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring metric") ("stop_after", po::value<size_t>()->default_value(0), "stop after X input sentences") @@ -75,8 +75,8 @@ main(int argc, char** argv) hstreaming = true; quiet = true; } - const size_t k = cfg["ksamples"].as<size_t>(); - const size_t N = cfg["ngrams"].as<size_t>(); + const size_t k = cfg["k"].as<size_t>(); + const size_t N = cfg["N"].as<size_t>(); const size_t T = cfg["epochs"].as<size_t>(); const size_t stop_after = cfg["stop_after"].as<size_t>(); const string filter_type = cfg["filter"].as<string>(); @@ -96,7 +96,7 @@ main(int argc, char** argv) MT19937 rng; // random number generator // setup decoder observer - HypoSampler* observer; + HypSampler* observer; if (sample_from == "kbest") { observer = dynamic_cast<KBestGetter*>(new KBestGetter(k, filter_type)); } else { @@ -274,45 +274,45 @@ main(int argc, char** argv) decoder.Decode(src_str_buf[ii], observer); } - Samples* samples = observer->GetSamples(); + vector<ScoredHyp>* samples = observer->GetSamples(); // (local) scoring if (t > 0) ref_ids = ref_ids_buf[ii]; score_t score = 0.; - for (size_t i = 0; i < samples->GetSize(); i++) { - NgramCounts counts = make_ngram_counts(ref_ids, samples->sents[i], N); + for (size_t i = 0; i < samples->size(); i++) { + NgramCounts counts = make_ngram_counts(ref_ids, (*samples)[i].w, N); if (scorer_str == "approx_bleu") { size_t hyp_len = 0; if (i == 0) { // 'context of 1best translations' global_counts += counts; - global_hyp_len += samples->sents[i].size(); + global_hyp_len += (*samples)[i].w.size(); global_ref_len += ref_ids.size(); counts.reset(); } else { - hyp_len = samples->sents[i].size(); + hyp_len = (*samples)[i].w.size(); } - NgramCounts counts_tmp = global_counts + counts; - score = .9 * scorer(counts_tmp, + NgramCounts _c = global_counts + counts; + score = .9 * scorer(_c, global_ref_len, global_hyp_len + hyp_len, N, bleu_weights); } else { score = scorer(counts, ref_ids.size(), - samples->sents[i].size(), N, bleu_weights); + (*samples)[i].w.size(), N, bleu_weights); } - samples->scores.push_back(score); + (*samples)[i].score = (score); if (i == 0) { score_sum += score; - model_sum += samples->model_scores[i]; + model_sum += (*samples)[i].model; } if (verbose) { if (i == 0) cout << "'" << TD::GetString(ref_ids) << "' [ref]" << endl; - cout << _p5 << _np << "[hyp " << i << "] " << "'" << TD::GetString(samples->sents[i]) << "'"; - cout << " [SCORE=" << score << ",model="<< samples->model_scores[i] << "]" << endl; - cout << samples->feats[i] << endl; + cout << _p5 << _np << "[hyp " << i << "] " << "'" << TD::GetString((*samples)[i].w) << "'"; + cout << " [SCORE=" << score << ",model="<< (*samples)[i].model << "]" << endl; + cout << (*samples)[i].f << endl; } } // sample/scoring loop @@ -321,18 +321,18 @@ main(int argc, char** argv) ////////////////////////////////////////////////////////// // UPDATE WEIGHTS if (!noup) { - vector<Pair> pairs; + vector<pair<ScoredHyp,ScoredHyp> > pairs; if (pair_sampling == "all") sample_all_pairs(samples, pairs); if (pair_sampling == "rand") sample_rand_pairs(samples, pairs, &rng); - for (vector<Pair>::iterator ti = pairs.begin(); + for (vector<pair<ScoredHyp,ScoredHyp> >::iterator ti = pairs.begin(); ti != pairs.end(); ti++) { SparseVector<double> dv; - if (ti->first_score - ti->second_score < 0) { - dv = ti->second - ti->first; + if (ti->first.score - ti->second.score < 0) { + dv = ti->second.f - ti->first.f; //} else { //dv = ti->first - ti->second; //} @@ -344,14 +344,14 @@ main(int argc, char** argv) lambdas += dv * eta; if (verbose) { - cout << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl; + /*cout << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl; cout << " i " << TD::GetString(samples->sents[ti->first_rank]) << endl; cout << " " << samples->feats[ti->first_rank] << endl; cout << " j " << TD::GetString(samples->sents[ti->second_rank]) << endl; cout << " " << samples->feats[ti->second_rank] << endl; cout << " diff vec: " << dv << endl; cout << " lambdas after update: " << lambdas << endl; - cout << "}}" << endl; + cout << "}}" << endl;*/ } } else { //SparseVector<double> reg; |