summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-11-21 12:21:08 +0100
committerPatrick Simianer <p@simianer.de>2011-11-21 12:21:08 +0100
commita0a109329c942ddc956205cc66ccac872fb8f222 (patch)
tree2788f8f14afb951bcb80538947459d676cc298bc /dtrain/dtrain.cc
parent1ce8f44fe74cb6fb9223c4a2a4050019fed3be49 (diff)
added pro stuff,clean up
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc107
1 files changed, 64 insertions, 43 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 0853173f..3d3aa2d3 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -6,32 +6,33 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
{
po::options_description ini("Configuration File Options");
ini.add_options()
- ("input", po::value<string>()->default_value("-"), "input file")
- ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
- ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
- ("decoder_config", po::value<string>(), "configuration file for cdec")
- ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: kbest, forest")
- ("k", po::value<unsigned>()->default_value(100), "how many translations to sample")
- ("filter", po::value<string>()->default_value("unique"), "filter kbest list: no, unique")
- ("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, rand, 108010")
- ("N", po::value<unsigned>()->default_value(3), "N for Ngrams (BLEU)")
- ("epochs", po::value<unsigned>()->default_value(2), "# of iterations T (per shard)")
- ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_*, smooth_*, approx_*")
- ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences")
- ("print_weights", po::value<string>(), "weights to print on each iteration")
- ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id")
- ("learning_rate", po::value<weight_t>()->default_value(0.0005), "learning rate")
- ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)")
- ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use")
- ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")
- ("keep_w", po::value<bool>()->zero_tokens(), "protocol weights for each iteration")
- ("unit_weight_vector", po::value<bool>()->zero_tokens(), "Rescale weight vector after each input")
- ("l1_reg", po::value<string>()->default_value("no"), "apply l1 regularization as in Tsuroka et al 2010")
- ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
+ ("input", po::value<string>()->default_value("-"), "input file")
+ ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
+ ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
+ ("decoder_config", po::value<string>(), "configuration file for cdec")
+ ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: kbest, forest")
+ ("k", po::value<unsigned>()->default_value(100), "how many translations to sample")
+ ("filter", po::value<string>()->default_value("uniq"), "filter kbest list: no, uniq")
+ ("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, 5050, 108010, PRO")
+ ("N", po::value<unsigned>()->default_value(3), "N for Ngrams (BLEU)")
+ ("epochs", po::value<unsigned>()->default_value(2), "# of iterations T (per shard)")
+ ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_*, smooth_*, approx_*")
+ ("learning_rate", po::value<weight_t>()->default_value(0.0005), "learning rate")
+ ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)")
+ ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")
+ ("unit_wv", po::value<bool>()->zero_tokens(), "Rescale weight vector after each input")
+ ("l1_reg", po::value<string>()->default_value("no"), "apply l1 regularization as in Tsuroka et al 2010")
+ ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
+ ("update_ok", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates")
+ ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences")
+ ("keep_w", po::value<bool>()->zero_tokens(), "keep weights files for each iteration")
+ ("print_weights", po::value<string>(), "weights to print on each iteration")
+ ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id")
+ ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use")
#ifdef DTRAIN_LOCAL
- ("refs,r", po::value<string>(), "references in local mode")
+ ("refs,r", po::value<string>(), "references in local mode")
#endif
- ("noup", po::value<bool>()->zero_tokens(), "do not update weights");
+ ("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
("config,c", po::value<string>(), "dtrain config file")
@@ -63,13 +64,14 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
return false;
}
- if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "unique"
+ if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "uniq"
&& (*cfg)["filter"].as<string>() != "no") {
- cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'unique' or 'no'." << endl;
+ cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'no'." << endl;
return false;
}
if ((*cfg)["pair_sampling"].as<string>() != "all"
- && (*cfg)["pair_sampling"].as<string>() != "rand" && (*cfg)["pair_sampling"].as<string>() != "108010") {
+ && (*cfg)["pair_sampling"].as<string>() != "5050" && (*cfg)["pair_sampling"].as<string>() != "108010"
+ && (*cfg)["pair_sampling"].as<string>() != "PRO") {
cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "', use 'all' or 'rand'." << endl;
return false;
}
@@ -101,11 +103,14 @@ main(int argc, char** argv)
task_id = cfg["hstreaming"].as<string>();
cerr.precision(17);
}
- bool unit_weight_vector = false;
- if (cfg.count("unit_weight_vector")) unit_weight_vector = true;
+ bool unit_wv = false;
+ if (cfg.count("unit_wv")) unit_wv = true;
HSReporter rep(task_id);
bool keep_w = false;
if (cfg.count("keep_w")) keep_w = true;
+ bool update_ok = false;
+ if (cfg.count("update_ok"))
+ update_ok = true;
const unsigned k = cfg["k"].as<unsigned>();
const unsigned N = cfg["N"].as<unsigned>();
@@ -118,7 +123,7 @@ main(int argc, char** argv)
vector<string> print_weights;
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));
-
+
// setup decoder
register_feature_functions();
SetSilent(true);
@@ -187,7 +192,7 @@ main(int argc, char** argv)
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
// where temp files go
string tmp_path = cfg["tmp"].as<string>();
- vector<string> w_tmp_files; // used for protocol_w
+ vector<string> w_tmp_files; // used for keep_w
#ifdef DTRAIN_LOCAL
string refs_fn = cfg["refs"].as<string>();
ReadFile refs(refs_fn);
@@ -226,6 +231,12 @@ main(int argc, char** argv)
cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;
cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl;
cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl;
+ if (cfg.count("l1_reg"))
+ cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl;
+ if (update_ok)
+ cerr << setw(25) << "up ok " << update_ok << endl;
+ if (unit_wv)
+ cerr << setw(25) << "unit weight vec " << unit_wv << endl;
if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl;
}
@@ -320,7 +331,7 @@ main(int argc, char** argv)
// get buffered grammar
string grammar_str;
while (true) {
- string rule;
+ string rule;
getline(grammar_buf_in, rule);
if (boost::starts_with(rule, DTRAIN_GRAMMAR_DELIM)) break;
grammar_str += rule + "\n";
@@ -372,13 +383,15 @@ main(int argc, char** argv)
if (!noup) {
vector<pair<ScoredHyp,ScoredHyp> > pairs;
if (pair_sampling == "all")
- sample_all_pairs(samples, pairs);
- if (pair_sampling == "rand")
- sample_rand_pairs(samples, pairs, &rng);
+ all_pairs(samples, pairs);
+ if (pair_sampling == "5050")
+ rand_pairs_5050(samples, pairs, &rng);
if (pair_sampling == "108010")
- sample108010(samples, pairs);
+ multpart108010(samples, pairs);
+ if (pair_sampling == "PRO")
+ PROsampling(samples, pairs);
npairs += pairs.size();
-
+
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
score_t rank_error = it->second.score - it->first.score;
@@ -388,6 +401,11 @@ main(int argc, char** argv)
SparseVector<weight_t> diff_vec = it->second.f - it->first.f;
lambdas.plus_eq_v_times_s(diff_vec, eta);
rank_errors++;
+ } else {
+ if (update_ok) {
+ SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
+ lambdas.plus_eq_v_times_s(diff_vec, eta);
+ }
}
if (it->first.model - it->second.model < 1) margin_violations++;
} else {
@@ -404,6 +422,8 @@ main(int argc, char** argv)
}
}
+ ////////
+ // TEST THIS
// reset cumulative_penalties after 1 iter?
// do this only once per INPUT (not per pair)
if (l1naive) {
@@ -439,8 +459,9 @@ main(int argc, char** argv)
}
}
}
+ ////////
- if (unit_weight_vector && sample_from == "forest") lambdas /= lambdas.l2norm();
+ if (unit_wv && sample_from == "forest") lambdas /= lambdas.l2norm();
++ii;
@@ -501,11 +522,11 @@ main(int argc, char** argv)
}
if (hstreaming) {
- rep.update_counter("Score 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(score_avg*_SCALE));
- rep.update_counter("Model 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(model_avg*_SCALE));
- rep.update_counter("Pairs avg #"+boost::lexical_cast<string>(t+1), (unsigned)((npairs/(weight_t)in_sz)*_SCALE));
- rep.update_counter("Rank errors avg #"+boost::lexical_cast<string>(t+1), (unsigned)((rank_errors/(weight_t)in_sz)*_SCALE));
- rep.update_counter("Margin violations avg #"+boost::lexical_cast<string>(t+1), (unsigned)((margin_violations/(weight_t)in_sz)*_SCALE));
+ rep.update_counter("Score 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(score_avg*DTRAIN_SCALE));
+ rep.update_counter("Model 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(model_avg*DTRAIN_SCALE));
+ rep.update_counter("Pairs avg #"+boost::lexical_cast<string>(t+1), (unsigned)((npairs/(weight_t)in_sz)*DTRAIN_SCALE));
+ rep.update_counter("Rank errors avg #"+boost::lexical_cast<string>(t+1), (unsigned)((rank_errors/(weight_t)in_sz)*DTRAIN_SCALE));
+ rep.update_counter("Margin violations avg #"+boost::lexical_cast<string>(t+1), (unsigned)((margin_violations/(weight_t)in_sz)*DTRAIN_SCALE));
unsigned nonz = (unsigned)lambdas.size_nonzero();
rep.update_counter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz);
rep.update_gcounter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz);