diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index f679c9f6..0a94f7aa 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -45,21 +45,25 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) cerr << "When using 'hstreaming' the 'output' param should be '-'."; return false; } - if ((*cfg)["filter"].as<string>() != "unique" + if ((*cfg)["sample_from"].as<string>() != "kbest" + && (*cfg)["sample_from"].as<string>() != "forest") { + cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl; + return false; + } + if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "unique" && (*cfg)["filter"].as<string>() != "no") { cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'unique' or 'no'." << endl; + return false; } if ((*cfg)["pair_sampling"].as<string>() != "all" - && (*cfg)["pair_sampling"].as<string>() != "rand") { + && (*cfg)["pair_sampling"].as<string>() != "rand" && (*cfg)["pair_sampling"].as<string>() != "108010") { cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "', use 'all' or 'rand'." << endl; - } - if ((*cfg)["sample_from"].as<string>() != "kbest" - && (*cfg)["sample_from"].as<string>() != "forest") { - cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl; + return false; } if ((*cfg)["select_weights"].as<string>() != "last" - && (*cfg)["select_weights"].as<string>() != "best") { + && (*cfg)["select_weights"].as<string>() != "best" && (*cfg)["select_weights"].as<string>() != "VOID") { cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl; + return false; } return true; } @@ -410,27 +414,26 @@ main(int argc, char** argv) unlink(grammar_buf_fn.c_str()); if (!noup) { - if (!quiet) cerr << endl << "writing weights file to '" << output_fn << "' ..." << endl; + if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl; if (select_weights == "last") { // last - WriteFile out(output_fn); - ostream& o = *out.stream(); + WriteFile of(output_fn); // works with '-' + ostream& o = *of.stream(); o.precision(17); o << _np; for (SparseVector<double>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) { if (it->second == 0) continue; o << FD::Convert(it->first) << '\t' << it->second << endl; } - if (hstreaming) cout << "__SHARD_COUNT__\t1" << endl; - } else if (select_weights == "VOID") { // do nothing + } else if (select_weights == "VOID") { // do nothing with the weights } else { // best if (output_fn != "-") { - CopyFile(weights_files[best_it], output_fn); + CopyFile(weights_files[best_it], output_fn); // always gzipped } else { - ReadFile(weights_files[best_it]); + ReadFile bestw(weights_files[best_it]); string o; cout.precision(17); cout << _np; - while(getline(*input, o)) cout << o << endl; + while(getline(*bestw, o)) cout << o << endl; } for (vector<string>::iterator it = weights_files.begin(); it != weights_files.end(); ++it) { unlink(it->c_str()); @@ -438,6 +441,7 @@ main(int argc, char** argv) unlink(it->c_str()); } } + if (output_fn == "-" && hstreaming) cout << "__SHARD_COUNT__\t1" << endl; if (!quiet) cerr << "done" << endl; } |