diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/dtrain/README.md | 18 | ||||
-rw-r--r-- | training/dtrain/dtrain.cc | 166 | ||||
-rw-r--r-- | training/dtrain/dtrain.h | 110 |
3 files changed, 197 insertions, 97 deletions
diff --git a/training/dtrain/README.md b/training/dtrain/README.md index 73a6a5a5..dc473568 100644 --- a/training/dtrain/README.md +++ b/training/dtrain/README.md @@ -16,6 +16,24 @@ Running ------- Download runnable examples for all use cases from [1] and extract here. +TODO +---- + * "stop_after" stop after X inputs + * "select_weights" average, best, last + * "rescale" rescale weight vector + * implement SVM objective? + * other variants of l1 regularization? + * l2 regularization? + * l1/l2 regularization? + * scale updates by bleu difference + * AdaGrad, per-coordinate learning rates + * batch update + * "repeat" iterate over k-best lists + * show k-best loss improvement + * "quiet" + * "verbose" + * fix output + Legal ----- Copyright (c) 2012-2015 by Patrick Simianer <p@simianer.de> diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index b39fff3e..e563f541 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -13,45 +13,65 @@ main(int argc, char** argv) if (!dtrain_init(argc, argv, &conf)) return 1; const size_t k = conf["k"].as<size_t>(); + const bool unique_kbest = conf["unique_kbest"].as<bool>(); + const bool forest_sample = conf["forest_sample"].as<bool>(); const string score_name = conf["score"].as<string>(); + const weight_t nakov_fix = conf["nakov_fix"].as<weight_t>(); + const weight_t chiang_decay = conf["chiang_decay"].as<weight_t>(); const size_t N = conf["N"].as<size_t>(); const size_t T = conf["iterations"].as<size_t>(); const weight_t eta = conf["learning_rate"].as<weight_t>(); const weight_t margin = conf["margin"].as<weight_t>(); + const weight_t cut = conf["cut"].as<weight_t>(); + const bool adjust_cut = conf["adjust"].as<bool>(); + const bool all_pairs = cut==0; const bool average = conf["average"].as<bool>(); - const bool structured = conf["struct"].as<bool>(); + const bool pro = conf["pro_sampling"].as<bool>(); + const bool structured = conf["structured"].as<bool>(); + const weight_t threshold = conf["threshold"].as<weight_t>(); + const size_t max_up = conf["max_pairs"].as<size_t>(); const weight_t l1_reg = conf["l1_reg"].as<weight_t>(); const bool keep = conf["keep"].as<bool>(); const bool noup = conf["disable_learning"].as<bool>(); const string output_fn = conf["output"].as<string>(); - const string output_data_which = conf["output_data"].as<string>(); - const bool output_data = output_data_which!=""; vector<string> print_weights; boost::split(print_weights, conf["print_weights"].as<string>(), boost::is_any_of(" ")); + const string output_updates_fn = conf["output_updates"].as<string>(); + const bool output_updates = output_updates_fn!=""; + const string output_raw_fn = conf["output_raw"].as<string>(); + const bool output_raw = output_raw_fn!=""; - // setup decoder and scorer + // setup decoder register_feature_functions(); SetSilent(true); ReadFile f(conf["decoder_conf"].as<string>()); Decoder decoder(f.stream()); + + // setup scorer & observer Scorer* scorer; if (score_name == "nakov") { - scorer = static_cast<PerSentenceBleuScorer*>(new PerSentenceBleuScorer(N)); + scorer = static_cast<NakovBleuScorer*>(new NakovBleuScorer(N, nakov_fix)); } else if (score_name == "papineni") { - scorer = static_cast<BleuScorer*>(new BleuScorer(N)); + scorer = static_cast<PapineniBleuScorer*>(new PapineniBleuScorer(N)); } else if (score_name == "lin") { - scorer = static_cast<OriginalPerSentenceBleuScorer*>\ - (new OriginalPerSentenceBleuScorer(N)); + scorer = static_cast<LinBleuScorer*>(new LinBleuScorer(N)); } else if (score_name == "liang") { - scorer = static_cast<SmoothPerSentenceBleuScorer*>\ - (new SmoothPerSentenceBleuScorer(N)); + scorer = static_cast<LiangBleuScorer*>(new LiangBleuScorer(N)); } else if (score_name == "chiang") { - scorer = static_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N)); + scorer = static_cast<ChiangBleuScorer*>(new ChiangBleuScorer(N)); + } else if (score_name == "sum") { + scorer = static_cast<SumBleuScorer*>(new SumBleuScorer(N)); } else { assert(false); } - ScoredKbest* observer = new ScoredKbest(k, scorer); + HypSampler* observer; + if (forest_sample) + observer = new KSampler(k, scorer); + else if (unique_kbest) + observer = new KBestSampler(k, scorer); + else + observer = new KBestNoFilterSampler(k, scorer); // weights vector<weight_t>& decoder_weights = decoder.CurrentWeightVector(); @@ -65,22 +85,46 @@ main(int argc, char** argv) string input_fn = conf["bitext"].as<string>(); ReadFile input(input_fn); vector<string> buf; // decoder only accepts strings as input - vector<vector<Ngrams> > buf_ngs; // compute ngrams and lengths of references - vector<vector<size_t> > buf_ls; // just once + vector<vector<Ngrams> > buffered_ngrams; // compute ngrams and lengths of references + vector<vector<size_t> > buffered_lengths; // (just once) size_t input_sz = 0; - cerr << _p4; + cerr << setprecision(4); // output configuration cerr << "Parameters:" << endl; cerr << setw(25) << "bitext " << "'" << input_fn << "'" << endl; cerr << setw(25) << "k " << k << endl; + if (unique_kbest && !forest_sample) + cerr << setw(25) << "unique k-best " << unique_kbest << endl; + if (forest_sample) + cerr << setw(25) << "forest " << forest_sample << endl; + if (all_pairs) + cerr << setw(25) << "all pairs " << all_pairs << endl; + else if (pro) + cerr << setw(25) << "PRO " << pro << endl; cerr << setw(25) << "score " << "'" << score_name << "'" << endl; + if (score_name == "nakov") + cerr << setw(25) << "nakov fix " << nakov_fix << endl; + if (score_name == "chiang") + cerr << setw(25) << "chiang decay " << chiang_decay << endl; cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; cerr << setw(25) << "learning rate " << eta << endl; cerr << setw(25) << "margin " << margin << endl; + if (!structured) { + cerr << setw(25) << "cut " << cut << endl; + cerr << setw(25) << "adjust " << adjust_cut << endl; + } else { + cerr << setw(25) << "struct. obj " << structured << endl; + } + if (threshold > 0) + cerr << setw(25) << "threshold " << threshold << endl; + if (max_up != numeric_limits<size_t>::max()) + cerr << setw(25) << "max up. " << max_up << endl; + if (noup) + cerr << setw(25) << "no up. " << noup << endl; cerr << setw(25) << "average " << average << endl; - cerr << setw(25) << "l1 reg " << l1_reg << endl; + cerr << setw(25) << "l1 reg. " << l1_reg << endl; cerr << setw(25) << "decoder conf " << "'" << conf["decoder_conf"].as<string>() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; @@ -89,6 +133,8 @@ main(int argc, char** argv) cerr << setw(25) << "weights in " << "'" << conf["input_weights"].as<string>() << "'" << endl; } + if (noup) + cerr << setw(25) << "no updates!" << endl; cerr << "(1 dot per processed input)" << endl; // meta @@ -96,6 +142,13 @@ main(int argc, char** argv) size_t best_iteration = 0; time_t total_time = 0.; + // output + WriteFile raw_out; + if (output_raw) raw_out.Init(output_raw_fn); + WriteFile updates_out; + if (output_updates) updates_out.Init(output_raw_fn); + + for (size_t t = 0; t < T; t++) // T iterations { @@ -120,16 +173,16 @@ main(int argc, char** argv) boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); buf.push_back(parts[0]); parts.erase(parts.begin()); - buf_ngs.push_back({}); - buf_ls.push_back({}); + buffered_ngrams.push_back({}); + buffered_lengths.push_back({}); for (auto s: parts) { vector<WordID> r; vector<string> toks; boost::split(toks, s, boost::is_any_of(" ")); for (auto tok: toks) r.push_back(TD::Convert(tok)); - buf_ngs.back().emplace_back(MakeNgrams(r, N)); - buf_ls.back().push_back(r.size()); + buffered_ngrams.back().emplace_back(ngrams(r, N)); + buffered_lengths.back().push_back(r.size()); } } } else { @@ -155,50 +208,54 @@ main(int argc, char** argv) // decode if (t > 0 || i > 0) lambdas.init_vector(&decoder_weights); - observer->SetReference(buf_ngs[i], buf_ls[i]); + observer->reference_ngrams = &buffered_ngrams[i]; + observer->reference_lengths = &buffered_lengths[i]; decoder.Decode(buf[i], observer); - vector<ScoredHyp>* samples = observer->GetSamples(); - - // stats for 1best - gold_sum += samples->front().gold; - model_sum += samples->front().model; - feature_count += observer->GetFeatureCount(); - list_sz += observer->GetSize(); - - if (output_data) { - if (output_data_which == "kbest") { - OutputKbest(samples); - } else if (output_data_which == "default") { - OutputMultipartitePairs(samples, margin); - } else if (output_data_which == "all") { - OutputAllPairs(samples); - } - } + vector<Hyp>* sample = &(observer->sample); + + // stats for 1-best + gold_sum += sample->front().gold; + model_sum += sample->front().model; + feature_count += observer->feature_count; + list_sz += observer->effective_size; + + if (output_raw) + output_sample(sample); - // get pairs and update + // update model if (!noup) { SparseVector<weight_t> updates; if (structured) - num_up += CollectUpdatesStruct(samples, updates); + num_up += update_structured(sample, updates, margin/*, + output_updates, updates_out.get()*/); // FIXME + else if (all_pairs) + num_up += updates_all(sample, updates, max_up, threshold/*, + output_updates, updates_out.get()*/); // FIXME + else if (pro) + num_up += updates_pro(sample, updates, cut, max_up, threshold/*, + output_updates, updates_out.get()*/); // FIXME else - num_up += CollectUpdates(samples, updates, margin); + num_up += updates_multipartite(sample, updates, cut, margin, + max_up, threshold, adjust_cut/*, + output_updates, updates_out.get()*/); // FIXME SparseVector<weight_t> lambdas_copy; if (l1_reg) lambdas_copy = lambdas; lambdas.plus_eq_v_times_s(updates, eta); - // update context for approx. BLEU + // update context for Chiang's approx. BLEU if (score_name == "chiang") { - for (auto it: *samples) { + for (auto it: *sample) { if (it.rank == 0) { - scorer->UpdateContext(it.w, buf_ngs[i], buf_ls[i], 0.9); + scorer->update_context(it.w, buffered_ngrams[i], + buffered_lengths[i], chiang_decay); break; } } } - // l1 regularization + // \ell_1 regularization // NB: regularization is done after each sentence, // not after every single pair! if (l1_reg) { @@ -234,19 +291,22 @@ main(int argc, char** argv) // stats weight_t gold_avg = gold_sum/(weight_t)input_sz; - cerr << _p << "WEIGHTS" << endl; + cerr << setiosflags(ios::showpos) << "WEIGHTS" << endl; for (auto name: print_weights) - cerr << setw(18) << name << " = " << lambdas.get(FD::Convert(name)) << endl; + cerr << setw(18) << name << " = " + << lambdas.get(FD::Convert(name)) << endl; cerr << " ---" << endl; - cerr << _np << " 1best avg score: " << gold_avg*100; - cerr << _p << " (" << (gold_avg-gold_prev)*100 << ")" << endl; + cerr << resetiosflags(ios::showpos) + << " 1best avg score: " << gold_avg*100; + cerr << setiosflags(ios::showpos) << " (" + << (gold_avg-gold_prev)*100 << ")" << endl; cerr << " 1best avg model score: " << model_sum/(weight_t)input_sz << endl; cerr << " avg # updates: "; - cerr << _np << num_up/(float)input_sz << endl; - cerr << " non-0 feature count: " << lambdas.num_nonzero() << endl; - cerr << " avg f count: " << feature_count/(float)list_sz << endl; - cerr << " avg list sz: " << list_sz/(float)input_sz << endl; + cerr << resetiosflags(ios::showpos) << num_up/(float)input_sz << endl; + cerr << " non-0 feature count: " << lambdas.num_nonzero() << endl; + cerr << " avg f count: " << feature_count/(float)list_sz << endl; + cerr << " avg list sz: " << list_sz/(float)input_sz << endl; if (gold_avg > best) { best = gold_avg; diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index 0bbb5c9b..18a7dbdc 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -22,59 +22,90 @@ namespace po = boost::program_options; namespace dtrain { -struct ScoredHyp +struct Hyp { + Hyp() {} + Hyp(vector<WordID> w, SparseVector<weight_t> f, weight_t model, weight_t gold, + size_t rank) : w(w), f(f), model(model), gold(gold), rank(rank) {} + vector<WordID> w; SparseVector<weight_t> f; weight_t model, gold; size_t rank; }; -inline void -PrintWordIDVec(vector<WordID>& v, ostream& os=cerr) -{ - for (size_t i = 0; i < v.size(); i++) { - os << TD::Convert(v[i]); - if (i < v.size()-1) os << " "; - } -} - -inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); } -inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); } -inline ostream& _p4(ostream& out) { return out << setprecision(4); } - bool -dtrain_init(int argc, char** argv, po::variables_map* conf) +dtrain_init(int argc, + char** argv, + po::variables_map* conf) { po::options_description opts("Configuration File Options"); opts.add_options() - ("bitext,b", po::value<string>(), "bitext") - ("decoder_conf,C", po::value<string>(), "configuration file for decoder") - ("iterations,T", po::value<size_t>()->default_value(15), "number of iterations T (per shard)") - ("k", po::value<size_t>()->default_value(100), "size of kbest list") - ("learning_rate,l", po::value<weight_t>()->default_value(0.00001), "learning rate") - ("l1_reg,r", po::value<weight_t>()->default_value(0.), "l1 regularization strength") - ("margin,m", po::value<weight_t>()->default_value(1.0), "margin for margin perceptron") - ("score,s", po::value<string>()->default_value("chiang"), "per-sentence BLEU approx.") - ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation") - ("input_weights,w", po::value<string>(), "input weights file") - ("average,a", po::bool_switch()->default_value(true), "output average weights") - ("keep,K", po::bool_switch()->default_value(false), "output a weight file per iteration") - ("struct,S", po::bool_switch()->default_value(false), "structured SGD with hope/fear") - ("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") - ("disable_learning,X", po::bool_switch()->default_value(false), "disable learning") - ("output_data,D", po::value<string>()->default_value(""), "output data to STDOUT; arg. is 'kbest', 'default' or 'all'") + ("bitext,b", po::value<string>(), + "bitext, source and references in a single file [e ||| f]") + ("decoder_conf,C", po::value<string>(), + "decoder configuration file") + ("iterations,T", po::value<size_t>()->default_value(15), + "number of iterations T") + ("k", po::value<size_t>()->default_value(100), + "sample size per input (e.g. size of k-best lists)") + ("unique_kbest", po::bool_switch()->default_value(true), + "unique k-best lists") + ("forest_sample", po::bool_switch()->default_value(false), + "sample k hyptheses from forest instead of using k-best list") + ("learning_rate,l", po::value<weight_t>()->default_value(0.00001), + "learning rate [only meaningful if margin>0 or input weights are given]") + ("l1_reg,r", po::value<weight_t>()->default_value(0.), + "l1 regularization strength [see Tsuruoka, Tsujii and Ananiadou (2009)]") + ("margin,m", po::value<weight_t>()->default_value(1.0), + "margin for margin perceptron [set =0 for standard perceptron]") + ("cut,u", po::value<weight_t>()->default_value(0.1), + "use top/bottom 10% (default) of k-best as 'good' and 'bad' for \ +pair sampling, 0 to use all pairs TODO") + ("adjust,A", po::bool_switch()->default_value(false), + "adjust cut for optimal pos. in k-best to cut") + ("score,s", po::value<string>()->default_value("chiang"), + "per-sentence BLEU (approx.)") + ("nakov_fix", po::value<weight_t>()->default_value(1.0), + "add to reference length [see score.h]") + ("chiang_decay", po::value<weight_t>()->default_value(0.9), + "decaying factor for Chiang's approx. BLEU") + ("N", po::value<size_t>()->default_value(4), + "N for BLEU approximation") + ("input_weights,w", po::value<string>(), + "weights to initialize model") + ("average,a", po::bool_switch()->default_value(true), + "output average weights") + ("keep,K", po::bool_switch()->default_value(false), + "output a weight file per iteration [as weights.T.gz]") + ("structured,S", po::bool_switch()->default_value(false), + "structured prediction objective [hope/fear] w/ SGD") + ("pro_sampling", po::bool_switch()->default_value(false), + "updates from pairs selected as shown in Fig.4 of (Hopkins and May, 2011) [Gamma=max_pairs (default 5000), Xi=cut (default 50); threshold default 0.05]") + ("threshold", po::value<weight_t>()->default_value(0.), + "(min.) threshold in terms of gold score for pair selection") + ("max_pairs", + po::value<size_t>()->default_value(numeric_limits<size_t>::max()), + "max. number of updates/pairs") + ("output,o", po::value<string>()->default_value("-"), + "output weights file, '-' for STDOUT") + ("disable_learning,X", po::bool_switch()->default_value(false), + "fix model") + ("output_updates,U", po::value<string>()->default_value(""), + "output updates (diff. vectors) [to filename]") + ("output_raw,R", po::value<string>()->default_value(""), + "output raw data (e.g. k-best lists) [to filename]") ("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"), - "list of weights to print after each iteration"); + "list of weights to print after each iteration"); po::options_description clopts("Command Line Options"); clopts.add_options() ("conf,c", po::value<string>(), "dtrain configuration file") ("help,h", po::bool_switch(), "display options"); opts.add(clopts); po::store(parse_command_line(argc, argv, opts), *conf); - cerr << "dtrain" << endl << endl; + cerr << "*dtrain*" << endl << endl; if ((*conf)["help"].as<bool>()) { - cerr << opts << endl; + cerr << setprecision(3) << opts << endl; return false; } @@ -90,20 +121,11 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) return false; } if (!conf->count("bitext")) { - cerr << "No input given." << endl; + cerr << "No input bitext." << endl; cerr << opts << endl; return false; } - if ((*conf)["output_data"].as<string>() != "") { - if ((*conf)["output_data"].as<string>() != "kbest" && - (*conf)["output_data"].as<string>() != "default" && - (*conf)["output_data"].as<string>() != "all") { - cerr << "Wrong 'output_data' argument: "; - cerr << (*conf)["output_data"].as<string>() << endl; - return false; - } - } return true; } |