summaryrefslogtreecommitdiff
path: root/dtrain
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain')
-rw-r--r--dtrain/dtrain.cc34
-rw-r--r--dtrain/test/example/dtrain.ini10
2 files changed, 24 insertions, 20 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index f679c9f6..0a94f7aa 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -45,21 +45,25 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "When using 'hstreaming' the 'output' param should be '-'.";
return false;
}
- if ((*cfg)["filter"].as<string>() != "unique"
+ if ((*cfg)["sample_from"].as<string>() != "kbest"
+ && (*cfg)["sample_from"].as<string>() != "forest") {
+ cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
+ return false;
+ }
+ if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "unique"
&& (*cfg)["filter"].as<string>() != "no") {
cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'unique' or 'no'." << endl;
+ return false;
}
if ((*cfg)["pair_sampling"].as<string>() != "all"
- && (*cfg)["pair_sampling"].as<string>() != "rand") {
+ && (*cfg)["pair_sampling"].as<string>() != "rand" && (*cfg)["pair_sampling"].as<string>() != "108010") {
cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "', use 'all' or 'rand'." << endl;
- }
- if ((*cfg)["sample_from"].as<string>() != "kbest"
- && (*cfg)["sample_from"].as<string>() != "forest") {
- cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
+ return false;
}
if ((*cfg)["select_weights"].as<string>() != "last"
- && (*cfg)["select_weights"].as<string>() != "best") {
+ && (*cfg)["select_weights"].as<string>() != "best" && (*cfg)["select_weights"].as<string>() != "VOID") {
cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl;
+ return false;
}
return true;
}
@@ -410,27 +414,26 @@ main(int argc, char** argv)
unlink(grammar_buf_fn.c_str());
if (!noup) {
- if (!quiet) cerr << endl << "writing weights file to '" << output_fn << "' ..." << endl;
+ if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl;
if (select_weights == "last") { // last
- WriteFile out(output_fn);
- ostream& o = *out.stream();
+ WriteFile of(output_fn); // works with '-'
+ ostream& o = *of.stream();
o.precision(17);
o << _np;
for (SparseVector<double>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
if (it->second == 0) continue;
o << FD::Convert(it->first) << '\t' << it->second << endl;
}
- if (hstreaming) cout << "__SHARD_COUNT__\t1" << endl;
- } else if (select_weights == "VOID") { // do nothing
+ } else if (select_weights == "VOID") { // do nothing with the weights
} else { // best
if (output_fn != "-") {
- CopyFile(weights_files[best_it], output_fn);
+ CopyFile(weights_files[best_it], output_fn); // always gzipped
} else {
- ReadFile(weights_files[best_it]);
+ ReadFile bestw(weights_files[best_it]);
string o;
cout.precision(17);
cout << _np;
- while(getline(*input, o)) cout << o << endl;
+ while(getline(*bestw, o)) cout << o << endl;
}
for (vector<string>::iterator it = weights_files.begin(); it != weights_files.end(); ++it) {
unlink(it->c_str());
@@ -438,6 +441,7 @@ main(int argc, char** argv)
unlink(it->c_str());
}
}
+ if (output_fn == "-" && hstreaming) cout << "__SHARD_COUNT__\t1" << endl;
if (!quiet) cerr << "done" << endl;
}
diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini
index 1e841824..9b83193a 100644
--- a/dtrain/test/example/dtrain.ini
+++ b/dtrain/test/example/dtrain.ini
@@ -1,14 +1,14 @@
decoder_config=test/example/cdec.ini
k=100
N=3
-gamma=0
-epochs=4
+gamma=0.00001
+epochs=2
input=test/example/nc-1k-tabs.gz
scorer=stupid_bleu
output=-
-stop_after=100
+stop_after=10
sample_from=kbest
-pair_sampling=all
-select_weights=VOID
+pair_sampling=108010
+select_weights=best
print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough
tmp=/tmp