diff options
-rw-r--r-- | dtrain/dtrain.cc | 34 | ||||
-rw-r--r-- | dtrain/test/example/dtrain.ini | 10 |
2 files changed, 24 insertions, 20 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index f679c9f6..0a94f7aa 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -45,21 +45,25 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) cerr << "When using 'hstreaming' the 'output' param should be '-'."; return false; } - if ((*cfg)["filter"].as<string>() != "unique" + if ((*cfg)["sample_from"].as<string>() != "kbest" + && (*cfg)["sample_from"].as<string>() != "forest") { + cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl; + return false; + } + if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "unique" && (*cfg)["filter"].as<string>() != "no") { cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'unique' or 'no'." << endl; + return false; } if ((*cfg)["pair_sampling"].as<string>() != "all" - && (*cfg)["pair_sampling"].as<string>() != "rand") { + && (*cfg)["pair_sampling"].as<string>() != "rand" && (*cfg)["pair_sampling"].as<string>() != "108010") { cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "', use 'all' or 'rand'." << endl; - } - if ((*cfg)["sample_from"].as<string>() != "kbest" - && (*cfg)["sample_from"].as<string>() != "forest") { - cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl; + return false; } if ((*cfg)["select_weights"].as<string>() != "last" - && (*cfg)["select_weights"].as<string>() != "best") { + && (*cfg)["select_weights"].as<string>() != "best" && (*cfg)["select_weights"].as<string>() != "VOID") { cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl; + return false; } return true; } @@ -410,27 +414,26 @@ main(int argc, char** argv) unlink(grammar_buf_fn.c_str()); if (!noup) { - if (!quiet) cerr << endl << "writing weights file to '" << output_fn << "' ..." << endl; + if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl; if (select_weights == "last") { // last - WriteFile out(output_fn); - ostream& o = *out.stream(); + WriteFile of(output_fn); // works with '-' + ostream& o = *of.stream(); o.precision(17); o << _np; for (SparseVector<double>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) { if (it->second == 0) continue; o << FD::Convert(it->first) << '\t' << it->second << endl; } - if (hstreaming) cout << "__SHARD_COUNT__\t1" << endl; - } else if (select_weights == "VOID") { // do nothing + } else if (select_weights == "VOID") { // do nothing with the weights } else { // best if (output_fn != "-") { - CopyFile(weights_files[best_it], output_fn); + CopyFile(weights_files[best_it], output_fn); // always gzipped } else { - ReadFile(weights_files[best_it]); + ReadFile bestw(weights_files[best_it]); string o; cout.precision(17); cout << _np; - while(getline(*input, o)) cout << o << endl; + while(getline(*bestw, o)) cout << o << endl; } for (vector<string>::iterator it = weights_files.begin(); it != weights_files.end(); ++it) { unlink(it->c_str()); @@ -438,6 +441,7 @@ main(int argc, char** argv) unlink(it->c_str()); } } + if (output_fn == "-" && hstreaming) cout << "__SHARD_COUNT__\t1" << endl; if (!quiet) cerr << "done" << endl; } diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini index 1e841824..9b83193a 100644 --- a/dtrain/test/example/dtrain.ini +++ b/dtrain/test/example/dtrain.ini @@ -1,14 +1,14 @@ decoder_config=test/example/cdec.ini k=100 N=3 -gamma=0 -epochs=4 +gamma=0.00001 +epochs=2 input=test/example/nc-1k-tabs.gz scorer=stupid_bleu output=- -stop_after=100 +stop_after=10 sample_from=kbest -pair_sampling=all -select_weights=VOID +pair_sampling=108010 +select_weights=best print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough tmp=/tmp |