diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/dtrain/Makefile.am | 2 | ||||
| -rw-r--r-- | training/dtrain/dtrain.cc | 45 | ||||
| -rw-r--r-- | training/dtrain/dtrain.h | 2 | ||||
| -rw-r--r-- | training/dtrain/examples/standard/dtrain.ini | 5 | ||||
| -rw-r--r-- | training/dtrain/examples/standard/nc-wmt11.gz | bin | 0 -> 113504 bytes | 
5 files changed, 41 insertions, 13 deletions
| diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am index 844c790d..ecb6c128 100644 --- a/training/dtrain/Makefile.am +++ b/training/dtrain/Makefile.am @@ -1,7 +1,7 @@  bin_PROGRAMS = dtrain  dtrain_SOURCES = dtrain.cc score.cc dtrain.h kbestget.h ksampler.h pairsampling.h score.h -dtrain_LDADD   = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a +dtrain_LDADD   = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lboost_regex  AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 38a9b69a..a496f08a 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -12,8 +12,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)  {    po::options_description ini("Configuration File Options");    ini.add_options() -    ("input",             po::value<string>()->default_value("-"),                                             "input file (src)") +    ("input",             po::value<string>(),                                                                 "input file (src)")      ("refs,r",            po::value<string>(),                                                                       "references") +    ("bitext,b",          po::value<string>(),                                                            "bitext: 'src ||| tgt'")      ("output",            po::value<string>()->default_value("-"),                          "output weights file, '-' for STDOUT")      ("input_weights",     po::value<string>(),                                "input weights file (e.g. from previous iteration)")      ("decoder_config",    po::value<string>(),                                                      "configuration file for cdec") @@ -73,13 +74,17 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)      cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl;      return false;    } -  if(cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") { +  if (cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") {      cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl;    } -  if((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) { +  if ((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) {      cerr << "hi_lo must lie in [0.01, 0.5]" << endl;      return false;    } +  if ((cfg->count("input")>0 || cfg->count("refs")>0) && cfg->count("bitext")>0) { +    cerr << "Provide 'input' and 'refs' or 'bitext', not both." << endl; +    return false; +  }    if ((*cfg)["pair_threshold"].as<score_t>() < 0) {      cerr << "The threshold must be >= 0!" << endl;      return false; @@ -208,13 +213,24 @@ main(int argc, char** argv)    // output    string output_fn = cfg["output"].as<string>();    // input -  string input_fn = cfg["input"].as<string>(); +  bool read_bitext = false; +  string input_fn; +  if (cfg.count("bitext")) { +    read_bitext = true; +    input_fn = cfg["bitext"].as<string>(); +  } else { +    input_fn = cfg["input"].as<string>(); +  }    ReadFile input(input_fn);    // buffer input for t > 0    vector<string> src_str_buf;          // source strings (decoder takes only strings)    vector<vector<WordID> > ref_ids_buf; // references as WordID vecs -  string refs_fn = cfg["refs"].as<string>(); -  ReadFile refs(refs_fn); +  ReadFile refs; +  string refs_fn; +  if (!read_bitext) { +    refs_fn = cfg["refs"].as<string>(); +    refs.Init(refs_fn); +  }    unsigned in_sz = std::numeric_limits<unsigned>::max(); // input index, input size    vector<pair<score_t, score_t> > all_scores; @@ -253,7 +269,8 @@ main(int argc, char** argv)      cerr << setw(25) << "max pairs " << max_pairs << endl;      cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;      cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; -    cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; +    if (!read_bitext) +      cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;      cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;      if (cfg.count("input_weights"))        cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl; @@ -279,9 +296,16 @@ main(int argc, char** argv)    {      string in; +    string ref;      bool next = false, stop = false; // next iteration or premature stop      if (t == 0) {        if(!getline(*input, in)) next = true; +      if(read_bitext) { +        vector<string> strs; +        boost::algorithm::split_regex(strs, in, boost::regex(" \\|\\|\\| ")); +        in = strs[0]; +        ref = strs[1]; +      }      } else {        if (ii == in_sz) next = true; // stop if we reach the end of our input      } @@ -318,10 +342,11 @@ main(int argc, char** argv)      // getting input      vector<WordID> ref_ids; // reference as vector<WordID>      if (t == 0) { -      string r_; -      getline(*refs, r_); +      if (!read_bitext) { +        getline(*refs, ref); +      }        vector<string> ref_tok; -      boost::split(ref_tok, r_, boost::is_any_of(" ")); +      boost::split(ref_tok, ref, boost::is_any_of(" "));        register_and_convert(ref_tok, ref_ids);        ref_ids_buf.push_back(ref_ids);        src_str_buf.push_back(in); diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index 3981fb39..ccb5ad4d 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -9,6 +9,8 @@  #include <string.h>  #include <boost/algorithm/string.hpp> +#include <boost/regex.hpp> +#include <boost/algorithm/string/regex.hpp>  #include <boost/program_options.hpp>  #include "decoder.h" diff --git a/training/dtrain/examples/standard/dtrain.ini b/training/dtrain/examples/standard/dtrain.ini index e6d6382e..7dbb4ff0 100644 --- a/training/dtrain/examples/standard/dtrain.ini +++ b/training/dtrain/examples/standard/dtrain.ini @@ -1,5 +1,6 @@ -input=./nc-wmt11.de.gz -refs=./nc-wmt11.en.gz +#input=./nc-wmt11.de.gz +#refs=./nc-wmt11.en.gz +bitext=./nc-wmt11.gz  output=-                  # a weights file (add .gz for gzip compression) or STDOUT '-'  select_weights=VOID       # output average (over epochs) weight vector  decoder_config=./cdec.ini # config for cdec diff --git a/training/dtrain/examples/standard/nc-wmt11.gz b/training/dtrain/examples/standard/nc-wmt11.gzBinary files differ new file mode 100644 index 00000000..c39c5aef --- /dev/null +++ b/training/dtrain/examples/standard/nc-wmt11.gz | 
