diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 61 |
1 files changed, 31 insertions, 30 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 0481cf96..44090242 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -6,23 +6,24 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) { po::options_description ini("Configuration File Options"); ini.add_options() - ("input", po::value<string>()->default_value("-"), "input file") - ("output", po::value<string>()->default_value("-"), "output weights file (or VOID)") - ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") - ("decoder_config", po::value<string>(), "configuration file for cdec") - ("k", po::value<size_t>()->default_value(100), "size of kbest or sample from forest") - ("sample_from", po::value<string>()->default_value("kbest"), "where to get translations from") - ("filter", po::value<string>()->default_value("unique"), "filter kbest list") - ("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, rand") - ("N", po::value<size_t>()->default_value(3), "N for Ngrams") - ("epochs", po::value<size_t>()->default_value(2), "# of iterations T") - ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring metric") - ("stop_after", po::value<size_t>()->default_value(0), "stop after X input sentences") - ("print_weights", po::value<string>(), "weights to print on each iteration") - ("hstreaming", po::value<bool>()->zero_tokens(), "run in hadoop streaming mode") - ("learning_rate", po::value<double>()->default_value(0.0005), "learning rate") - ("gamma", po::value<double>()->default_value(0.), "gamma for SVM (0 for perceptron)") - ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); + ("input", po::value<string>()->default_value("-"), "input file") + ("output", po::value<string>()->default_value("-"), "output weights file (or VOID)") + ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") + ("decoder_config", po::value<string>(), "configuration file for cdec") + ("k", po::value<unsigned>()->default_value(100), "size of kbest or sample from forest") + ("sample_from", po::value<string>()->default_value("kbest"), "where to get translations from") + ("filter", po::value<string>()->default_value("unique"), "filter kbest list") + ("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, rand") + ("N", po::value<unsigned>()->default_value(3), "N for Ngrams") + ("epochs", po::value<unsigned>()->default_value(2), "# of iterations T") + ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring metric") + ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences") + ("print_weights", po::value<string>(), "weights to print on each iteration") + ("hstreaming", po::value<bool>()->zero_tokens(), "run in hadoop streaming mode") + ("learning_rate", po::value<double>()->default_value(0.0005), "learning rate") + ("gamma", po::value<double>()->default_value(0.), "gamma for SVM (0 for perceptron)") + ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use") // FIXME + ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() ("config,c", po::value<string>(), "dtrain config file") @@ -75,10 +76,10 @@ main(int argc, char** argv) hstreaming = true; quiet = true; } - const size_t k = cfg["k"].as<size_t>(); - const size_t N = cfg["N"].as<size_t>(); - const size_t T = cfg["epochs"].as<size_t>(); - const size_t stop_after = cfg["stop_after"].as<size_t>(); + const unsigned k = cfg["k"].as<unsigned>(); + const unsigned N = cfg["N"].as<unsigned>(); + const unsigned T = cfg["epochs"].as<unsigned>(); + const unsigned stop_after = cfg["stop_after"].as<unsigned>(); const string filter_type = cfg["filter"].as<string>(); const string sample_from = cfg["sample_from"].as<string>(); const string pair_sampling = cfg["pair_sampling"].as<string>(); @@ -105,7 +106,7 @@ main(int argc, char** argv) // scoring metric/scorer string scorer_str = cfg["scorer"].as<string>(); - score_t (*scorer)(NgramCounts&, const size_t, const size_t, size_t, vector<score_t>); + score_t (*scorer)(NgramCounts&, const unsigned, const unsigned, unsigned, vector<score_t>); if (scorer_str == "bleu") { scorer = &bleu; } else if (scorer_str == "stupid_bleu") { @@ -119,8 +120,8 @@ main(int argc, char** argv) exit(1); } NgramCounts global_counts(N); // counts for 1 best translations - size_t global_hyp_len = 0; // sum hypothesis lengths - size_t global_ref_len = 0; // sum reference lengths + unsigned global_hyp_len = 0; // sum hypothesis lengths + unsigned global_ref_len = 0; // sum reference lengths // ^^^ global_* for approx_bleu vector<score_t> bleu_weights; // we leave this empty -> 1/N if (!quiet) cerr << setw(26) << "scorer '" << scorer_str << "'" << endl << endl; @@ -149,10 +150,10 @@ main(int argc, char** argv) ogzstream grammar_buf_out; grammar_buf_out.open(grammar_buf_fn); - size_t in_sz = 999999999; // input index, input size + unsigned in_sz = 999999999; // input index, input size vector<pair<score_t,score_t> > all_scores; score_t max_score = 0.; - size_t best_it = 0; + unsigned best_it = 0; float overall_time = 0.; // output cfg @@ -178,7 +179,7 @@ main(int argc, char** argv) } - for (size_t t = 0; t < T; t++) // T epochs + for (unsigned t = 0; t < T; t++) // T epochs { time_t start, end; @@ -186,7 +187,7 @@ main(int argc, char** argv) igzstream grammar_buf_in; if (t > 0) grammar_buf_in.open(grammar_buf_fn); score_t score_sum = 0., model_sum = 0.; - size_t ii = 0; + unsigned ii = 0; if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl; while(true) @@ -279,10 +280,10 @@ main(int argc, char** argv) // (local) scoring if (t > 0) ref_ids = ref_ids_buf[ii]; score_t score = 0.; - for (size_t i = 0; i < samples->size(); i++) { + for (unsigned i = 0; i < samples->size(); i++) { NgramCounts counts = make_ngram_counts(ref_ids, (*samples)[i].w, N); if (scorer_str == "approx_bleu") { - size_t hyp_len = 0; + unsigned hyp_len = 0; if (i == 0) { // 'context of 1best translations' global_counts += counts; global_hyp_len += (*samples)[i].w.size(); |