summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc140
1 files changed, 70 insertions, 70 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 4f29c499..4c5972a1 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -8,7 +8,7 @@ using namespace dtrain;
bool
-dtrain_init(int argc, char** argv, po::variables_map* cfg)
+dtrain_init(int argc, char** argv, po::variables_map* conf)
{
po::options_description ini("Configuration File Options");
ini.add_options()
@@ -51,49 +51,49 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("quiet,q", po::value<bool>()->zero_tokens(), "be quiet")
("verbose,v", po::value<bool>()->zero_tokens(), "be verbose");
cl.add(ini);
- po::store(parse_command_line(argc, argv, cl), *cfg);
- if (cfg->count("config")) {
- ifstream ini_f((*cfg)["config"].as<string>().c_str());
- po::store(po::parse_config_file(ini_f, ini), *cfg);
+ po::store(parse_command_line(argc, argv, cl), *conf);
+ if (conf->count("config")) {
+ ifstream ini_f((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(ini_f, ini), *conf);
}
- po::notify(*cfg);
- if (!cfg->count("decoder_config")) {
+ po::notify(*conf);
+ if (!conf->count("decoder_config")) {
cerr << cl << endl;
return false;
}
- if ((*cfg)["sample_from"].as<string>() != "kbest"
- && (*cfg)["sample_from"].as<string>() != "forest") {
- cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
+ if ((*conf)["sample_from"].as<string>() != "kbest"
+ && (*conf)["sample_from"].as<string>() != "forest") {
+ cerr << "Wrong 'sample_from' param: '" << (*conf)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
return false;
}
- if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "uniq" &&
- (*cfg)["filter"].as<string>() != "not") {
- cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'not'." << endl;
+ if ((*conf)["sample_from"].as<string>() == "kbest" && (*conf)["filter"].as<string>() != "uniq" &&
+ (*conf)["filter"].as<string>() != "not") {
+ cerr << "Wrong 'filter' param: '" << (*conf)["filter"].as<string>() << "', use 'uniq' or 'not'." << endl;
return false;
}
- if ((*cfg)["pair_sampling"].as<string>() != "all" && (*cfg)["pair_sampling"].as<string>() != "XYX" &&
- (*cfg)["pair_sampling"].as<string>() != "PRO" && (*cfg)["pair_sampling"].as<string>() != "output_pairs") {
- cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl;
+ if ((*conf)["pair_sampling"].as<string>() != "all" && (*conf)["pair_sampling"].as<string>() != "XYX" &&
+ (*conf)["pair_sampling"].as<string>() != "PRO" && (*conf)["pair_sampling"].as<string>() != "output_pairs") {
+ cerr << "Wrong 'pair_sampling' param: '" << (*conf)["pair_sampling"].as<string>() << "'." << endl;
return false;
}
- if (cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") {
+ if (conf->count("hi_lo") && (*conf)["pair_sampling"].as<string>() != "XYX") {
cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl;
}
- if ((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) {
+ if ((*conf)["hi_lo"].as<float>() > 0.5 || (*conf)["hi_lo"].as<float>() < 0.01) {
cerr << "hi_lo must lie in [0.01, 0.5]" << endl;
return false;
}
- if (!cfg->count("bitext")) {
+ if (!conf->count("bitext")) {
cerr << "No training data given." << endl;
return false;
}
- if ((*cfg)["pair_threshold"].as<score_t>() < 0) {
+ if ((*conf)["pair_threshold"].as<score_t>() < 0) {
cerr << "The threshold must be >= 0!" << endl;
return false;
}
- if ((*cfg)["select_weights"].as<string>() != "last" && (*cfg)["select_weights"].as<string>() != "best" &&
- (*cfg)["select_weights"].as<string>() != "avg" && (*cfg)["select_weights"].as<string>() != "VOID") {
- cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl;
+ if ((*conf)["select_weights"].as<string>() != "last" && (*conf)["select_weights"].as<string>() != "best" &&
+ (*conf)["select_weights"].as<string>() != "avg" && (*conf)["select_weights"].as<string>() != "VOID") {
+ cerr << "Wrong 'select_weights' param: '" << (*conf)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl;
return false;
}
return true;
@@ -103,60 +103,60 @@ int
main(int argc, char** argv)
{
// handle most parameters
- po::variables_map cfg;
- if (!dtrain_init(argc, argv, &cfg)) exit(1); // something is wrong
+ po::variables_map conf;
+ if (!dtrain_init(argc, argv, &conf)) exit(1); // something is wrong
bool quiet = false;
- if (cfg.count("quiet")) quiet = true;
+ if (conf.count("quiet")) quiet = true;
bool verbose = false;
- if (cfg.count("verbose")) verbose = true;
+ if (conf.count("verbose")) verbose = true;
bool noup = false;
- if (cfg.count("noup")) noup = true;
+ if (conf.count("noup")) noup = true;
bool rescale = false;
- if (cfg.count("rescale")) rescale = true;
+ if (conf.count("rescale")) rescale = true;
bool keep = false;
- if (cfg.count("keep")) keep = true;
+ if (conf.count("keep")) keep = true;
bool fix_features = false;
- if (cfg.count("fix_features")) fix_features = true;
-
- const unsigned k = cfg["k"].as<unsigned>();
- const unsigned N = cfg["N"].as<unsigned>();
- const unsigned T = cfg["epochs"].as<unsigned>();
- const unsigned stop_after = cfg["stop_after"].as<unsigned>();
- const string filter_type = cfg["filter"].as<string>();
- const string sample_from = cfg["sample_from"].as<string>();
- const string pair_sampling = cfg["pair_sampling"].as<string>();
- const score_t pair_threshold = cfg["pair_threshold"].as<score_t>();
- const string select_weights = cfg["select_weights"].as<string>();
- const string output_ranking = cfg["output_ranking"].as<string>();
- const float hi_lo = cfg["hi_lo"].as<float>();
- const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>();
- const unsigned max_pairs = cfg["max_pairs"].as<unsigned>();
- int repeat = cfg["repeat"].as<unsigned>();
+ if (conf.count("fix_features")) fix_features = true;
+
+ const unsigned k = conf["k"].as<unsigned>();
+ const unsigned N = conf["N"].as<unsigned>();
+ const unsigned T = conf["epochs"].as<unsigned>();
+ const unsigned stop_after = conf["stop_after"].as<unsigned>();
+ const string filter_type = conf["filter"].as<string>();
+ const string sample_from = conf["sample_from"].as<string>();
+ const string pair_sampling = conf["pair_sampling"].as<string>();
+ const score_t pair_threshold = conf["pair_threshold"].as<score_t>();
+ const string select_weights = conf["select_weights"].as<string>();
+ const string output_ranking = conf["output_ranking"].as<string>();
+ const float hi_lo = conf["hi_lo"].as<float>();
+ const score_t approx_bleu_d = conf["approx_bleu_d"].as<score_t>();
+ const unsigned max_pairs = conf["max_pairs"].as<unsigned>();
+ int repeat = conf["repeat"].as<unsigned>();
bool check = false;
- if (cfg.count("check")) check = true;
- weight_t loss_margin = cfg["loss_margin"].as<weight_t>();
+ if (conf.count("check")) check = true;
+ weight_t loss_margin = conf["loss_margin"].as<weight_t>();
bool batch = false;
- if (cfg.count("batch")) batch = true;
+ if (conf.count("batch")) batch = true;
if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max();
- const string pclr = cfg["pclr"].as<string>();
+ const string pclr = conf["pclr"].as<string>();
bool average = false;
if (select_weights == "avg")
average = true;
vector<string> print_weights;
- if (cfg.count("print_weights"))
- boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));
+ if (conf.count("print_weights"))
+ boost::split(print_weights, conf["print_weights"].as<string>(), boost::is_any_of(" "));
// setup decoder
register_feature_functions();
SetSilent(true);
- ReadFile ini_rf(cfg["decoder_config"].as<string>());
+ ReadFile ini_rf(conf["decoder_config"].as<string>());
if (!quiet)
- cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
+ cerr << setw(25) << "cdec conf " << "'" << conf["decoder_config"].as<string>() << "'" << endl;
Decoder decoder(ini_rf.stream());
// scoring metric/scorer
- string scorer_str = cfg["scorer"].as<string>();
+ string scorer_str = conf["scorer"].as<string>();
LocalScorer* scorer;
if (scorer_str == "bleu") {
scorer = static_cast<BleuScorer*>(new BleuScorer);
@@ -196,8 +196,8 @@ main(int argc, char** argv)
vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
SparseVector<weight_t> lambdas, cumulative_penalties, w_average, fixed;
- if (cfg.count("input_weights")) {
- Weights::InitFromFile(cfg["input_weights"].as<string>(), &decoder_weights);
+ if (conf.count("input_weights")) {
+ Weights::InitFromFile(conf["input_weights"].as<string>(), &decoder_weights);
if (fix_features) {
Weights::InitSparseVector(decoder_weights, &fixed);
SparseVector<weight_t>::iterator it = fixed.begin();
@@ -209,8 +209,8 @@ main(int argc, char** argv)
Weights::InitSparseVector(decoder_weights, &lambdas);
// meta params for perceptron, SVM
- weight_t eta = cfg["learning_rate"].as<weight_t>();
- weight_t gamma = cfg["gamma"].as<weight_t>();
+ weight_t eta = conf["learning_rate"].as<weight_t>();
+ weight_t gamma = conf["gamma"].as<weight_t>();
// faster perceptron: consider only misranked pairs, see
bool faster_perceptron = false;
@@ -221,19 +221,19 @@ main(int argc, char** argv)
bool l1clip = false;
bool l1cumul = false;
weight_t l1_reg = 0;
- if (cfg["l1_reg"].as<string>() != "none") {
- string s = cfg["l1_reg"].as<string>();
+ if (conf["l1_reg"].as<string>() != "none") {
+ string s = conf["l1_reg"].as<string>();
if (s == "naive") l1naive = true;
else if (s == "clip") l1clip = true;
else if (s == "cumul") l1cumul = true;
- l1_reg = cfg["l1_reg_strength"].as<weight_t>();
+ l1_reg = conf["l1_reg_strength"].as<weight_t>();
}
// output
- string output_fn = cfg["output"].as<string>();
+ string output_fn = conf["output"].as<string>();
// input
string input_fn;
- ReadFile input(cfg["bitext"].as<string>());
+ ReadFile input(conf["bitext"].as<string>());
// buffer input for t > 0
vector<string> src_str_buf; // source strings (decoder takes only strings)
vector<vector<vector<WordID> > > refs_as_ids_buf; // references as WordID vecs
@@ -244,7 +244,7 @@ main(int argc, char** argv)
unsigned best_it = 0;
float overall_time = 0.;
- // output cfg
+ // output conf
if (!quiet) {
cerr << _p5;
cerr << endl << "dtrain" << endl << "Parameters:" << endl;
@@ -267,19 +267,19 @@ main(int argc, char** argv)
cerr << setw(25) << "hi lo " << hi_lo << endl;
cerr << setw(25) << "pair threshold " << pair_threshold << endl;
cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl;
- if (cfg.count("l1_reg"))
- cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl;
+ if (conf.count("l1_reg"))
+ cerr << setw(25) << "l1 reg " << l1_reg << " '" << conf["l1_reg"].as<string>() << "'" << endl;
if (rescale)
cerr << setw(25) << "rescale " << rescale << endl;
cerr << setw(25) << "pclr " << pclr << endl;
cerr << setw(25) << "max pairs " << max_pairs << endl;
cerr << setw(25) << "repeat " << repeat << endl;
//cerr << setw(25) << "test k-best " << test_k_best << endl;
- cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
+ cerr << setw(25) << "cdec conf " << "'" << conf["decoder_config"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
- if (cfg.count("input_weights"))
- cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl;
+ if (conf.count("input_weights"))
+ cerr << setw(25) << "weights in " << "'" << conf["input_weights"].as<string>() << "'" << endl;
if (stop_after > 0)
cerr << setw(25) << "stop_after " << stop_after << endl;
if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl;