summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorPatrick Simianer <simianer@cl.uni-heidelberg.de>2012-04-27 16:07:12 +0200
committerPatrick Simianer <simianer@cl.uni-heidelberg.de>2012-04-27 16:07:12 +0200
commitc7ac569634c07de169a91c9f4d028ecd3899b4df (patch)
tree4e7ba79ba4265b8e54da3dad21e6398116f94153 /dtrain/dtrain.cc
parent0ac66e310d57f9aea5ddeea900c84df08abfe8c2 (diff)
improved readability, fixes
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc47
1 files changed, 10 insertions, 37 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 3dee10f2..e817e7ab 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -30,7 +30,6 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input")
("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)")
("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
- ("inc_correct", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates")
("fselect", po::value<weight_t>()->default_value(-1), "TODO select top x percent of features after each epoch")
("approx_bleu_d", po::value<score_t>()->default_value(0.9), "discount for approx. BLEU")
#ifdef DTRAIN_LOCAL
@@ -122,9 +121,6 @@ main(int argc, char** argv)
HSReporter rep(task_id);
bool keep = false;
if (cfg.count("keep")) keep = true;
- bool inc_correct = false;
- if (cfg.count("inc_correct"))
- inc_correct = true;
const unsigned k = cfg["k"].as<unsigned>();
const unsigned N = cfg["N"].as<unsigned>();
@@ -226,7 +222,6 @@ main(int argc, char** argv)
score_t max_score = 0.;
unsigned best_it = 0;
float overall_time = 0.;
- unsigned pair_count = 0, feature_count = 0;
// output cfg
if (!quiet) {
@@ -250,8 +245,6 @@ main(int argc, char** argv)
cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl;
if (cfg.count("l1_reg"))
cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl;
- if (inc_correct)
- cerr << setw(25) << "inc. correct " << inc_correct << endl;
if (rescale)
cerr << setw(25) << "rescale " << rescale << endl;
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
@@ -420,36 +413,18 @@ main(int argc, char** argv)
if (pair_sampling == "PRO")
PROsampling(samples, pairs, pair_threshold);
npairs += pairs.size();
- pair_count += 2*pairs.size();
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
- score_t rank_error = it->second.score - it->first.score;
- feature_count += it->first.f.size() + it->second.f.size();
- if (!gamma) {
- // perceptron
- if (rank_error > 0) {
- SparseVector<weight_t> diff_vec = it->second.f - it->first.f;
- lambdas.plus_eq_v_times_s(diff_vec, eta);
- rank_errors++;
- } else {
- if (inc_correct) {
- SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
- lambdas.plus_eq_v_times_s(diff_vec, eta);
- }
- }
- if (it->first.model - it->second.model < 1) margin_violations++;
- } else {
- // SVM
- score_t margin = it->first.model - it->second.model;
- if (rank_error > 0 || margin < 1) {
- SparseVector<weight_t> diff_vec = it->second.f - it->first.f;
- lambdas.plus_eq_v_times_s(diff_vec, eta);
- if (rank_error > 0) rank_errors++;
- if (margin < 1) margin_violations++;
- }
- // regularization
- lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
+ bool rank_error = it->first.model <= it->second.model;
+ if (rank_error) rank_errors++;
+ score_t margin = fabs(it->first.model - it->second.model);
+ if (!rank_error && margin < 1) margin_violations++;
+ if (rank_error || (gamma && margin<1)) {
+ SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
+ lambdas.plus_eq_v_times_s(diff_vec, eta);
+ if (gamma)
+ lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
}
}
@@ -553,8 +528,6 @@ main(int argc, char** argv)
cerr << " avg # margin viol: ";
cerr << margin_violations/(float)in_sz << endl;
cerr << " non0 feature count: " << nonz << endl;
- cerr << " avg f count: ";
- cerr << feature_count/(float)pair_count << endl;
}
if (hstreaming) {
@@ -580,7 +553,7 @@ main(int argc, char** argv)
overall_time += time_diff;
if (!quiet) {
cerr << _p2 << _np << "(time " << time_diff/60. << " min, ";
- cerr << time_diff/(float)in_sz<< " s/S)" << endl;
+ cerr << time_diff/in_sz << " s/S)" << endl;
}
if (t+1 != T && !quiet) cerr << endl;