From d86299e31deb81a836b1b2f2a356a6f4b28eda9e Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Thu, 31 May 2012 14:33:59 +0200 Subject: new scorer, stuff --- dtrain/dtrain.cc | 75 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 40 insertions(+), 35 deletions(-) (limited to 'dtrain/dtrain.cc') diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 717d47a2..88413a1d 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -6,38 +6,39 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) { po::options_description ini("Configuration File Options"); ini.add_options() - ("input", po::value()->default_value("-"), "input file") - ("output", po::value()->default_value("-"), "output weights file, '-' for STDOUT") - ("input_weights", po::value(), "input weights file (e.g. from previous iteration)") - ("decoder_config", po::value(), "configuration file for cdec") - ("print_weights", po::value(), "weights to print on each iteration") - ("stop_after", po::value()->default_value(0), "stop after X input sentences") - ("tmp", po::value()->default_value("/tmp"), "temp dir to use") - ("keep", po::value()->zero_tokens(), "keep weights files for each iteration") - ("hstreaming", po::value(), "run in hadoop streaming mode, arg is a task id") - ("epochs", po::value()->default_value(10), "# of iterations T (per shard)") - ("k", po::value()->default_value(100), "how many translations to sample") - ("sample_from", po::value()->default_value("kbest"), "where to sample translations from: 'kbest', 'forest'") - ("filter", po::value()->default_value("uniq"), "filter kbest list: 'not', 'uniq'") - ("pair_sampling", po::value()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'") - ("hi_lo", po::value()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5") - ("pair_threshold", po::value()->default_value(0.), "bleu [0,1] threshold to filter pairs") - ("N", po::value()->default_value(4), "N for Ngrams (BLEU)") - ("scorer", po::value()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_") - ("learning_rate", po::value()->default_value(0.0001), "learning rate") - ("gamma", po::value()->default_value(0.), "gamma for SVM (0 for perceptron)") - ("select_weights", po::value()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)") - ("rescale", po::value()->zero_tokens(), "rescale weight vector after each input") - ("l1_reg", po::value()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") - ("l1_reg_strength", po::value(), "l1 regularization strength") - ("fselect", po::value()->default_value(-1), "TODO select top x percent (or by threshold) of features after each epoch") - ("approx_bleu_d", po::value()->default_value(0.9), "discount for approx. BLEU") - ("scale_bleu_diff", po::value()->zero_tokens(), "learning rate <- bleu diff of a misranked pair") - ("loss_margin", po::value()->default_value(0.), "update if no error in pref pair but model scores this near") + ("input", po::value()->default_value("-"), "input file") + ("output", po::value()->default_value("-"), "output weights file, '-' for STDOUT") + ("input_weights", po::value(), "input weights file (e.g. from previous iteration)") + ("decoder_config", po::value(), "configuration file for cdec") + ("print_weights", po::value(), "weights to print on each iteration") + ("stop_after", po::value()->default_value(0), "stop after X input sentences") + ("tmp", po::value()->default_value("/tmp"), "temp dir to use") + ("keep", po::value()->zero_tokens(), "keep weights files for each iteration") + ("hstreaming", po::value(), "run in hadoop streaming mode, arg is a task id") + ("epochs", po::value()->default_value(10), "# of iterations T (per shard)") + ("k", po::value()->default_value(100), "how many translations to sample") + ("sample_from", po::value()->default_value("kbest"), "where to sample translations from: 'kbest', 'forest'") + ("filter", po::value()->default_value("uniq"), "filter kbest list: 'not', 'uniq'") + ("pair_sampling", po::value()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'") + ("hi_lo", po::value()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5") + ("pair_threshold", po::value()->default_value(0.), "bleu [0,1] threshold to filter pairs") + ("N", po::value()->default_value(4), "N for Ngrams (BLEU)") + ("scorer", po::value()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_, lc_") + ("learning_rate", po::value()->default_value(0.0001), "learning rate") + ("gamma", po::value()->default_value(0.), "gamma for SVM (0 for perceptron)") + ("select_weights", po::value()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)") + ("rescale", po::value()->zero_tokens(), "rescale weight vector after each input") + ("l1_reg", po::value()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") + ("l1_reg_strength", po::value(), "l1 regularization strength") + ("fselect", po::value()->default_value(-1), "select top x percent (or by threshold) of features after each epoch NOT IMPL") // TODO + ("approx_bleu_d", po::value()->default_value(0.9), "discount for approx. BLEU") + ("scale_bleu_diff", po::value()->zero_tokens(), "learning rate <- bleu diff of a misranked pair") + ("loss_margin", po::value()->default_value(0.), "update if no error in pref pair but model scores this near") + ("max_pairs", po::value()->default_value(std::numeric_limits::max()), "max. # of pairs per Sent.") #ifdef DTRAIN_LOCAL - ("refs,r", po::value(), "references in local mode") + ("refs,r", po::value(), "references in local mode") #endif - ("noup", po::value()->zero_tokens(), "do not update weights"); + ("noup", po::value()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() ("config,c", po::value(), "dtrain config file") @@ -135,6 +136,7 @@ main(int argc, char** argv) const string select_weights = cfg["select_weights"].as(); const float hi_lo = cfg["hi_lo"].as(); const score_t approx_bleu_d = cfg["approx_bleu_d"].as(); + const unsigned max_pairs = cfg["max_pairs"].as(); weight_t loss_margin = cfg["loss_margin"].as(); if (loss_margin > 9998.) loss_margin = std::numeric_limits::max(); bool scale_bleu_diff = false; @@ -167,6 +169,8 @@ main(int argc, char** argv) scorer = dynamic_cast(new SmoothSingleBleuScorer); } else if (scorer_str == "approx_bleu") { scorer = dynamic_cast(new ApproxBleuScorer(N, approx_bleu_d)); + } else if (scorer_str == "lc_bleu") { + scorer = dynamic_cast(new LinearBleuScorer(N)); } else { cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl; exit(1); @@ -257,6 +261,7 @@ main(int argc, char** argv) cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as() << "'" << endl; if (rescale) cerr << setw(25) << "rescale " << rescale << endl; + cerr << "max pairs " << max_pairs << endl; cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; #ifdef DTRAIN_LOCAL @@ -421,17 +426,17 @@ main(int argc, char** argv) // get pairs vector > pairs; if (pair_sampling == "all") - all_pairs(samples, pairs, pair_threshold); + all_pairs(samples, pairs, pair_threshold, max_pairs); if (pair_sampling == "XYX") - partXYX(samples, pairs, pair_threshold, hi_lo); + partXYX(samples, pairs, pair_threshold, max_pairs, hi_lo); if (pair_sampling == "PRO") - PROsampling(samples, pairs, pair_threshold); + PROsampling(samples, pairs, pair_threshold, max_pairs); npairs += pairs.size(); for (vector >::iterator it = pairs.begin(); it != pairs.end(); it++) { #ifdef DTRAIN_FASTER_PERCEPTRON - bool rank_error = true; // pair filtering already did this for us + bool rank_error = true; // pair sampling already did this for us rank_errors++; score_t margin = std::numeric_limits::max(); #else @@ -498,7 +503,7 @@ main(int argc, char** argv) if (average) w_average += lambdas; - if (scorer_str == "approx_bleu") scorer->Reset(); + if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset(); if (t == 0) { in_sz = ii; // remember size of input (# lines) -- cgit v1.2.3