summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc34
1 files changed, 19 insertions, 15 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index f679c9f6..0a94f7aa 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -45,21 +45,25 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "When using 'hstreaming' the 'output' param should be '-'.";
return false;
}
- if ((*cfg)["filter"].as<string>() != "unique"
+ if ((*cfg)["sample_from"].as<string>() != "kbest"
+ && (*cfg)["sample_from"].as<string>() != "forest") {
+ cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
+ return false;
+ }
+ if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "unique"
&& (*cfg)["filter"].as<string>() != "no") {
cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'unique' or 'no'." << endl;
+ return false;
}
if ((*cfg)["pair_sampling"].as<string>() != "all"
- && (*cfg)["pair_sampling"].as<string>() != "rand") {
+ && (*cfg)["pair_sampling"].as<string>() != "rand" && (*cfg)["pair_sampling"].as<string>() != "108010") {
cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "', use 'all' or 'rand'." << endl;
- }
- if ((*cfg)["sample_from"].as<string>() != "kbest"
- && (*cfg)["sample_from"].as<string>() != "forest") {
- cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
+ return false;
}
if ((*cfg)["select_weights"].as<string>() != "last"
- && (*cfg)["select_weights"].as<string>() != "best") {
+ && (*cfg)["select_weights"].as<string>() != "best" && (*cfg)["select_weights"].as<string>() != "VOID") {
cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl;
+ return false;
}
return true;
}
@@ -410,27 +414,26 @@ main(int argc, char** argv)
unlink(grammar_buf_fn.c_str());
if (!noup) {
- if (!quiet) cerr << endl << "writing weights file to '" << output_fn << "' ..." << endl;
+ if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl;
if (select_weights == "last") { // last
- WriteFile out(output_fn);
- ostream& o = *out.stream();
+ WriteFile of(output_fn); // works with '-'
+ ostream& o = *of.stream();
o.precision(17);
o << _np;
for (SparseVector<double>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
if (it->second == 0) continue;
o << FD::Convert(it->first) << '\t' << it->second << endl;
}
- if (hstreaming) cout << "__SHARD_COUNT__\t1" << endl;
- } else if (select_weights == "VOID") { // do nothing
+ } else if (select_weights == "VOID") { // do nothing with the weights
} else { // best
if (output_fn != "-") {
- CopyFile(weights_files[best_it], output_fn);
+ CopyFile(weights_files[best_it], output_fn); // always gzipped
} else {
- ReadFile(weights_files[best_it]);
+ ReadFile bestw(weights_files[best_it]);
string o;
cout.precision(17);
cout << _np;
- while(getline(*input, o)) cout << o << endl;
+ while(getline(*bestw, o)) cout << o << endl;
}
for (vector<string>::iterator it = weights_files.begin(); it != weights_files.end(); ++it) {
unlink(it->c_str());
@@ -438,6 +441,7 @@ main(int argc, char** argv)
unlink(it->c_str());
}
}
+ if (output_fn == "-" && hstreaming) cout << "__SHARD_COUNT__\t1" << endl;
if (!quiet) cerr << "done" << endl;
}