diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 52 |
1 files changed, 37 insertions, 15 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 4668ad66..2fe7afd7 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -24,6 +24,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)") ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use") ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") + ("keep_w", po::value<bool>()->zero_tokens(), "protocol weights for each iteration") #ifdef DTRAIN_LOCAL ("refs,r", po::value<string>(), "references in local mode") #endif @@ -92,7 +93,12 @@ main(int argc, char** argv) bool hstreaming = false; if (cfg.count("hstreaming")) { hstreaming = true; + quiet = true; + cerr.precision(17); } + bool keep_w = false; + if (cfg.count("keep_w")) keep_w = true; + const unsigned k = cfg["k"].as<unsigned>(); const unsigned N = cfg["N"].as<unsigned>(); const unsigned T = cfg["epochs"].as<unsigned>(); @@ -104,7 +110,7 @@ main(int argc, char** argv) vector<string> print_weights; if (cfg.count("print_weights")) boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); - + // setup decoder register_feature_functions(); SetSilent(true); @@ -151,6 +157,7 @@ main(int argc, char** argv) weight_t eta = cfg["learning_rate"].as<weight_t>(); weight_t gamma = cfg["gamma"].as<weight_t>(); + // output string output_fn = cfg["output"].as<string>(); // input string input_fn = cfg["input"].as<string>(); @@ -158,9 +165,9 @@ main(int argc, char** argv) // buffer input for t > 0 vector<string> src_str_buf; // source strings (decoder takes only strings) vector<vector<WordID> > ref_ids_buf; // references as WordID vecs - vector<string> weights_files; // remember weights for each iteration // where temp files go string tmp_path = cfg["tmp"].as<string>(); + vector<string> w_tmp_files; // used for protocol_w #ifdef DTRAIN_LOCAL string refs_fn = cfg["refs"].as<string>(); ReadFile refs(refs_fn); @@ -169,7 +176,7 @@ main(int argc, char** argv) ogzstream grammar_buf_out; grammar_buf_out.open(grammar_buf_fn.c_str()); #endif - + unsigned in_sz = UINT_MAX; // input index, input size vector<pair<score_t, score_t> > all_scores; score_t max_score = 0.; @@ -206,6 +213,8 @@ main(int argc, char** argv) for (unsigned t = 0; t < T; t++) // T epochs { + if (hstreaming) cerr << "reporter:status:Iteration #" << t+1 << " of " << T << endl; + time_t start, end; time(&start); #ifndef DTRAIN_LOCAL @@ -231,7 +240,7 @@ main(int argc, char** argv) if (stop_after > 0 && stop_after == ii && !next) stop = true; // produce some pretty output - if (!hstreaming && !quiet && !verbose) { + if (!quiet && !verbose) { if (ii == 0) cerr << " "; if ((ii+1) % (DTRAIN_DOTS) == 0) { cerr << "."; @@ -375,10 +384,12 @@ main(int argc, char** argv) ++ii; - if (hstreaming) cerr << "reporter:counter:dtrain,sid," << ii << endl; + if (hstreaming) cerr << "reporter:counter:dtrain,count,1" << endl; } // input loop + if (hstreaming && t == 0) cerr << "reporter:counter:dtrain,|input|," << ii+1 << endl; + if (scorer_str == "approx_bleu") scorer->Reset(); if (t == 0) { @@ -404,6 +415,11 @@ main(int argc, char** argv) score_diff = score_avg; model_diff = model_avg; } + if (hstreaming) { + cerr << "reporter:counter:dtrain,score avg it " << t+1 << "," << score_avg << endl; + cerr << "reporter:counter:dtrain,model avg it " << t+1 << "," << model_avg << endl; + } + if (!quiet) { cerr << _p5 << _p << "WEIGHTS" << endl; for (vector<string>::iterator it = print_weights.begin(); it != print_weights.end(); it++) { @@ -439,12 +455,10 @@ main(int argc, char** argv) if (noup) break; // write weights to file - if (select_weights == "best") { - string infix = "dtrain-weights-" + boost::lexical_cast<string>(t); + if (select_weights == "best" || keep_w) { lambdas.init_vector(&dense_weights); - string w_fn = gettmpf(tmp_path, infix, "gz"); + string w_fn = "weights." + boost::lexical_cast<string>(t) + ".gz"; Weights::WriteToFile(w_fn, dense_weights, true); - weights_files.push_back(w_fn); } } // outer loop @@ -467,18 +481,19 @@ main(int argc, char** argv) } else if (select_weights == "VOID") { // do nothing with the weights } else { // best if (output_fn != "-") { - CopyFile(weights_files[best_it], output_fn); // always gzipped + CopyFile("weights."+boost::lexical_cast<string>(best_it)+".gz", output_fn); } else { - ReadFile bestw(weights_files[best_it]); + ReadFile bestw("weights."+boost::lexical_cast<string>(best_it)+".gz"); string o; cout.precision(17); cout << _np; while(getline(*bestw, o)) cout << o << endl; } - for (vector<string>::iterator it = weights_files.begin(); it != weights_files.end(); ++it) { - unlink(it->c_str()); - it->erase(it->end()-3, it->end()); - unlink(it->c_str()); + if (!keep_w) { + for (unsigned i = 0; i < T; i++) { + string s = "weights." + boost::lexical_cast<string>(i) + ".gz"; + unlink(s.c_str()); + } } } if (output_fn == "-" && hstreaming) cout << "__SHARD_COUNT__\t1" << endl; @@ -491,6 +506,13 @@ main(int argc, char** argv) cerr << _p2 << "This took " << overall_time/60. << " min." << endl; } + if (keep_w) { + cout << endl << "Weight files per iteration:" << endl; + for (unsigned i = 0; i < w_tmp_files.size(); i++) { + cout << w_tmp_files[i] << endl; + } + } + return 0; } |