diff options
| -rw-r--r-- | training/dtrain/dtrain.cc | 39 | ||||
| -rw-r--r-- | training/dtrain/dtrain.h | 43 | ||||
| -rw-r--r-- | training/dtrain/update.h | 50 | 
3 files changed, 105 insertions, 27 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 9ca048c0..b39fff3e 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -12,17 +12,20 @@ main(int argc, char** argv)    po::variables_map conf;    if (!dtrain_init(argc, argv, &conf))      return 1; -  const size_t k              = conf["k"].as<size_t>(); -  const string score_name     = conf["score"].as<string>(); -  const size_t N              = conf["N"].as<size_t>(); -  const size_t T              = conf["iterations"].as<size_t>(); -  const weight_t eta          = conf["learning_rate"].as<weight_t>(); -  const weight_t margin       = conf["margin"].as<weight_t>(); -  const bool average          = conf["average"].as<bool>(); -  const bool structured       = conf["struct"].as<bool>(); -  const weight_t l1_reg       = conf["l1_reg"].as<weight_t>(); -  const bool keep             = conf["keep"].as<bool>(); -  const string output_fn      = conf["output"].as<string>(); +  const size_t k                 = conf["k"].as<size_t>(); +  const string score_name        = conf["score"].as<string>(); +  const size_t N                 = conf["N"].as<size_t>(); +  const size_t T                 = conf["iterations"].as<size_t>(); +  const weight_t eta             = conf["learning_rate"].as<weight_t>(); +  const weight_t margin          = conf["margin"].as<weight_t>(); +  const bool average             = conf["average"].as<bool>(); +  const bool structured          = conf["struct"].as<bool>(); +  const weight_t l1_reg          = conf["l1_reg"].as<weight_t>(); +  const bool keep                = conf["keep"].as<bool>(); +  const bool noup                = conf["disable_learning"].as<bool>(); +  const string output_fn         = conf["output"].as<string>(); +  const string output_data_which = conf["output_data"].as<string>(); +  const bool output_data         = output_data_which!="";    vector<string> print_weights;    boost::split(print_weights, conf["print_weights"].as<string>(),                 boost::is_any_of(" ")); @@ -162,7 +165,19 @@ main(int argc, char** argv)      feature_count += observer->GetFeatureCount();      list_sz += observer->GetSize(); +    if (output_data) { +      if (output_data_which == "kbest") { +        OutputKbest(samples); +      } else if (output_data_which == "default") { +        OutputMultipartitePairs(samples, margin); +      } else if (output_data_which == "all") { +        OutputAllPairs(samples); +      } +    } +      // get pairs and update +    if (!noup) { +      SparseVector<weight_t> updates;      if (structured)        num_up += CollectUpdatesStruct(samples, updates); @@ -204,6 +219,8 @@ main(int argc, char** argv)        }      } +    } // noup +      i++;    } // input loop diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index 6633b4f9..0bbb5c9b 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -48,22 +48,24 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)  {    po::options_description opts("Configuration File Options");    opts.add_options() -    ("bitext,b",          po::value<string>(),                                                  "bitext") -    ("decoder_conf,C",    po::value<string>(),                          "configuration file for decoder") -    ("iterations,T",      po::value<size_t>()->default_value(15),   "number of iterations T (per shard)") -    ("k",                 po::value<size_t>()->default_value(100),                  "size of kbest list") -    ("learning_rate,l",   po::value<weight_t>()->default_value(0.00001),                 "learning rate") -    ("l1_reg,r",          po::value<weight_t>()->default_value(0.),         "l1 regularization strength") -    ("margin,m",          po::value<weight_t>()->default_value(1.0),      "margin for margin perceptron") -    ("score,s",           po::value<string>()->default_value("chiang"),      "per-sentence BLEU approx.") -    ("N",                 po::value<size_t>()->default_value(4),              "N for BLEU approximation") -    ("input_weights,w",   po::value<string>(),                                      "input weights file") -    ("average,a",         po::bool_switch()->default_value(true),               "output average weights") -    ("keep,K",            po::bool_switch()->default_value(false),  "output a weight file per iteration") -    ("struct,S",          po::bool_switch()->default_value(false),       "structured SGD with hope/fear") -    ("output,o",          po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") -    ("print_weights,P",   po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"), -                                                         "list of weights to print after each iteration"); +    ("bitext,b",           po::value<string>(),                                                      "bitext") +    ("decoder_conf,C",     po::value<string>(),                              "configuration file for decoder") +    ("iterations,T",       po::value<size_t>()->default_value(15),       "number of iterations T (per shard)") +    ("k",                  po::value<size_t>()->default_value(100),                      "size of kbest list") +    ("learning_rate,l",    po::value<weight_t>()->default_value(0.00001),                     "learning rate") +    ("l1_reg,r",           po::value<weight_t>()->default_value(0.),             "l1 regularization strength") +    ("margin,m",           po::value<weight_t>()->default_value(1.0),          "margin for margin perceptron") +    ("score,s",            po::value<string>()->default_value("chiang"),          "per-sentence BLEU approx.") +    ("N",                  po::value<size_t>()->default_value(4),                  "N for BLEU approximation") +    ("input_weights,w",    po::value<string>(),                                          "input weights file") +    ("average,a",          po::bool_switch()->default_value(true),                   "output average weights") +    ("keep,K",             po::bool_switch()->default_value(false),      "output a weight file per iteration") +    ("struct,S",           po::bool_switch()->default_value(false),           "structured SGD with hope/fear") +    ("output,o",           po::value<string>()->default_value("-"),     "output weights file, '-' for STDOUT") +    ("disable_learning,X", po::bool_switch()->default_value(false),                        "disable learning") +    ("output_data,D",      po::value<string>()->default_value(""), "output data to STDOUT; arg. is 'kbest', 'default' or 'all'") +    ("print_weights,P",    po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"), +                                                             "list of weights to print after each iteration");    po::options_description clopts("Command Line Options");    clopts.add_options()      ("conf,c", po::value<string>(), "dtrain configuration file") @@ -93,6 +95,15 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)      return false;    } +  if ((*conf)["output_data"].as<string>() != "") { +    if ((*conf)["output_data"].as<string>() != "kbest" && +        (*conf)["output_data"].as<string>() != "default" && +        (*conf)["output_data"].as<string>() != "all") { +      cerr << "Wrong 'output_data' argument: "; +      cerr << (*conf)["output_data"].as<string>() << endl; +      return false; +    } +  }    return true;  } diff --git a/training/dtrain/update.h b/training/dtrain/update.h index 6f42e5bd..83dc3186 100644 --- a/training/dtrain/update.h +++ b/training/dtrain/update.h @@ -94,6 +94,56 @@ CollectUpdatesStruct(vector<ScoredHyp>* s,    return updates.size();  } +inline void +OutputKbest(vector<ScoredHyp>* s) +{ +  sort(s->begin(), s->end(), _cmp); +  size_t i = 0; +  for (auto k: *s) { +    cout << i << "\t" << k.gold << "\t" << k.model << " \t" << k.f << endl; +    i++; +  } +} + +inline void +OutputMultipartitePairs(vector<ScoredHyp>* s, +                               weight_t margin=0., +                               bool all=true) +{ +  size_t sz = s->size(); +  sort(s->begin(), s->end(), _cmp); +  size_t sep = round(sz*0.1); +  for (size_t i = 0; i < sep; i++) { +    for (size_t j = sep; j < sz; j++) { +      if (!all && _good((*s)[i], (*s)[j], margin)) +        continue; +      cout << (*s)[i].f-(*s)[j].f << endl; +    } +  } +  size_t sep_lo = sz-sep; +  for (size_t i = sep; i < sep_lo; i++) { +    for (size_t j = sep_lo; j < sz; j++) { +      if (!all && _good((*s)[i], (*s)[j], margin)) +        continue; +      cout << (*s)[i].f-(*s)[j].f << endl; +    } +  } +} + +inline void +OutputAllPairs(vector<ScoredHyp>* s) +{ +  size_t sz = s->size(); +  sort(s->begin(), s->end(), _cmp); +  for (size_t i = 0; i < sz-1; i++) { +    for (size_t j = i+1; j < sz; j++) { +      if ((*s)[i].gold == (*s)[j].gold) +        continue; +      cout << (*s)[i].f-(*s)[j].f << endl; +    } +  } +} +  } // namespace  #endif  | 
