diff options
Diffstat (limited to 'dtrain/dtrain.cc')
| -rw-r--r-- | dtrain/dtrain.cc | 107 | 
1 files changed, 64 insertions, 43 deletions
| diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 0853173f..3d3aa2d3 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -6,32 +6,33 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)  {    po::options_description ini("Configuration File Options");    ini.add_options() -    ("input",          po::value<string>()->default_value("-"),                                                "input file") -    ("output",         po::value<string>()->default_value("-"),                       "output weights file, '-' for STDOUT") -    ("input_weights",  po::value<string>(),                             "input weights file (e.g. from previous iteration)") -    ("decoder_config", po::value<string>(),                                                   "configuration file for cdec") -    ("sample_from",    po::value<string>()->default_value("kbest"),      "where to sample translations from: kbest, forest") -    ("k",              po::value<unsigned>()->default_value(100),                         "how many translations to sample") -    ("filter",         po::value<string>()->default_value("unique"),                        "filter kbest list: no, unique") -    ("pair_sampling",  po::value<string>()->default_value("all"),                  "how to sample pairs: all, rand, 108010") -    ("N",              po::value<unsigned>()->default_value(3),                                       "N for Ngrams (BLEU)") -    ("epochs",         po::value<unsigned>()->default_value(2),                             "# of iterations T (per shard)")  -    ("scorer",         po::value<string>()->default_value("stupid_bleu"),     "scoring: bleu, stupid_*, smooth_*, approx_*") -    ("stop_after",     po::value<unsigned>()->default_value(0),                              "stop after X input sentences") -    ("print_weights",  po::value<string>(),                                            "weights to print on each iteration") -    ("hstreaming",     po::value<string>(),                                "run in hadoop streaming mode, arg is a task id") -    ("learning_rate",  po::value<weight_t>()->default_value(0.0005),                                        "learning rate") -    ("gamma",          po::value<weight_t>()->default_value(0),                          "gamma for SVM (0 for perceptron)") -    ("tmp",            po::value<string>()->default_value("/tmp"),                                        "temp dir to use") -    ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") -    ("keep_w",         po::value<bool>()->zero_tokens(),                              "protocol weights for each iteration") -    ("unit_weight_vector", po::value<bool>()->zero_tokens(),                       "Rescale weight vector after each input") -    ("l1_reg",         po::value<string>()->default_value("no"),         "apply l1 regularization as in Tsuroka et al 2010") -    ("l1_reg_strength", po::value<weight_t>(),                                                 "l1 regularization strength") +    ("input",           po::value<string>()->default_value("-"),                                                "input file") +    ("output",          po::value<string>()->default_value("-"),                       "output weights file, '-' for STDOUT") +    ("input_weights",   po::value<string>(),                             "input weights file (e.g. from previous iteration)") +    ("decoder_config",  po::value<string>(),                                                   "configuration file for cdec") +    ("sample_from",     po::value<string>()->default_value("kbest"),      "where to sample translations from: kbest, forest") +    ("k",               po::value<unsigned>()->default_value(100),                         "how many translations to sample") +    ("filter",          po::value<string>()->default_value("uniq"),                            "filter kbest list: no, uniq") +    ("pair_sampling",   po::value<string>()->default_value("all"),             "how to sample pairs: all, 5050, 108010, PRO") +    ("N",               po::value<unsigned>()->default_value(3),                                       "N for Ngrams (BLEU)") +    ("epochs",          po::value<unsigned>()->default_value(2),                             "# of iterations T (per shard)")  +    ("scorer",          po::value<string>()->default_value("stupid_bleu"),     "scoring: bleu, stupid_*, smooth_*, approx_*") +    ("learning_rate",   po::value<weight_t>()->default_value(0.0005),                                        "learning rate") +    ("gamma",           po::value<weight_t>()->default_value(0),                          "gamma for SVM (0 for perceptron)") +    ("select_weights",  po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") +    ("unit_wv",         po::value<bool>()->zero_tokens(),                           "Rescale weight vector after each input") +    ("l1_reg",          po::value<string>()->default_value("no"),         "apply l1 regularization as in Tsuroka et al 2010") +    ("l1_reg_strength", po::value<weight_t>(),                                                  "l1 regularization strength") +    ("update_ok",       po::value<bool>()->zero_tokens(),                      "include correctly ranked pairs into updates") +    ("stop_after",      po::value<unsigned>()->default_value(0),                              "stop after X input sentences") +    ("keep_w",          po::value<bool>()->zero_tokens(),                            "keep weights files for each iteration") +    ("print_weights",   po::value<string>(),                                            "weights to print on each iteration") +    ("hstreaming",      po::value<string>(),                                "run in hadoop streaming mode, arg is a task id") +    ("tmp",             po::value<string>()->default_value("/tmp"),                                        "temp dir to use")  #ifdef DTRAIN_LOCAL -    ("refs,r",         po::value<string>(),                                                      "references in local mode") +    ("refs,r",         po::value<string>(),                                                       "references in local mode")  #endif -    ("noup",           po::value<bool>()->zero_tokens(),                                            "do not update weights"); +    ("noup",           po::value<bool>()->zero_tokens(),                                             "do not update weights");    po::options_description cl("Command Line Options");    cl.add_options()      ("config,c",         po::value<string>(),              "dtrain config file") @@ -63,13 +64,14 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)      cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;      return false;    } -  if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "unique" +  if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "uniq"         && (*cfg)["filter"].as<string>() != "no") { -    cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'unique' or 'no'." << endl; +    cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'no'." << endl;      return false;    }    if ((*cfg)["pair_sampling"].as<string>() != "all" -       && (*cfg)["pair_sampling"].as<string>() != "rand" && (*cfg)["pair_sampling"].as<string>() != "108010") { +       && (*cfg)["pair_sampling"].as<string>() != "5050" && (*cfg)["pair_sampling"].as<string>() != "108010" +       && (*cfg)["pair_sampling"].as<string>() != "PRO") {      cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "', use 'all' or 'rand'." << endl;      return false;    } @@ -101,11 +103,14 @@ main(int argc, char** argv)      task_id = cfg["hstreaming"].as<string>();      cerr.precision(17);    } -  bool unit_weight_vector = false; -  if (cfg.count("unit_weight_vector")) unit_weight_vector = true; +  bool unit_wv = false; +  if (cfg.count("unit_wv")) unit_wv = true;    HSReporter rep(task_id);    bool keep_w = false;    if (cfg.count("keep_w")) keep_w = true; +  bool update_ok = false; +  if (cfg.count("update_ok")) +    update_ok = true;    const unsigned k = cfg["k"].as<unsigned>();    const unsigned N = cfg["N"].as<unsigned>();  @@ -118,7 +123,7 @@ main(int argc, char** argv)    vector<string> print_weights;    if (cfg.count("print_weights"))      boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); -   +    // setup decoder    register_feature_functions();    SetSilent(true); @@ -187,7 +192,7 @@ main(int argc, char** argv)    vector<vector<WordID> > ref_ids_buf; // references as WordID vecs    // where temp files go    string tmp_path = cfg["tmp"].as<string>(); -  vector<string> w_tmp_files; // used for protocol_w +  vector<string> w_tmp_files; // used for keep_w   #ifdef DTRAIN_LOCAL    string refs_fn = cfg["refs"].as<string>();    ReadFile refs(refs_fn); @@ -226,6 +231,12 @@ main(int argc, char** argv)      cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;      cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl;      cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; +    if (cfg.count("l1_reg")) +      cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl; +    if (update_ok) +      cerr << setw(25) << "up ok " << update_ok << endl; +    if (unit_wv) +      cerr << setw(25) << "unit weight vec " << unit_wv << endl;      if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl;    } @@ -320,7 +331,7 @@ main(int argc, char** argv)        // get buffered grammar        string grammar_str;        while (true) { -        string rule;   +        string rule;          getline(grammar_buf_in, rule);          if (boost::starts_with(rule, DTRAIN_GRAMMAR_DELIM)) break;          grammar_str += rule + "\n"; @@ -372,13 +383,15 @@ main(int argc, char** argv)      if (!noup) {        vector<pair<ScoredHyp,ScoredHyp> > pairs;        if (pair_sampling == "all") -        sample_all_pairs(samples, pairs); -      if (pair_sampling == "rand") -        sample_rand_pairs(samples, pairs, &rng); +        all_pairs(samples, pairs); +      if (pair_sampling == "5050") +        rand_pairs_5050(samples, pairs, &rng);        if (pair_sampling == "108010") -        sample108010(samples, pairs); +        multpart108010(samples, pairs); +      if (pair_sampling == "PRO") +        PROsampling(samples, pairs);        npairs += pairs.size(); -        +        for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();             it != pairs.end(); it++) {          score_t rank_error = it->second.score - it->first.score; @@ -388,6 +401,11 @@ main(int argc, char** argv)              SparseVector<weight_t> diff_vec = it->second.f - it->first.f;              lambdas.plus_eq_v_times_s(diff_vec, eta);              rank_errors++; +          } else { +            if (update_ok) { +              SparseVector<weight_t> diff_vec = it->first.f - it->second.f; +              lambdas.plus_eq_v_times_s(diff_vec, eta); +            }            }            if (it->first.model - it->second.model < 1) margin_violations++;          } else { @@ -404,6 +422,8 @@ main(int argc, char** argv)          }        } +      //////// +      // TEST THIS        // reset cumulative_penalties after 1 iter?         // do this only once per INPUT (not per pair)        if (l1naive) { @@ -439,8 +459,9 @@ main(int argc, char** argv)          }        }      } +    //////// -    if (unit_weight_vector && sample_from == "forest") lambdas /= lambdas.l2norm(); +    if (unit_wv && sample_from == "forest") lambdas /= lambdas.l2norm();      ++ii; @@ -501,11 +522,11 @@ main(int argc, char** argv)    }    if (hstreaming) { -    rep.update_counter("Score 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(score_avg*_SCALE));  -    rep.update_counter("Model 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(model_avg*_SCALE));  -    rep.update_counter("Pairs avg #"+boost::lexical_cast<string>(t+1), (unsigned)((npairs/(weight_t)in_sz)*_SCALE));  -    rep.update_counter("Rank errors avg #"+boost::lexical_cast<string>(t+1), (unsigned)((rank_errors/(weight_t)in_sz)*_SCALE));  -    rep.update_counter("Margin violations avg #"+boost::lexical_cast<string>(t+1), (unsigned)((margin_violations/(weight_t)in_sz)*_SCALE));  +    rep.update_counter("Score 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(score_avg*DTRAIN_SCALE));  +    rep.update_counter("Model 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(model_avg*DTRAIN_SCALE));  +    rep.update_counter("Pairs avg #"+boost::lexical_cast<string>(t+1), (unsigned)((npairs/(weight_t)in_sz)*DTRAIN_SCALE));  +    rep.update_counter("Rank errors avg #"+boost::lexical_cast<string>(t+1), (unsigned)((rank_errors/(weight_t)in_sz)*DTRAIN_SCALE));  +    rep.update_counter("Margin violations avg #"+boost::lexical_cast<string>(t+1), (unsigned)((margin_violations/(weight_t)in_sz)*DTRAIN_SCALE));       unsigned nonz = (unsigned)lambdas.size_nonzero();      rep.update_counter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz);       rep.update_gcounter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz); | 
