summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2013-03-18 05:14:48 -0700
committerChris Dyer <redpony@gmail.com>2013-03-18 05:14:48 -0700
commit3aeab176d9068b13e3ca3394be4f9089f5952517 (patch)
treef0d458ee427a3dd3632c99ea7febe463dc571e07 /training/dtrain/dtrain.cc
parent4f452c5bf5cd0ed3cb50d31012f93a50366b3aac (diff)
parenta416615b81380d664246f11a8047098c59185838 (diff)
Merge pull request #17 from pks/master
dtrain
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc204
1 files changed, 50 insertions, 154 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 18286668..149f87d4 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -6,15 +6,14 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
{
po::options_description ini("Configuration File Options");
ini.add_options()
- ("input", po::value<string>()->default_value("-"), "input file")
+ ("input", po::value<string>()->default_value("-"), "input file (src)")
+ ("refs,r", po::value<string>(), "references")
("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
("decoder_config", po::value<string>(), "configuration file for cdec")
("print_weights", po::value<string>(), "weights to print on each iteration")
("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences")
- ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use")
("keep", po::value<bool>()->zero_tokens(), "keep weights files for each iteration")
- ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id")
("epochs", po::value<unsigned>()->default_value(10), "# of iterations T (per shard)")
("k", po::value<unsigned>()->default_value(100), "how many translations to sample")
("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: 'kbest', 'forest'")
@@ -28,16 +27,13 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("gamma", po::value<weight_t>()->default_value(0.), "gamma for SVM (0 for perceptron)")
("select_weights", po::value<string>()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)")
("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input")
- ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)")
+ ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010) UNTESTED")
("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
("fselect", po::value<weight_t>()->default_value(-1), "select top x percent (or by threshold) of features after each epoch NOT IMPLEMENTED") // TODO
("approx_bleu_d", po::value<score_t>()->default_value(0.9), "discount for approx. BLEU")
("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair")
("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near")
("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")
-#ifdef DTRAIN_LOCAL
- ("refs,r", po::value<string>(), "references in local mode")
-#endif
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -55,16 +51,6 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << cl << endl;
return false;
}
- if (cfg->count("hstreaming") && (*cfg)["output"].as<string>() != "-") {
- cerr << "When using 'hstreaming' the 'output' param should be '-'." << endl;
- return false;
- }
-#ifdef DTRAIN_LOCAL
- if ((*cfg)["input"].as<string>() == "-") {
- cerr << "Can't use stdin as input with this binary. Recompile without DTRAIN_LOCAL" << endl;
- return false;
- }
-#endif
if ((*cfg)["sample_from"].as<string>() != "kbest"
&& (*cfg)["sample_from"].as<string>() != "forest") {
cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
@@ -111,17 +97,8 @@ main(int argc, char** argv)
if (cfg.count("verbose")) verbose = true;
bool noup = false;
if (cfg.count("noup")) noup = true;
- bool hstreaming = false;
- string task_id;
- if (cfg.count("hstreaming")) {
- hstreaming = true;
- quiet = true;
- task_id = cfg["hstreaming"].as<string>();
- cerr.precision(17);
- }
bool rescale = false;
if (cfg.count("rescale")) rescale = true;
- HSReporter rep(task_id);
bool keep = false;
if (cfg.count("keep")) keep = true;
@@ -148,6 +125,7 @@ main(int argc, char** argv)
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));
+
// setup decoder
register_feature_functions();
SetSilent(true);
@@ -163,6 +141,8 @@ main(int argc, char** argv)
scorer = dynamic_cast<BleuScorer*>(new BleuScorer);
} else if (scorer_str == "stupid_bleu") {
scorer = dynamic_cast<StupidBleuScorer*>(new StupidBleuScorer);
+ } else if (scorer_str == "fixed_stupid_bleu") {
+ scorer = dynamic_cast<FixedStupidBleuScorer*>(new FixedStupidBleuScorer);
} else if (scorer_str == "smooth_bleu") {
scorer = dynamic_cast<SmoothBleuScorer*>(new SmoothBleuScorer);
} else if (scorer_str == "sum_bleu") {
@@ -201,6 +181,11 @@ main(int argc, char** argv)
weight_t eta = cfg["learning_rate"].as<weight_t>();
weight_t gamma = cfg["gamma"].as<weight_t>();
+ // faster perceptron: consider only misranked pairs, see
+ // DO NOT ENABLE WITH SVM (gamma > 0) OR loss_margin!
+ bool faster_perceptron = false;
+ if (gamma==0 && loss_margin==0) faster_perceptron = true;
+
// l1 regularization
bool l1naive = false;
bool l1clip = false;
@@ -222,16 +207,8 @@ main(int argc, char** argv)
// buffer input for t > 0
vector<string> src_str_buf; // source strings (decoder takes only strings)
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
- // where temp files go
- string tmp_path = cfg["tmp"].as<string>();
-#ifdef DTRAIN_LOCAL
string refs_fn = cfg["refs"].as<string>();
ReadFile refs(refs_fn);
-#else
- string grammar_buf_fn = gettmpf(tmp_path, "dtrain-grammars");
- ogzstream grammar_buf_out;
- grammar_buf_out.open(grammar_buf_fn.c_str());
-#endif
unsigned in_sz = std::numeric_limits<unsigned>::max(); // input index, input size
vector<pair<score_t, score_t> > all_scores;
@@ -246,7 +223,7 @@ main(int argc, char** argv)
cerr << setw(25) << "k " << k << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
- cerr << setw(25) << "scorer '" << scorer_str << "'" << endl;
+ cerr << setw(26) << "scorer '" << scorer_str << "'" << endl;
if (scorer_str == "approx_bleu")
cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl;
cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;
@@ -256,6 +233,7 @@ main(int argc, char** argv)
else cerr << setw(25) << "learning rate " << "bleu diff" << endl;
cerr << setw(25) << "gamma " << gamma << endl;
cerr << setw(25) << "loss margin " << loss_margin << endl;
+ cerr << setw(25) << "faster perceptron " << faster_perceptron << endl;
cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl;
if (pair_sampling == "XYX")
cerr << setw(25) << "hi lo " << hi_lo << endl;
@@ -268,9 +246,7 @@ main(int argc, char** argv)
cerr << setw(25) << "max pairs " << max_pairs << endl;
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
-#ifdef DTRAIN_LOCAL
cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
-#endif
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
if (cfg.count("input_weights"))
cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl;
@@ -283,14 +259,8 @@ main(int argc, char** argv)
for (unsigned t = 0; t < T; t++) // T epochs
{
- if (hstreaming) cerr << "reporter:status:Iteration #" << t+1 << " of " << T << endl;
-
time_t start, end;
time(&start);
-#ifndef DTRAIN_LOCAL
- igzstream grammar_buf_in;
- if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str());
-#endif
score_t score_sum = 0.;
score_t model_sum(0);
unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0;
@@ -338,52 +308,6 @@ main(int argc, char** argv)
// getting input
vector<WordID> ref_ids; // reference as vector<WordID>
-#ifndef DTRAIN_LOCAL
- vector<string> in_split; // input: sid\tsrc\tref\tpsg
- if (t == 0) {
- // handling input
- split_in(in, in_split);
- if (hstreaming && ii == 0) cerr << "reporter:counter:" << task_id << ",First ID," << in_split[0] << endl;
- // getting reference
- vector<string> ref_tok;
- boost::split(ref_tok, in_split[2], boost::is_any_of(" "));
- register_and_convert(ref_tok, ref_ids);
- ref_ids_buf.push_back(ref_ids);
- // process and set grammar
- bool broken_grammar = true; // ignore broken grammars
- for (string::iterator it = in.begin(); it != in.end(); it++) {
- if (!isspace(*it)) {
- broken_grammar = false;
- break;
- }
- }
- if (broken_grammar) {
- cerr << "Broken grammar for " << ii+1 << "! Ignoring this input." << endl;
- continue;
- }
- boost::replace_all(in, "\t", "\n");
- in += "\n";
- grammar_buf_out << in << DTRAIN_GRAMMAR_DELIM << " " << in_split[0] << endl;
- decoder.AddSupplementalGrammarFromString(in);
- src_str_buf.push_back(in_split[1]);
- // decode
- observer->SetRef(ref_ids);
- decoder.Decode(in_split[1], observer);
- } else {
- // get buffered grammar
- string grammar_str;
- while (true) {
- string rule;
- getline(grammar_buf_in, rule);
- if (boost::starts_with(rule, DTRAIN_GRAMMAR_DELIM)) break;
- grammar_str += rule + "\n";
- }
- decoder.AddSupplementalGrammarFromString(grammar_str);
- // decode
- observer->SetRef(ref_ids_buf[ii]);
- decoder.Decode(src_str_buf[ii], observer);
- }
-#else
if (t == 0) {
string r_;
getline(*refs, r_);
@@ -400,7 +324,6 @@ main(int argc, char** argv)
decoder.Decode(in, observer);
else
decoder.Decode(src_str_buf[ii], observer);
-#endif
// get (scored) samples
vector<ScoredHyp>* samples = observer->GetSamples();
@@ -430,25 +353,26 @@ main(int argc, char** argv)
// get pairs
vector<pair<ScoredHyp,ScoredHyp> > pairs;
if (pair_sampling == "all")
- all_pairs(samples, pairs, pair_threshold, max_pairs);
+ all_pairs(samples, pairs, pair_threshold, max_pairs, faster_perceptron);
if (pair_sampling == "XYX")
- partXYX(samples, pairs, pair_threshold, max_pairs, hi_lo);
+ partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo);
if (pair_sampling == "PRO")
PROsampling(samples, pairs, pair_threshold, max_pairs);
npairs += pairs.size();
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
-#ifdef DTRAIN_FASTER_PERCEPTRON
- bool rank_error = true; // pair sampling already did this for us
- rank_errors++;
- score_t margin = std::numeric_limits<float>::max();
-#else
- bool rank_error = it->first.model <= it->second.model;
+ bool rank_error;
+ score_t margin;
+ if (faster_perceptron) { // we only have considering misranked pairs
+ rank_error = true; // pair sampling already did this for us
+ margin = std::numeric_limits<float>::max();
+ } else {
+ rank_error = it->first.model <= it->second.model;
+ margin = fabs(fabs(it->first.model) - fabs(it->second.model));
+ if (!rank_error && margin < loss_margin) margin_violations++;
+ }
if (rank_error) rank_errors++;
- score_t margin = fabs(fabs(it->first.model) - fabs(it->second.model));
- if (!rank_error && margin < loss_margin) margin_violations++;
-#endif
if (scale_bleu_diff) eta = it->first.score - it->second.score;
if (rank_error || margin < loss_margin) {
SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
@@ -459,35 +383,40 @@ main(int argc, char** argv)
}
// l1 regularization
+ // please note that this penalizes _all_ weights
+ // (contrary to only the ones changed by the last update)
+ // after a _sentence_ (not after each example/pair)
if (l1naive) {
- for (unsigned d = 0; d < lambdas.size(); d++) {
- weight_t v = lambdas.get(d);
- lambdas.set_value(d, v - sign(v) * l1_reg);
+ FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ for (; it != lambdas.end(); ++it) {
+ it->second -= sign(it->second) * l1_reg;
}
} else if (l1clip) {
- for (unsigned d = 0; d < lambdas.size(); d++) {
- if (lambdas.nonzero(d)) {
- weight_t v = lambdas.get(d);
+ FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ for (; it != lambdas.end(); ++it) {
+ if (it->second != 0) {
+ weight_t v = it->second;
if (v > 0) {
- lambdas.set_value(d, max(0., v - l1_reg));
+ it->second = max(0., v - l1_reg);
} else {
- lambdas.set_value(d, min(0., v + l1_reg));
+ it->second = min(0., v + l1_reg);
}
}
}
} else if (l1cumul) {
weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input
- for (unsigned d = 0; d < lambdas.size(); d++) {
- if (lambdas.nonzero(d)) {
- weight_t v = lambdas.get(d);
- weight_t penalty = 0;
+ FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ for (; it != lambdas.end(); ++it) {
+ if (it->second != 0) {
+ weight_t v = it->second;
+ weight_t penalized = 0.;
if (v > 0) {
- penalty = max(0., v-(acc_penalty + cumulative_penalties.get(d)));
+ penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first)));
} else {
- penalty = min(0., v+(acc_penalty - cumulative_penalties.get(d)));
+ penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first)));
}
- lambdas.set_value(d, penalty);
- cumulative_penalties.set_value(d, cumulative_penalties.get(d)+penalty);
+ it->second = penalized;
+ cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);
}
}
}
@@ -498,11 +427,6 @@ main(int argc, char** argv)
++ii;
- if (hstreaming) {
- rep.update_counter("Seen #"+boost::lexical_cast<string>(t+1), 1u);
- rep.update_counter("Seen", 1u);
- }
-
} // input loop
if (average) w_average += lambdas;
@@ -511,21 +435,8 @@ main(int argc, char** argv)
if (t == 0) {
in_sz = ii; // remember size of input (# lines)
- if (hstreaming) {
- rep.update_counter("|Input|", ii);
- rep.update_gcounter("|Input|", ii);
- rep.update_gcounter("Shards", 1u);
- }
}
-#ifndef DTRAIN_LOCAL
- if (t == 0) {
- grammar_buf_out.close();
- } else {
- grammar_buf_in.close();
- }
-#endif
-
// print some stats
score_t score_avg = score_sum/(score_t)in_sz;
score_t model_avg = model_sum/(score_t)in_sz;
@@ -539,7 +450,7 @@ main(int argc, char** argv)
}
unsigned nonz = 0;
- if (!quiet || hstreaming) nonz = (unsigned)lambdas.num_nonzero();
+ if (!quiet) nonz = (unsigned)lambdas.num_nonzero();
if (!quiet) {
cerr << _p5 << _p << "WEIGHTS" << endl;
@@ -552,28 +463,18 @@ main(int argc, char** argv)
cerr << _np << " 1best avg model score: " << model_avg;
cerr << _p << " (" << model_diff << ")" << endl;
cerr << " avg # pairs: ";
- cerr << _np << npairs/(float)in_sz << endl;
+ cerr << _np << npairs/(float)in_sz;
+ if (faster_perceptron) cerr << " (meaningless)";
+ cerr << endl;
cerr << " avg # rank err: ";
cerr << rank_errors/(float)in_sz << endl;
-#ifndef DTRAIN_FASTER_PERCEPTRON
cerr << " avg # margin viol: ";
cerr << margin_violations/(float)in_sz << endl;
-#endif
cerr << " non0 feature count: " << nonz << endl;
cerr << " avg list sz: " << list_sz/(float)in_sz << endl;
cerr << " avg f count: " << f_count/(float)list_sz << endl;
}
- if (hstreaming) {
- rep.update_counter("Score 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(score_avg*DTRAIN_SCALE));
- rep.update_counter("Model 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(model_avg*DTRAIN_SCALE));
- rep.update_counter("Pairs avg #"+boost::lexical_cast<string>(t+1), (unsigned)((npairs/(weight_t)in_sz)*DTRAIN_SCALE));
- rep.update_counter("Rank errors avg #"+boost::lexical_cast<string>(t+1), (unsigned)((rank_errors/(weight_t)in_sz)*DTRAIN_SCALE));
- rep.update_counter("Margin violations avg #"+boost::lexical_cast<string>(t+1), (unsigned)((margin_violations/(weight_t)in_sz)*DTRAIN_SCALE));
- rep.update_counter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz);
- rep.update_gcounter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz);
- }
-
pair<score_t,score_t> remember;
remember.first = score_avg;
remember.second = model_avg;
@@ -604,10 +505,6 @@ main(int argc, char** argv)
if (average) w_average /= (weight_t)T;
-#ifndef DTRAIN_LOCAL
- unlink(grammar_buf_fn.c_str());
-#endif
-
if (!noup) {
if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl;
if (select_weights == "last" || average) { // last, average
@@ -644,7 +541,6 @@ main(int argc, char** argv)
}
}
}
- if (output_fn == "-" && hstreaming) cout << "__SHARD_COUNT__\t1" << endl;
if (!quiet) cerr << "done" << endl;
}