summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc61
1 files changed, 31 insertions, 30 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 0481cf96..44090242 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -6,23 +6,24 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
{
po::options_description ini("Configuration File Options");
ini.add_options()
- ("input", po::value<string>()->default_value("-"), "input file")
- ("output", po::value<string>()->default_value("-"), "output weights file (or VOID)")
- ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
- ("decoder_config", po::value<string>(), "configuration file for cdec")
- ("k", po::value<size_t>()->default_value(100), "size of kbest or sample from forest")
- ("sample_from", po::value<string>()->default_value("kbest"), "where to get translations from")
- ("filter", po::value<string>()->default_value("unique"), "filter kbest list")
- ("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, rand")
- ("N", po::value<size_t>()->default_value(3), "N for Ngrams")
- ("epochs", po::value<size_t>()->default_value(2), "# of iterations T")
- ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring metric")
- ("stop_after", po::value<size_t>()->default_value(0), "stop after X input sentences")
- ("print_weights", po::value<string>(), "weights to print on each iteration")
- ("hstreaming", po::value<bool>()->zero_tokens(), "run in hadoop streaming mode")
- ("learning_rate", po::value<double>()->default_value(0.0005), "learning rate")
- ("gamma", po::value<double>()->default_value(0.), "gamma for SVM (0 for perceptron)")
- ("noup", po::value<bool>()->zero_tokens(), "do not update weights");
+ ("input", po::value<string>()->default_value("-"), "input file")
+ ("output", po::value<string>()->default_value("-"), "output weights file (or VOID)")
+ ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
+ ("decoder_config", po::value<string>(), "configuration file for cdec")
+ ("k", po::value<unsigned>()->default_value(100), "size of kbest or sample from forest")
+ ("sample_from", po::value<string>()->default_value("kbest"), "where to get translations from")
+ ("filter", po::value<string>()->default_value("unique"), "filter kbest list")
+ ("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, rand")
+ ("N", po::value<unsigned>()->default_value(3), "N for Ngrams")
+ ("epochs", po::value<unsigned>()->default_value(2), "# of iterations T")
+ ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring metric")
+ ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences")
+ ("print_weights", po::value<string>(), "weights to print on each iteration")
+ ("hstreaming", po::value<bool>()->zero_tokens(), "run in hadoop streaming mode")
+ ("learning_rate", po::value<double>()->default_value(0.0005), "learning rate")
+ ("gamma", po::value<double>()->default_value(0.), "gamma for SVM (0 for perceptron)")
+ ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use") // FIXME
+ ("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
("config,c", po::value<string>(), "dtrain config file")
@@ -75,10 +76,10 @@ main(int argc, char** argv)
hstreaming = true;
quiet = true;
}
- const size_t k = cfg["k"].as<size_t>();
- const size_t N = cfg["N"].as<size_t>();
- const size_t T = cfg["epochs"].as<size_t>();
- const size_t stop_after = cfg["stop_after"].as<size_t>();
+ const unsigned k = cfg["k"].as<unsigned>();
+ const unsigned N = cfg["N"].as<unsigned>();
+ const unsigned T = cfg["epochs"].as<unsigned>();
+ const unsigned stop_after = cfg["stop_after"].as<unsigned>();
const string filter_type = cfg["filter"].as<string>();
const string sample_from = cfg["sample_from"].as<string>();
const string pair_sampling = cfg["pair_sampling"].as<string>();
@@ -105,7 +106,7 @@ main(int argc, char** argv)
// scoring metric/scorer
string scorer_str = cfg["scorer"].as<string>();
- score_t (*scorer)(NgramCounts&, const size_t, const size_t, size_t, vector<score_t>);
+ score_t (*scorer)(NgramCounts&, const unsigned, const unsigned, unsigned, vector<score_t>);
if (scorer_str == "bleu") {
scorer = &bleu;
} else if (scorer_str == "stupid_bleu") {
@@ -119,8 +120,8 @@ main(int argc, char** argv)
exit(1);
}
NgramCounts global_counts(N); // counts for 1 best translations
- size_t global_hyp_len = 0; // sum hypothesis lengths
- size_t global_ref_len = 0; // sum reference lengths
+ unsigned global_hyp_len = 0; // sum hypothesis lengths
+ unsigned global_ref_len = 0; // sum reference lengths
// ^^^ global_* for approx_bleu
vector<score_t> bleu_weights; // we leave this empty -> 1/N
if (!quiet) cerr << setw(26) << "scorer '" << scorer_str << "'" << endl << endl;
@@ -149,10 +150,10 @@ main(int argc, char** argv)
ogzstream grammar_buf_out;
grammar_buf_out.open(grammar_buf_fn);
- size_t in_sz = 999999999; // input index, input size
+ unsigned in_sz = 999999999; // input index, input size
vector<pair<score_t,score_t> > all_scores;
score_t max_score = 0.;
- size_t best_it = 0;
+ unsigned best_it = 0;
float overall_time = 0.;
// output cfg
@@ -178,7 +179,7 @@ main(int argc, char** argv)
}
- for (size_t t = 0; t < T; t++) // T epochs
+ for (unsigned t = 0; t < T; t++) // T epochs
{
time_t start, end;
@@ -186,7 +187,7 @@ main(int argc, char** argv)
igzstream grammar_buf_in;
if (t > 0) grammar_buf_in.open(grammar_buf_fn);
score_t score_sum = 0., model_sum = 0.;
- size_t ii = 0;
+ unsigned ii = 0;
if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl;
while(true)
@@ -279,10 +280,10 @@ main(int argc, char** argv)
// (local) scoring
if (t > 0) ref_ids = ref_ids_buf[ii];
score_t score = 0.;
- for (size_t i = 0; i < samples->size(); i++) {
+ for (unsigned i = 0; i < samples->size(); i++) {
NgramCounts counts = make_ngram_counts(ref_ids, (*samples)[i].w, N);
if (scorer_str == "approx_bleu") {
- size_t hyp_len = 0;
+ unsigned hyp_len = 0;
if (i == 0) { // 'context of 1best translations'
global_counts += counts;
global_hyp_len += (*samples)[i].w.size();