summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-11-13 16:17:02 +0100
committerPatrick Simianer <p@simianer.de>2011-11-13 16:17:02 +0100
commitcc2fb8549e9729ecf2d61dc771a7c348feb106f6 (patch)
treef1182b36bbc3f642aed01f9a70f45fa57709a55e /dtrain/dtrain.cc
parentbf5dd9905851113f5ebb38f207b6218c37a4f113 (diff)
removed hgsampler, more stats, unit_weight_vector arg
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc26
1 files changed, 18 insertions, 8 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 5c95c7f1..79047fd9 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -19,12 +19,13 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_*, smooth_*, approx_*")
("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences")
("print_weights", po::value<string>(), "weights to print on each iteration")
- ("hstreaming", po::value<string>()->default_value("N/A"), "run in hadoop streaming mode, arg is a task id")
+ ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id")
("learning_rate", po::value<weight_t>()->default_value(0.0005), "learning rate")
("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)")
("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use")
("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")
("keep_w", po::value<bool>()->zero_tokens(), "protocol weights for each iteration")
+ ("unit_weight_vector", po::value<bool>()->zero_tokens(), "Rescale weight vector after each input")
#ifdef DTRAIN_LOCAL
("refs,r", po::value<string>(), "references in local mode")
#endif
@@ -46,7 +47,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
return false;
}
if (cfg->count("hstreaming") && (*cfg)["output"].as<string>() != "-") {
- cerr << "When using 'hstreaming' the 'output' param should be '-'.";
+ cerr << "When using 'hstreaming' the 'output' param should be '-'." << endl;
return false;
}
#ifdef DTRAIN_LOCAL
@@ -98,6 +99,8 @@ main(int argc, char** argv)
task_id = cfg["hstreaming"].as<string>();
cerr.precision(17);
}
+ bool unit_weight_vector = false;
+ if (cfg.count("unit_weight_vector")) unit_weight_vector = true;
HSReporter rep(task_id);
bool keep_w = false;
if (cfg.count("keep_w")) keep_w = true;
@@ -226,7 +229,7 @@ main(int argc, char** argv)
#endif
score_t score_sum = 0.;
score_t model_sum(0);
- unsigned ii = 0, nup = 0, npairs = 0;
+ unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0;
if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl;
while(true)
@@ -369,21 +372,25 @@ main(int argc, char** argv)
if (rank_error > 0) {
SparseVector<weight_t> diff_vec = it->second.f - it->first.f;
lambdas.plus_eq_v_times_s(diff_vec, eta);
- nup++;
+ rank_errors++;
}
+ if (margin < 1) margin_violations++;
} else {
// SVM
score_t margin = it->first.model - it->second.model;
if (rank_error > 0 || margin < 1) {
SparseVector<weight_t> diff_vec = it->second.f - it->first.f;
lambdas.plus_eq_v_times_s(diff_vec, eta);
- nup++;
+ if (rank_error > 0) rank_errors++;
+ if (margin < 1) margin_violations++;
}
// regularization
lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
}
}
}
+
+ if (unit_weight_vector && sample_from == "forest") lambdas /= lambdas.l2norm();
++ii;
@@ -437,15 +444,18 @@ main(int argc, char** argv)
cerr << _p << " (" << model_diff << ")" << endl;
cerr << " avg #pairs: ";
cerr << _np << npairs/(float)in_sz << endl;
- cerr << " avg #up: ";
- cerr << nup/(float)in_sz << endl;
+ cerr << " avg #rank err: ";
+ cerr << rank_errors/(float)in_sz << endl;
+ cerr << " avg #margin viol: ";
+ cerr << margin_violations/float(in_sz) << endl;
}
if (hstreaming) {
rep.update_counter("Score 1best avg #"+boost::lexical_cast<string>(t+1), score_avg);
rep.update_counter("Model 1best avg #"+boost::lexical_cast<string>(t+1), model_avg);
rep.update_counter("Pairs avg #"+boost::lexical_cast<string>(t+1), npairs/(weight_t)in_sz);
- rep.update_counter("Updates avg #"+boost::lexical_cast<string>(t+1), nup/(weight_t)in_sz);
+ rep.update_counter("Rank errors avg #"+boost::lexical_cast<string>(t+1), rank_errors/(weight_t)in_sz);
+ rep.update_counter("Margin violations avg #"+boost::lexical_cast<string>(t+1), margin_violations/(weight_t)in_sz);
unsigned nonz = (unsigned)lambdas.size_nonzero();
rep.update_counter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz);
rep.update_gcounter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz);