diff options
Diffstat (limited to 'training/dtrain')
-rw-r--r-- | training/dtrain/dtrain.cc | 140 |
1 files changed, 70 insertions, 70 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 4f29c499..4c5972a1 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -8,7 +8,7 @@ using namespace dtrain; bool -dtrain_init(int argc, char** argv, po::variables_map* cfg) +dtrain_init(int argc, char** argv, po::variables_map* conf) { po::options_description ini("Configuration File Options"); ini.add_options() @@ -51,49 +51,49 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("quiet,q", po::value<bool>()->zero_tokens(), "be quiet") ("verbose,v", po::value<bool>()->zero_tokens(), "be verbose"); cl.add(ini); - po::store(parse_command_line(argc, argv, cl), *cfg); - if (cfg->count("config")) { - ifstream ini_f((*cfg)["config"].as<string>().c_str()); - po::store(po::parse_config_file(ini_f, ini), *cfg); + po::store(parse_command_line(argc, argv, cl), *conf); + if (conf->count("config")) { + ifstream ini_f((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(ini_f, ini), *conf); } - po::notify(*cfg); - if (!cfg->count("decoder_config")) { + po::notify(*conf); + if (!conf->count("decoder_config")) { cerr << cl << endl; return false; } - if ((*cfg)["sample_from"].as<string>() != "kbest" - && (*cfg)["sample_from"].as<string>() != "forest") { - cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl; + if ((*conf)["sample_from"].as<string>() != "kbest" + && (*conf)["sample_from"].as<string>() != "forest") { + cerr << "Wrong 'sample_from' param: '" << (*conf)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl; return false; } - if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "uniq" && - (*cfg)["filter"].as<string>() != "not") { - cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'not'." << endl; + if ((*conf)["sample_from"].as<string>() == "kbest" && (*conf)["filter"].as<string>() != "uniq" && + (*conf)["filter"].as<string>() != "not") { + cerr << "Wrong 'filter' param: '" << (*conf)["filter"].as<string>() << "', use 'uniq' or 'not'." << endl; return false; } - if ((*cfg)["pair_sampling"].as<string>() != "all" && (*cfg)["pair_sampling"].as<string>() != "XYX" && - (*cfg)["pair_sampling"].as<string>() != "PRO" && (*cfg)["pair_sampling"].as<string>() != "output_pairs") { - cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl; + if ((*conf)["pair_sampling"].as<string>() != "all" && (*conf)["pair_sampling"].as<string>() != "XYX" && + (*conf)["pair_sampling"].as<string>() != "PRO" && (*conf)["pair_sampling"].as<string>() != "output_pairs") { + cerr << "Wrong 'pair_sampling' param: '" << (*conf)["pair_sampling"].as<string>() << "'." << endl; return false; } - if (cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") { + if (conf->count("hi_lo") && (*conf)["pair_sampling"].as<string>() != "XYX") { cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl; } - if ((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) { + if ((*conf)["hi_lo"].as<float>() > 0.5 || (*conf)["hi_lo"].as<float>() < 0.01) { cerr << "hi_lo must lie in [0.01, 0.5]" << endl; return false; } - if (!cfg->count("bitext")) { + if (!conf->count("bitext")) { cerr << "No training data given." << endl; return false; } - if ((*cfg)["pair_threshold"].as<score_t>() < 0) { + if ((*conf)["pair_threshold"].as<score_t>() < 0) { cerr << "The threshold must be >= 0!" << endl; return false; } - if ((*cfg)["select_weights"].as<string>() != "last" && (*cfg)["select_weights"].as<string>() != "best" && - (*cfg)["select_weights"].as<string>() != "avg" && (*cfg)["select_weights"].as<string>() != "VOID") { - cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl; + if ((*conf)["select_weights"].as<string>() != "last" && (*conf)["select_weights"].as<string>() != "best" && + (*conf)["select_weights"].as<string>() != "avg" && (*conf)["select_weights"].as<string>() != "VOID") { + cerr << "Wrong 'select_weights' param: '" << (*conf)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl; return false; } return true; @@ -103,60 +103,60 @@ int main(int argc, char** argv) { // handle most parameters - po::variables_map cfg; - if (!dtrain_init(argc, argv, &cfg)) exit(1); // something is wrong + po::variables_map conf; + if (!dtrain_init(argc, argv, &conf)) exit(1); // something is wrong bool quiet = false; - if (cfg.count("quiet")) quiet = true; + if (conf.count("quiet")) quiet = true; bool verbose = false; - if (cfg.count("verbose")) verbose = true; + if (conf.count("verbose")) verbose = true; bool noup = false; - if (cfg.count("noup")) noup = true; + if (conf.count("noup")) noup = true; bool rescale = false; - if (cfg.count("rescale")) rescale = true; + if (conf.count("rescale")) rescale = true; bool keep = false; - if (cfg.count("keep")) keep = true; + if (conf.count("keep")) keep = true; bool fix_features = false; - if (cfg.count("fix_features")) fix_features = true; - - const unsigned k = cfg["k"].as<unsigned>(); - const unsigned N = cfg["N"].as<unsigned>(); - const unsigned T = cfg["epochs"].as<unsigned>(); - const unsigned stop_after = cfg["stop_after"].as<unsigned>(); - const string filter_type = cfg["filter"].as<string>(); - const string sample_from = cfg["sample_from"].as<string>(); - const string pair_sampling = cfg["pair_sampling"].as<string>(); - const score_t pair_threshold = cfg["pair_threshold"].as<score_t>(); - const string select_weights = cfg["select_weights"].as<string>(); - const string output_ranking = cfg["output_ranking"].as<string>(); - const float hi_lo = cfg["hi_lo"].as<float>(); - const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>(); - const unsigned max_pairs = cfg["max_pairs"].as<unsigned>(); - int repeat = cfg["repeat"].as<unsigned>(); + if (conf.count("fix_features")) fix_features = true; + + const unsigned k = conf["k"].as<unsigned>(); + const unsigned N = conf["N"].as<unsigned>(); + const unsigned T = conf["epochs"].as<unsigned>(); + const unsigned stop_after = conf["stop_after"].as<unsigned>(); + const string filter_type = conf["filter"].as<string>(); + const string sample_from = conf["sample_from"].as<string>(); + const string pair_sampling = conf["pair_sampling"].as<string>(); + const score_t pair_threshold = conf["pair_threshold"].as<score_t>(); + const string select_weights = conf["select_weights"].as<string>(); + const string output_ranking = conf["output_ranking"].as<string>(); + const float hi_lo = conf["hi_lo"].as<float>(); + const score_t approx_bleu_d = conf["approx_bleu_d"].as<score_t>(); + const unsigned max_pairs = conf["max_pairs"].as<unsigned>(); + int repeat = conf["repeat"].as<unsigned>(); bool check = false; - if (cfg.count("check")) check = true; - weight_t loss_margin = cfg["loss_margin"].as<weight_t>(); + if (conf.count("check")) check = true; + weight_t loss_margin = conf["loss_margin"].as<weight_t>(); bool batch = false; - if (cfg.count("batch")) batch = true; + if (conf.count("batch")) batch = true; if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max(); - const string pclr = cfg["pclr"].as<string>(); + const string pclr = conf["pclr"].as<string>(); bool average = false; if (select_weights == "avg") average = true; vector<string> print_weights; - if (cfg.count("print_weights")) - boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); + if (conf.count("print_weights")) + boost::split(print_weights, conf["print_weights"].as<string>(), boost::is_any_of(" ")); // setup decoder register_feature_functions(); SetSilent(true); - ReadFile ini_rf(cfg["decoder_config"].as<string>()); + ReadFile ini_rf(conf["decoder_config"].as<string>()); if (!quiet) - cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl; + cerr << setw(25) << "cdec conf " << "'" << conf["decoder_config"].as<string>() << "'" << endl; Decoder decoder(ini_rf.stream()); // scoring metric/scorer - string scorer_str = cfg["scorer"].as<string>(); + string scorer_str = conf["scorer"].as<string>(); LocalScorer* scorer; if (scorer_str == "bleu") { scorer = static_cast<BleuScorer*>(new BleuScorer); @@ -196,8 +196,8 @@ main(int argc, char** argv) vector<weight_t>& decoder_weights = decoder.CurrentWeightVector(); SparseVector<weight_t> lambdas, cumulative_penalties, w_average, fixed; - if (cfg.count("input_weights")) { - Weights::InitFromFile(cfg["input_weights"].as<string>(), &decoder_weights); + if (conf.count("input_weights")) { + Weights::InitFromFile(conf["input_weights"].as<string>(), &decoder_weights); if (fix_features) { Weights::InitSparseVector(decoder_weights, &fixed); SparseVector<weight_t>::iterator it = fixed.begin(); @@ -209,8 +209,8 @@ main(int argc, char** argv) Weights::InitSparseVector(decoder_weights, &lambdas); // meta params for perceptron, SVM - weight_t eta = cfg["learning_rate"].as<weight_t>(); - weight_t gamma = cfg["gamma"].as<weight_t>(); + weight_t eta = conf["learning_rate"].as<weight_t>(); + weight_t gamma = conf["gamma"].as<weight_t>(); // faster perceptron: consider only misranked pairs, see bool faster_perceptron = false; @@ -221,19 +221,19 @@ main(int argc, char** argv) bool l1clip = false; bool l1cumul = false; weight_t l1_reg = 0; - if (cfg["l1_reg"].as<string>() != "none") { - string s = cfg["l1_reg"].as<string>(); + if (conf["l1_reg"].as<string>() != "none") { + string s = conf["l1_reg"].as<string>(); if (s == "naive") l1naive = true; else if (s == "clip") l1clip = true; else if (s == "cumul") l1cumul = true; - l1_reg = cfg["l1_reg_strength"].as<weight_t>(); + l1_reg = conf["l1_reg_strength"].as<weight_t>(); } // output - string output_fn = cfg["output"].as<string>(); + string output_fn = conf["output"].as<string>(); // input string input_fn; - ReadFile input(cfg["bitext"].as<string>()); + ReadFile input(conf["bitext"].as<string>()); // buffer input for t > 0 vector<string> src_str_buf; // source strings (decoder takes only strings) vector<vector<vector<WordID> > > refs_as_ids_buf; // references as WordID vecs @@ -244,7 +244,7 @@ main(int argc, char** argv) unsigned best_it = 0; float overall_time = 0.; - // output cfg + // output conf if (!quiet) { cerr << _p5; cerr << endl << "dtrain" << endl << "Parameters:" << endl; @@ -267,19 +267,19 @@ main(int argc, char** argv) cerr << setw(25) << "hi lo " << hi_lo << endl; cerr << setw(25) << "pair threshold " << pair_threshold << endl; cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; - if (cfg.count("l1_reg")) - cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl; + if (conf.count("l1_reg")) + cerr << setw(25) << "l1 reg " << l1_reg << " '" << conf["l1_reg"].as<string>() << "'" << endl; if (rescale) cerr << setw(25) << "rescale " << rescale << endl; cerr << setw(25) << "pclr " << pclr << endl; cerr << setw(25) << "max pairs " << max_pairs << endl; cerr << setw(25) << "repeat " << repeat << endl; //cerr << setw(25) << "test k-best " << test_k_best << endl; - cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl; + cerr << setw(25) << "cdec conf " << "'" << conf["decoder_config"].as<string>() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; - if (cfg.count("input_weights")) - cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl; + if (conf.count("input_weights")) + cerr << setw(25) << "weights in " << "'" << conf["input_weights"].as<string>() << "'" << endl; if (stop_after > 0) cerr << setw(25) << "stop_after " << stop_after << endl; if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl; |