summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.h')
-rw-r--r--training/dtrain/dtrain.h47
1 files changed, 30 insertions, 17 deletions
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index 2636fa89..1d07defa 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -43,44 +43,57 @@ inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); }
inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); }
inline ostream& _p4(ostream& out) { return out << setprecision(4); }
-void
+bool
dtrain_init(int argc, char** argv, po::variables_map* conf)
{
- po::options_description ini("Configuration File Options");
- ini.add_options()
+ po::options_description opts("Configuration File Options");
+ opts.add_options()
("bitext,b", po::value<string>(), "bitext")
("decoder_conf,C", po::value<string>(), "configuration file for decoder")
- ("iterations,T", po::value<size_t>()->default_value(10), "number of iterations T (per shard)")
+ ("iterations,T", po::value<size_t>()->default_value(15), "number of iterations T (per shard)")
("k", po::value<size_t>()->default_value(100), "size of kbest list")
- ("learning_rate,l", po::value<weight_t>()->default_value(1.0), "learning rate")
+ ("learning_rate,l", po::value<weight_t>()->default_value(0.00001), "learning rate")
("l1_reg,r", po::value<weight_t>()->default_value(0.), "l1 regularization strength")
- ("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron")
- ("score,s", po::value<string>()->default_value("nakov"), "per-sentence BLEU approx.")
+ ("margin,m", po::value<weight_t>()->default_value(1.0), "margin for margin perceptron")
+ ("score,s", po::value<string>()->default_value("chiang"), "per-sentence BLEU approx.")
("N", po::value<size_t>()->default_value(4), "N for BLEU approximation")
("input_weights,w", po::value<string>(), "input weights file")
- ("average,a", po::value<bool>()->default_value(false), "output average weights")
- ("keep,K", po::value<bool>()->default_value(false), "output a weight file per iteration")
+ ("average,a", po::bool_switch()->default_value(true), "output average weights")
+ ("keep,K", po::bool_switch()->default_value(false), "output a weight file per iteration")
("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"),
"list of weights to print after each iteration");
- po::options_description cl("Command Line Options");
- cl.add_options()
- ("conf,c", po::value<string>(), "dtrain configuration file");
- cl.add(ini);
- po::store(parse_command_line(argc, argv, cl), *conf);
+ po::options_description clopts("Command Line Options");
+ clopts.add_options()
+ ("conf,c", po::value<string>(), "dtrain configuration file")
+ ("help,h", po::bool_switch()->default_value(false), "display options");
+ opts.add(clopts);
+ po::store(parse_command_line(argc, argv, opts), *conf);
+ cerr << "dtrain" << endl << endl;
+ if (conf->count("help")) {
+ cerr << opts << endl;
+
+ return false;
+ }
if (conf->count("conf")) {
ifstream f((*conf)["conf"].as<string>().c_str());
- po::store(po::parse_config_file(f, ini), *conf);
+ po::store(po::parse_config_file(f, opts), *conf);
}
po::notify(*conf);
if (!conf->count("decoder_conf")) {
cerr << "Missing decoder configuration." << endl;
- assert(false);
+ cerr << opts << endl;
+
+ return false;
}
if (!conf->count("bitext")) {
cerr << "No input given." << endl;
- assert(false);
+ cerr << opts << endl;
+
+ return false;
}
+
+ return true;
}
} // namespace