summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-11-10 23:07:57 +0100
committerPatrick Simianer <p@simianer.de>2011-11-10 23:07:57 +0100
commitb7e58c8f9c96417d2530be21bd00662b343d6bcd (patch)
treecbcb890e9f3e76bd8c602af279db82b159fbf0f1 /dtrain/dtrain.cc
parent27498c35e5be9da1e05f48b3b67a425301bf9fd4 (diff)
some more reporting in hstreaming, keep weights option
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc52
1 files changed, 37 insertions, 15 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 4668ad66..2fe7afd7 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -24,6 +24,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)")
("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use")
("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")
+ ("keep_w", po::value<bool>()->zero_tokens(), "protocol weights for each iteration")
#ifdef DTRAIN_LOCAL
("refs,r", po::value<string>(), "references in local mode")
#endif
@@ -92,7 +93,12 @@ main(int argc, char** argv)
bool hstreaming = false;
if (cfg.count("hstreaming")) {
hstreaming = true;
+ quiet = true;
+ cerr.precision(17);
}
+ bool keep_w = false;
+ if (cfg.count("keep_w")) keep_w = true;
+
const unsigned k = cfg["k"].as<unsigned>();
const unsigned N = cfg["N"].as<unsigned>();
const unsigned T = cfg["epochs"].as<unsigned>();
@@ -104,7 +110,7 @@ main(int argc, char** argv)
vector<string> print_weights;
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));
-
+
// setup decoder
register_feature_functions();
SetSilent(true);
@@ -151,6 +157,7 @@ main(int argc, char** argv)
weight_t eta = cfg["learning_rate"].as<weight_t>();
weight_t gamma = cfg["gamma"].as<weight_t>();
+ // output
string output_fn = cfg["output"].as<string>();
// input
string input_fn = cfg["input"].as<string>();
@@ -158,9 +165,9 @@ main(int argc, char** argv)
// buffer input for t > 0
vector<string> src_str_buf; // source strings (decoder takes only strings)
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
- vector<string> weights_files; // remember weights for each iteration
// where temp files go
string tmp_path = cfg["tmp"].as<string>();
+ vector<string> w_tmp_files; // used for protocol_w
#ifdef DTRAIN_LOCAL
string refs_fn = cfg["refs"].as<string>();
ReadFile refs(refs_fn);
@@ -169,7 +176,7 @@ main(int argc, char** argv)
ogzstream grammar_buf_out;
grammar_buf_out.open(grammar_buf_fn.c_str());
#endif
-
+
unsigned in_sz = UINT_MAX; // input index, input size
vector<pair<score_t, score_t> > all_scores;
score_t max_score = 0.;
@@ -206,6 +213,8 @@ main(int argc, char** argv)
for (unsigned t = 0; t < T; t++) // T epochs
{
+ if (hstreaming) cerr << "reporter:status:Iteration #" << t+1 << " of " << T << endl;
+
time_t start, end;
time(&start);
#ifndef DTRAIN_LOCAL
@@ -231,7 +240,7 @@ main(int argc, char** argv)
if (stop_after > 0 && stop_after == ii && !next) stop = true;
// produce some pretty output
- if (!hstreaming && !quiet && !verbose) {
+ if (!quiet && !verbose) {
if (ii == 0) cerr << " ";
if ((ii+1) % (DTRAIN_DOTS) == 0) {
cerr << ".";
@@ -375,10 +384,12 @@ main(int argc, char** argv)
++ii;
- if (hstreaming) cerr << "reporter:counter:dtrain,sid," << ii << endl;
+ if (hstreaming) cerr << "reporter:counter:dtrain,count,1" << endl;
} // input loop
+ if (hstreaming && t == 0) cerr << "reporter:counter:dtrain,|input|," << ii+1 << endl;
+
if (scorer_str == "approx_bleu") scorer->Reset();
if (t == 0) {
@@ -404,6 +415,11 @@ main(int argc, char** argv)
score_diff = score_avg;
model_diff = model_avg;
}
+ if (hstreaming) {
+ cerr << "reporter:counter:dtrain,score avg it " << t+1 << "," << score_avg << endl;
+ cerr << "reporter:counter:dtrain,model avg it " << t+1 << "," << model_avg << endl;
+ }
+
if (!quiet) {
cerr << _p5 << _p << "WEIGHTS" << endl;
for (vector<string>::iterator it = print_weights.begin(); it != print_weights.end(); it++) {
@@ -439,12 +455,10 @@ main(int argc, char** argv)
if (noup) break;
// write weights to file
- if (select_weights == "best") {
- string infix = "dtrain-weights-" + boost::lexical_cast<string>(t);
+ if (select_weights == "best" || keep_w) {
lambdas.init_vector(&dense_weights);
- string w_fn = gettmpf(tmp_path, infix, "gz");
+ string w_fn = "weights." + boost::lexical_cast<string>(t) + ".gz";
Weights::WriteToFile(w_fn, dense_weights, true);
- weights_files.push_back(w_fn);
}
} // outer loop
@@ -467,18 +481,19 @@ main(int argc, char** argv)
} else if (select_weights == "VOID") { // do nothing with the weights
} else { // best
if (output_fn != "-") {
- CopyFile(weights_files[best_it], output_fn); // always gzipped
+ CopyFile("weights."+boost::lexical_cast<string>(best_it)+".gz", output_fn);
} else {
- ReadFile bestw(weights_files[best_it]);
+ ReadFile bestw("weights."+boost::lexical_cast<string>(best_it)+".gz");
string o;
cout.precision(17);
cout << _np;
while(getline(*bestw, o)) cout << o << endl;
}
- for (vector<string>::iterator it = weights_files.begin(); it != weights_files.end(); ++it) {
- unlink(it->c_str());
- it->erase(it->end()-3, it->end());
- unlink(it->c_str());
+ if (!keep_w) {
+ for (unsigned i = 0; i < T; i++) {
+ string s = "weights." + boost::lexical_cast<string>(i) + ".gz";
+ unlink(s.c_str());
+ }
}
}
if (output_fn == "-" && hstreaming) cout << "__SHARD_COUNT__\t1" << endl;
@@ -491,6 +506,13 @@ main(int argc, char** argv)
cerr << _p2 << "This took " << overall_time/60. << " min." << endl;
}
+ if (keep_w) {
+ cout << endl << "Weight files per iteration:" << endl;
+ for (unsigned i = 0; i < w_tmp_files.size(); i++) {
+ cout << w_tmp_files[i] << endl;
+ }
+ }
+
return 0;
}