diff options
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 166 |
1 files changed, 113 insertions, 53 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index b39fff3e..e563f541 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -13,45 +13,65 @@ main(int argc, char** argv) if (!dtrain_init(argc, argv, &conf)) return 1; const size_t k = conf["k"].as<size_t>(); + const bool unique_kbest = conf["unique_kbest"].as<bool>(); + const bool forest_sample = conf["forest_sample"].as<bool>(); const string score_name = conf["score"].as<string>(); + const weight_t nakov_fix = conf["nakov_fix"].as<weight_t>(); + const weight_t chiang_decay = conf["chiang_decay"].as<weight_t>(); const size_t N = conf["N"].as<size_t>(); const size_t T = conf["iterations"].as<size_t>(); const weight_t eta = conf["learning_rate"].as<weight_t>(); const weight_t margin = conf["margin"].as<weight_t>(); + const weight_t cut = conf["cut"].as<weight_t>(); + const bool adjust_cut = conf["adjust"].as<bool>(); + const bool all_pairs = cut==0; const bool average = conf["average"].as<bool>(); - const bool structured = conf["struct"].as<bool>(); + const bool pro = conf["pro_sampling"].as<bool>(); + const bool structured = conf["structured"].as<bool>(); + const weight_t threshold = conf["threshold"].as<weight_t>(); + const size_t max_up = conf["max_pairs"].as<size_t>(); const weight_t l1_reg = conf["l1_reg"].as<weight_t>(); const bool keep = conf["keep"].as<bool>(); const bool noup = conf["disable_learning"].as<bool>(); const string output_fn = conf["output"].as<string>(); - const string output_data_which = conf["output_data"].as<string>(); - const bool output_data = output_data_which!=""; vector<string> print_weights; boost::split(print_weights, conf["print_weights"].as<string>(), boost::is_any_of(" ")); + const string output_updates_fn = conf["output_updates"].as<string>(); + const bool output_updates = output_updates_fn!=""; + const string output_raw_fn = conf["output_raw"].as<string>(); + const bool output_raw = output_raw_fn!=""; - // setup decoder and scorer + // setup decoder register_feature_functions(); SetSilent(true); ReadFile f(conf["decoder_conf"].as<string>()); Decoder decoder(f.stream()); + + // setup scorer & observer Scorer* scorer; if (score_name == "nakov") { - scorer = static_cast<PerSentenceBleuScorer*>(new PerSentenceBleuScorer(N)); + scorer = static_cast<NakovBleuScorer*>(new NakovBleuScorer(N, nakov_fix)); } else if (score_name == "papineni") { - scorer = static_cast<BleuScorer*>(new BleuScorer(N)); + scorer = static_cast<PapineniBleuScorer*>(new PapineniBleuScorer(N)); } else if (score_name == "lin") { - scorer = static_cast<OriginalPerSentenceBleuScorer*>\ - (new OriginalPerSentenceBleuScorer(N)); + scorer = static_cast<LinBleuScorer*>(new LinBleuScorer(N)); } else if (score_name == "liang") { - scorer = static_cast<SmoothPerSentenceBleuScorer*>\ - (new SmoothPerSentenceBleuScorer(N)); + scorer = static_cast<LiangBleuScorer*>(new LiangBleuScorer(N)); } else if (score_name == "chiang") { - scorer = static_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N)); + scorer = static_cast<ChiangBleuScorer*>(new ChiangBleuScorer(N)); + } else if (score_name == "sum") { + scorer = static_cast<SumBleuScorer*>(new SumBleuScorer(N)); } else { assert(false); } - ScoredKbest* observer = new ScoredKbest(k, scorer); + HypSampler* observer; + if (forest_sample) + observer = new KSampler(k, scorer); + else if (unique_kbest) + observer = new KBestSampler(k, scorer); + else + observer = new KBestNoFilterSampler(k, scorer); // weights vector<weight_t>& decoder_weights = decoder.CurrentWeightVector(); @@ -65,22 +85,46 @@ main(int argc, char** argv) string input_fn = conf["bitext"].as<string>(); ReadFile input(input_fn); vector<string> buf; // decoder only accepts strings as input - vector<vector<Ngrams> > buf_ngs; // compute ngrams and lengths of references - vector<vector<size_t> > buf_ls; // just once + vector<vector<Ngrams> > buffered_ngrams; // compute ngrams and lengths of references + vector<vector<size_t> > buffered_lengths; // (just once) size_t input_sz = 0; - cerr << _p4; + cerr << setprecision(4); // output configuration cerr << "Parameters:" << endl; cerr << setw(25) << "bitext " << "'" << input_fn << "'" << endl; cerr << setw(25) << "k " << k << endl; + if (unique_kbest && !forest_sample) + cerr << setw(25) << "unique k-best " << unique_kbest << endl; + if (forest_sample) + cerr << setw(25) << "forest " << forest_sample << endl; + if (all_pairs) + cerr << setw(25) << "all pairs " << all_pairs << endl; + else if (pro) + cerr << setw(25) << "PRO " << pro << endl; cerr << setw(25) << "score " << "'" << score_name << "'" << endl; + if (score_name == "nakov") + cerr << setw(25) << "nakov fix " << nakov_fix << endl; + if (score_name == "chiang") + cerr << setw(25) << "chiang decay " << chiang_decay << endl; cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; cerr << setw(25) << "learning rate " << eta << endl; cerr << setw(25) << "margin " << margin << endl; + if (!structured) { + cerr << setw(25) << "cut " << cut << endl; + cerr << setw(25) << "adjust " << adjust_cut << endl; + } else { + cerr << setw(25) << "struct. obj " << structured << endl; + } + if (threshold > 0) + cerr << setw(25) << "threshold " << threshold << endl; + if (max_up != numeric_limits<size_t>::max()) + cerr << setw(25) << "max up. " << max_up << endl; + if (noup) + cerr << setw(25) << "no up. " << noup << endl; cerr << setw(25) << "average " << average << endl; - cerr << setw(25) << "l1 reg " << l1_reg << endl; + cerr << setw(25) << "l1 reg. " << l1_reg << endl; cerr << setw(25) << "decoder conf " << "'" << conf["decoder_conf"].as<string>() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; @@ -89,6 +133,8 @@ main(int argc, char** argv) cerr << setw(25) << "weights in " << "'" << conf["input_weights"].as<string>() << "'" << endl; } + if (noup) + cerr << setw(25) << "no updates!" << endl; cerr << "(1 dot per processed input)" << endl; // meta @@ -96,6 +142,13 @@ main(int argc, char** argv) size_t best_iteration = 0; time_t total_time = 0.; + // output + WriteFile raw_out; + if (output_raw) raw_out.Init(output_raw_fn); + WriteFile updates_out; + if (output_updates) updates_out.Init(output_raw_fn); + + for (size_t t = 0; t < T; t++) // T iterations { @@ -120,16 +173,16 @@ main(int argc, char** argv) boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); buf.push_back(parts[0]); parts.erase(parts.begin()); - buf_ngs.push_back({}); - buf_ls.push_back({}); + buffered_ngrams.push_back({}); + buffered_lengths.push_back({}); for (auto s: parts) { vector<WordID> r; vector<string> toks; boost::split(toks, s, boost::is_any_of(" ")); for (auto tok: toks) r.push_back(TD::Convert(tok)); - buf_ngs.back().emplace_back(MakeNgrams(r, N)); - buf_ls.back().push_back(r.size()); + buffered_ngrams.back().emplace_back(ngrams(r, N)); + buffered_lengths.back().push_back(r.size()); } } } else { @@ -155,50 +208,54 @@ main(int argc, char** argv) // decode if (t > 0 || i > 0) lambdas.init_vector(&decoder_weights); - observer->SetReference(buf_ngs[i], buf_ls[i]); + observer->reference_ngrams = &buffered_ngrams[i]; + observer->reference_lengths = &buffered_lengths[i]; decoder.Decode(buf[i], observer); - vector<ScoredHyp>* samples = observer->GetSamples(); - - // stats for 1best - gold_sum += samples->front().gold; - model_sum += samples->front().model; - feature_count += observer->GetFeatureCount(); - list_sz += observer->GetSize(); - - if (output_data) { - if (output_data_which == "kbest") { - OutputKbest(samples); - } else if (output_data_which == "default") { - OutputMultipartitePairs(samples, margin); - } else if (output_data_which == "all") { - OutputAllPairs(samples); - } - } + vector<Hyp>* sample = &(observer->sample); + + // stats for 1-best + gold_sum += sample->front().gold; + model_sum += sample->front().model; + feature_count += observer->feature_count; + list_sz += observer->effective_size; + + if (output_raw) + output_sample(sample); - // get pairs and update + // update model if (!noup) { SparseVector<weight_t> updates; if (structured) - num_up += CollectUpdatesStruct(samples, updates); + num_up += update_structured(sample, updates, margin/*, + output_updates, updates_out.get()*/); // FIXME + else if (all_pairs) + num_up += updates_all(sample, updates, max_up, threshold/*, + output_updates, updates_out.get()*/); // FIXME + else if (pro) + num_up += updates_pro(sample, updates, cut, max_up, threshold/*, + output_updates, updates_out.get()*/); // FIXME else - num_up += CollectUpdates(samples, updates, margin); + num_up += updates_multipartite(sample, updates, cut, margin, + max_up, threshold, adjust_cut/*, + output_updates, updates_out.get()*/); // FIXME SparseVector<weight_t> lambdas_copy; if (l1_reg) lambdas_copy = lambdas; lambdas.plus_eq_v_times_s(updates, eta); - // update context for approx. BLEU + // update context for Chiang's approx. BLEU if (score_name == "chiang") { - for (auto it: *samples) { + for (auto it: *sample) { if (it.rank == 0) { - scorer->UpdateContext(it.w, buf_ngs[i], buf_ls[i], 0.9); + scorer->update_context(it.w, buffered_ngrams[i], + buffered_lengths[i], chiang_decay); break; } } } - // l1 regularization + // \ell_1 regularization // NB: regularization is done after each sentence, // not after every single pair! if (l1_reg) { @@ -234,19 +291,22 @@ main(int argc, char** argv) // stats weight_t gold_avg = gold_sum/(weight_t)input_sz; - cerr << _p << "WEIGHTS" << endl; + cerr << setiosflags(ios::showpos) << "WEIGHTS" << endl; for (auto name: print_weights) - cerr << setw(18) << name << " = " << lambdas.get(FD::Convert(name)) << endl; + cerr << setw(18) << name << " = " + << lambdas.get(FD::Convert(name)) << endl; cerr << " ---" << endl; - cerr << _np << " 1best avg score: " << gold_avg*100; - cerr << _p << " (" << (gold_avg-gold_prev)*100 << ")" << endl; + cerr << resetiosflags(ios::showpos) + << " 1best avg score: " << gold_avg*100; + cerr << setiosflags(ios::showpos) << " (" + << (gold_avg-gold_prev)*100 << ")" << endl; cerr << " 1best avg model score: " << model_sum/(weight_t)input_sz << endl; cerr << " avg # updates: "; - cerr << _np << num_up/(float)input_sz << endl; - cerr << " non-0 feature count: " << lambdas.num_nonzero() << endl; - cerr << " avg f count: " << feature_count/(float)list_sz << endl; - cerr << " avg list sz: " << list_sz/(float)input_sz << endl; + cerr << resetiosflags(ios::showpos) << num_up/(float)input_sz << endl; + cerr << " non-0 feature count: " << lambdas.num_nonzero() << endl; + cerr << " avg f count: " << feature_count/(float)list_sz << endl; + cerr << " avg list sz: " << list_sz/(float)input_sz << endl; if (gold_avg > best) { best = gold_avg; |