summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc53
1 files changed, 27 insertions, 26 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index fb6c6880..e7a1244c 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -29,8 +29,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input")
("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)")
("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
- ("funny", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates")
- ("fselect", po::value<weight_t>()->default_value(-1), "select top x percent of features after each epoch")
+ ("inc_correct", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates")
+ ("fselect", po::value<weight_t>()->default_value(-1), "TODO select top x percent of features after each epoch")
#ifdef DTRAIN_LOCAL
("refs,r", po::value<string>(), "references in local mode")
#endif
@@ -113,9 +113,9 @@ main(int argc, char** argv)
HSReporter rep(task_id);
bool keep = false;
if (cfg.count("keep")) keep = true;
- bool funny = false;
- if (cfg.count("funny"))
- funny = true;
+ bool inc_correct = false;
+ if (cfg.count("inc_correct"))
+ inc_correct = true;
const unsigned k = cfg["k"].as<unsigned>();
const unsigned N = cfg["N"].as<unsigned>();
@@ -158,10 +158,9 @@ main(int argc, char** argv)
}
vector<score_t> bleu_weights;
scorer->Init(N, bleu_weights);
- if (!quiet) cerr << setw(26) << "scorer '" << scorer_str << "'" << endl << endl;
// setup decoder observer
- MT19937 rng; // random number generator
+ MT19937 rng; // random number generator, only for forest sampling
HypSampler* observer;
if (sample_from == "kbest")
observer = dynamic_cast<KBestGetter*>(new KBestGetter(k, filter_type));
@@ -225,6 +224,7 @@ main(int argc, char** argv)
cerr << setw(25) << "k " << k << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
+ cerr << setw(25) << "scorer '" << scorer_str << "'" << endl;
cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;
if (sample_from == "kbest")
cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl;
@@ -235,8 +235,8 @@ main(int argc, char** argv)
cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl;
if (cfg.count("l1_reg"))
cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl;
- if (funny)
- cerr << setw(25) << "funny " << funny << endl;
+ if (inc_correct)
+ cerr << setw(25) << "inc. correct " << inc_correct << endl;
if (rescale)
cerr << setw(25) << "rescale " << rescale << endl;
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
@@ -246,7 +246,7 @@ main(int argc, char** argv)
#endif
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
if (cfg.count("input_weights"))
- cerr << setw(25) << "weights in" << cfg["input_weights"].as<string>() << endl;
+ cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl;
if (cfg.count("stop-after"))
cerr << setw(25) << "stop_after " << stop_after << endl;
if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl;
@@ -279,7 +279,7 @@ main(int argc, char** argv)
} else {
if (ii == in_sz) next = true; // stop if we reach the end of our input
}
- // stop after X sentences (but still iterate for those)
+ // stop after X sentences (but still go on for those)
if (stop_after > 0 && stop_after == ii && !next) stop = true;
// produce some pretty output
@@ -323,14 +323,17 @@ main(int argc, char** argv)
register_and_convert(ref_tok, ref_ids);
ref_ids_buf.push_back(ref_ids);
// process and set grammar
- bool broken_grammar = true;
+ bool broken_grammar = true; // ignore broken grammars
for (string::iterator it = in.begin(); it != in.end(); it++) {
if (!isspace(*it)) {
broken_grammar = false;
break;
}
}
- if (broken_grammar) continue;
+ if (broken_grammar) {
+ cerr << "Broken grammar for " << ii+1 << "! Ignoring this input." << endl;
+ continue;
+ }
boost::replace_all(in, "\t", "\n");
in += "\n";
grammar_buf_out << in << DTRAIN_GRAMMAR_DELIM << " " << in_split[0] << endl;
@@ -389,7 +392,7 @@ main(int argc, char** argv)
}
}
- score_sum += (*samples)[0].score;
+ score_sum += (*samples)[0].score; // stats for 1best
model_sum += (*samples)[0].model;
// weight updates
@@ -415,7 +418,7 @@ main(int argc, char** argv)
lambdas.plus_eq_v_times_s(diff_vec, eta);
rank_errors++;
} else {
- if (funny) {
+ if (inc_correct) {
SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
lambdas.plus_eq_v_times_s(diff_vec, eta);
}
@@ -453,7 +456,7 @@ main(int argc, char** argv)
}
}
} else if (l1cumul) {
- weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input
+ weight_t acc_penalty = (ii+1) * l1_reg; // Note: ii is the index of the current input
for (unsigned d = 0; d < lambdas.size(); d++) {
if (lambdas.nonzero(d)) {
weight_t v = lambdas.get(d);
@@ -515,7 +518,7 @@ main(int argc, char** argv)
model_diff = model_avg;
}
- unsigned nonz;
+ unsigned nonz = 0;
if (!quiet || hstreaming) nonz = (unsigned)lambdas.size_nonzero();
if (!quiet) {
@@ -524,18 +527,18 @@ main(int argc, char** argv)
cerr << setw(18) << *it << " = " << lambdas.get(FD::Convert(*it)) << endl;
}
cerr << " ---" << endl;
- cerr << _np << " 1best avg score: " << score_avg;
+ cerr << _np << " 1best avg score: " << score_avg;
cerr << _p << " (" << score_diff << ")" << endl;
- cerr << _np << "1best avg model score: " << model_avg;
+ cerr << _np << " 1best avg model score: " << model_avg;
cerr << _p << " (" << model_diff << ")" << endl;
- cerr << " avg #pairs: ";
+ cerr << " avg # pairs: ";
cerr << _np << npairs/(float)in_sz << endl;
- cerr << " avg #rank err: ";
+ cerr << " avg # rank err: ";
cerr << rank_errors/(float)in_sz << endl;
- cerr << " avg #margin viol: ";
+ cerr << " avg # margin viol: ";
cerr << margin_violations/(float)in_sz << endl;
- cerr << " non0 feature count: " << nonz << endl;
- cerr << " avg f count: ";
+ cerr << " non0 feature count: " << nonz << endl;
+ cerr << " avg f count: ";
cerr << feature_count/(float)pair_count << endl;
}
@@ -628,7 +631,5 @@ main(int argc, char** argv)
cerr << best_it+1 << " [SCORE '" << scorer_str << "'=" << max_score << "]." << endl;
cerr << _p2 << "This took " << overall_time/60. << " min." << endl;
}
-
- return 0;
}