diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 53 |
1 files changed, 27 insertions, 26 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index fb6c6880..e7a1244c 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -29,8 +29,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input") ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") - ("funny", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates") - ("fselect", po::value<weight_t>()->default_value(-1), "select top x percent of features after each epoch") + ("inc_correct", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates") + ("fselect", po::value<weight_t>()->default_value(-1), "TODO select top x percent of features after each epoch") #ifdef DTRAIN_LOCAL ("refs,r", po::value<string>(), "references in local mode") #endif @@ -113,9 +113,9 @@ main(int argc, char** argv) HSReporter rep(task_id); bool keep = false; if (cfg.count("keep")) keep = true; - bool funny = false; - if (cfg.count("funny")) - funny = true; + bool inc_correct = false; + if (cfg.count("inc_correct")) + inc_correct = true; const unsigned k = cfg["k"].as<unsigned>(); const unsigned N = cfg["N"].as<unsigned>(); @@ -158,10 +158,9 @@ main(int argc, char** argv) } vector<score_t> bleu_weights; scorer->Init(N, bleu_weights); - if (!quiet) cerr << setw(26) << "scorer '" << scorer_str << "'" << endl << endl; // setup decoder observer - MT19937 rng; // random number generator + MT19937 rng; // random number generator, only for forest sampling HypSampler* observer; if (sample_from == "kbest") observer = dynamic_cast<KBestGetter*>(new KBestGetter(k, filter_type)); @@ -225,6 +224,7 @@ main(int argc, char** argv) cerr << setw(25) << "k " << k << endl; cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; + cerr << setw(25) << "scorer '" << scorer_str << "'" << endl; cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; if (sample_from == "kbest") cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl; @@ -235,8 +235,8 @@ main(int argc, char** argv) cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; if (cfg.count("l1_reg")) cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl; - if (funny) - cerr << setw(25) << "funny " << funny << endl; + if (inc_correct) + cerr << setw(25) << "inc. correct " << inc_correct << endl; if (rescale) cerr << setw(25) << "rescale " << rescale << endl; cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl; @@ -246,7 +246,7 @@ main(int argc, char** argv) #endif cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; if (cfg.count("input_weights")) - cerr << setw(25) << "weights in" << cfg["input_weights"].as<string>() << endl; + cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl; if (cfg.count("stop-after")) cerr << setw(25) << "stop_after " << stop_after << endl; if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl; @@ -279,7 +279,7 @@ main(int argc, char** argv) } else { if (ii == in_sz) next = true; // stop if we reach the end of our input } - // stop after X sentences (but still iterate for those) + // stop after X sentences (but still go on for those) if (stop_after > 0 && stop_after == ii && !next) stop = true; // produce some pretty output @@ -323,14 +323,17 @@ main(int argc, char** argv) register_and_convert(ref_tok, ref_ids); ref_ids_buf.push_back(ref_ids); // process and set grammar - bool broken_grammar = true; + bool broken_grammar = true; // ignore broken grammars for (string::iterator it = in.begin(); it != in.end(); it++) { if (!isspace(*it)) { broken_grammar = false; break; } } - if (broken_grammar) continue; + if (broken_grammar) { + cerr << "Broken grammar for " << ii+1 << "! Ignoring this input." << endl; + continue; + } boost::replace_all(in, "\t", "\n"); in += "\n"; grammar_buf_out << in << DTRAIN_GRAMMAR_DELIM << " " << in_split[0] << endl; @@ -389,7 +392,7 @@ main(int argc, char** argv) } } - score_sum += (*samples)[0].score; + score_sum += (*samples)[0].score; // stats for 1best model_sum += (*samples)[0].model; // weight updates @@ -415,7 +418,7 @@ main(int argc, char** argv) lambdas.plus_eq_v_times_s(diff_vec, eta); rank_errors++; } else { - if (funny) { + if (inc_correct) { SparseVector<weight_t> diff_vec = it->first.f - it->second.f; lambdas.plus_eq_v_times_s(diff_vec, eta); } @@ -453,7 +456,7 @@ main(int argc, char** argv) } } } else if (l1cumul) { - weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input + weight_t acc_penalty = (ii+1) * l1_reg; // Note: ii is the index of the current input for (unsigned d = 0; d < lambdas.size(); d++) { if (lambdas.nonzero(d)) { weight_t v = lambdas.get(d); @@ -515,7 +518,7 @@ main(int argc, char** argv) model_diff = model_avg; } - unsigned nonz; + unsigned nonz = 0; if (!quiet || hstreaming) nonz = (unsigned)lambdas.size_nonzero(); if (!quiet) { @@ -524,18 +527,18 @@ main(int argc, char** argv) cerr << setw(18) << *it << " = " << lambdas.get(FD::Convert(*it)) << endl; } cerr << " ---" << endl; - cerr << _np << " 1best avg score: " << score_avg; + cerr << _np << " 1best avg score: " << score_avg; cerr << _p << " (" << score_diff << ")" << endl; - cerr << _np << "1best avg model score: " << model_avg; + cerr << _np << " 1best avg model score: " << model_avg; cerr << _p << " (" << model_diff << ")" << endl; - cerr << " avg #pairs: "; + cerr << " avg # pairs: "; cerr << _np << npairs/(float)in_sz << endl; - cerr << " avg #rank err: "; + cerr << " avg # rank err: "; cerr << rank_errors/(float)in_sz << endl; - cerr << " avg #margin viol: "; + cerr << " avg # margin viol: "; cerr << margin_violations/(float)in_sz << endl; - cerr << " non0 feature count: " << nonz << endl; - cerr << " avg f count: "; + cerr << " non0 feature count: " << nonz << endl; + cerr << " avg f count: "; cerr << feature_count/(float)pair_count << endl; } @@ -628,7 +631,5 @@ main(int argc, char** argv) cerr << best_it+1 << " [SCORE '" << scorer_str << "'=" << max_score << "]." << endl; cerr << _p2 << "This took " << overall_time/60. << " min." << endl; } - - return 0; } |