From c914b17d0621eb6626a98f86e4d4e118cd589555 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Fri, 27 Apr 2012 16:07:12 +0200 Subject: improved readability, fixes --- dtrain/dtrain.cc | 47 ++++++++++------------------------------------- 1 file changed, 10 insertions(+), 37 deletions(-) (limited to 'dtrain/dtrain.cc') diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 3dee10f2..e817e7ab 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -30,7 +30,6 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("rescale", po::value()->zero_tokens(), "rescale weight vector after each input") ("l1_reg", po::value()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") ("l1_reg_strength", po::value(), "l1 regularization strength") - ("inc_correct", po::value()->zero_tokens(), "include correctly ranked pairs into updates") ("fselect", po::value()->default_value(-1), "TODO select top x percent of features after each epoch") ("approx_bleu_d", po::value()->default_value(0.9), "discount for approx. BLEU") #ifdef DTRAIN_LOCAL @@ -122,9 +121,6 @@ main(int argc, char** argv) HSReporter rep(task_id); bool keep = false; if (cfg.count("keep")) keep = true; - bool inc_correct = false; - if (cfg.count("inc_correct")) - inc_correct = true; const unsigned k = cfg["k"].as(); const unsigned N = cfg["N"].as(); @@ -226,7 +222,6 @@ main(int argc, char** argv) score_t max_score = 0.; unsigned best_it = 0; float overall_time = 0.; - unsigned pair_count = 0, feature_count = 0; // output cfg if (!quiet) { @@ -250,8 +245,6 @@ main(int argc, char** argv) cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; if (cfg.count("l1_reg")) cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as() << "'" << endl; - if (inc_correct) - cerr << setw(25) << "inc. correct " << inc_correct << endl; if (rescale) cerr << setw(25) << "rescale " << rescale << endl; cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as() << "'" << endl; @@ -420,36 +413,18 @@ main(int argc, char** argv) if (pair_sampling == "PRO") PROsampling(samples, pairs, pair_threshold); npairs += pairs.size(); - pair_count += 2*pairs.size(); for (vector >::iterator it = pairs.begin(); it != pairs.end(); it++) { - score_t rank_error = it->second.score - it->first.score; - feature_count += it->first.f.size() + it->second.f.size(); - if (!gamma) { - // perceptron - if (rank_error > 0) { - SparseVector diff_vec = it->second.f - it->first.f; - lambdas.plus_eq_v_times_s(diff_vec, eta); - rank_errors++; - } else { - if (inc_correct) { - SparseVector diff_vec = it->first.f - it->second.f; - lambdas.plus_eq_v_times_s(diff_vec, eta); - } - } - if (it->first.model - it->second.model < 1) margin_violations++; - } else { - // SVM - score_t margin = it->first.model - it->second.model; - if (rank_error > 0 || margin < 1) { - SparseVector diff_vec = it->second.f - it->first.f; - lambdas.plus_eq_v_times_s(diff_vec, eta); - if (rank_error > 0) rank_errors++; - if (margin < 1) margin_violations++; - } - // regularization - lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); + bool rank_error = it->first.model <= it->second.model; + if (rank_error) rank_errors++; + score_t margin = fabs(it->first.model - it->second.model); + if (!rank_error && margin < 1) margin_violations++; + if (rank_error || (gamma && margin<1)) { + SparseVector diff_vec = it->first.f - it->second.f; + lambdas.plus_eq_v_times_s(diff_vec, eta); + if (gamma) + lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); } } @@ -553,8 +528,6 @@ main(int argc, char** argv) cerr << " avg # margin viol: "; cerr << margin_violations/(float)in_sz << endl; cerr << " non0 feature count: " << nonz << endl; - cerr << " avg f count: "; - cerr << feature_count/(float)pair_count << endl; } if (hstreaming) { @@ -580,7 +553,7 @@ main(int argc, char** argv) overall_time += time_diff; if (!quiet) { cerr << _p2 << _np << "(time " << time_diff/60. << " min, "; - cerr << time_diff/(float)in_sz<< " s/S)" << endl; + cerr << time_diff/in_sz << " s/S)" << endl; } if (t+1 != T && !quiet) cerr << endl; -- cgit v1.2.3