From 035585ee59e593d2b0cc358068d6a5dd639037cc Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Sun, 3 Nov 2013 21:24:51 +0100 Subject: bitext input for dtrain --- training/dtrain/Makefile.am | 2 +- training/dtrain/dtrain.cc | 45 ++++++++++++++++++++------ training/dtrain/dtrain.h | 2 ++ training/dtrain/examples/standard/dtrain.ini | 5 +-- training/dtrain/examples/standard/nc-wmt11.gz | Bin 0 -> 113504 bytes 5 files changed, 41 insertions(+), 13 deletions(-) create mode 100644 training/dtrain/examples/standard/nc-wmt11.gz (limited to 'training/dtrain') diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am index 844c790d..ecb6c128 100644 --- a/training/dtrain/Makefile.am +++ b/training/dtrain/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = dtrain dtrain_SOURCES = dtrain.cc score.cc dtrain.h kbestget.h ksampler.h pairsampling.h score.h -dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a +dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lboost_regex AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 38a9b69a..a496f08a 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -12,8 +12,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) { po::options_description ini("Configuration File Options"); ini.add_options() - ("input", po::value()->default_value("-"), "input file (src)") + ("input", po::value(), "input file (src)") ("refs,r", po::value(), "references") + ("bitext,b", po::value(), "bitext: 'src ||| tgt'") ("output", po::value()->default_value("-"), "output weights file, '-' for STDOUT") ("input_weights", po::value(), "input weights file (e.g. from previous iteration)") ("decoder_config", po::value(), "configuration file for cdec") @@ -73,13 +74,17 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as() << "'." << endl; return false; } - if(cfg->count("hi_lo") && (*cfg)["pair_sampling"].as() != "XYX") { + if (cfg->count("hi_lo") && (*cfg)["pair_sampling"].as() != "XYX") { cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl; } - if((*cfg)["hi_lo"].as() > 0.5 || (*cfg)["hi_lo"].as() < 0.01) { + if ((*cfg)["hi_lo"].as() > 0.5 || (*cfg)["hi_lo"].as() < 0.01) { cerr << "hi_lo must lie in [0.01, 0.5]" << endl; return false; } + if ((cfg->count("input")>0 || cfg->count("refs")>0) && cfg->count("bitext")>0) { + cerr << "Provide 'input' and 'refs' or 'bitext', not both." << endl; + return false; + } if ((*cfg)["pair_threshold"].as() < 0) { cerr << "The threshold must be >= 0!" << endl; return false; @@ -208,13 +213,24 @@ main(int argc, char** argv) // output string output_fn = cfg["output"].as(); // input - string input_fn = cfg["input"].as(); + bool read_bitext = false; + string input_fn; + if (cfg.count("bitext")) { + read_bitext = true; + input_fn = cfg["bitext"].as(); + } else { + input_fn = cfg["input"].as(); + } ReadFile input(input_fn); // buffer input for t > 0 vector src_str_buf; // source strings (decoder takes only strings) vector > ref_ids_buf; // references as WordID vecs - string refs_fn = cfg["refs"].as(); - ReadFile refs(refs_fn); + ReadFile refs; + string refs_fn; + if (!read_bitext) { + refs_fn = cfg["refs"].as(); + refs.Init(refs_fn); + } unsigned in_sz = std::numeric_limits::max(); // input index, input size vector > all_scores; @@ -253,7 +269,8 @@ main(int argc, char** argv) cerr << setw(25) << "max pairs " << max_pairs << endl; cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; - cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; + if (!read_bitext) + cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; if (cfg.count("input_weights")) cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as() << "'" << endl; @@ -279,9 +296,16 @@ main(int argc, char** argv) { string in; + string ref; bool next = false, stop = false; // next iteration or premature stop if (t == 0) { if(!getline(*input, in)) next = true; + if(read_bitext) { + vector strs; + boost::algorithm::split_regex(strs, in, boost::regex(" \\|\\|\\| ")); + in = strs[0]; + ref = strs[1]; + } } else { if (ii == in_sz) next = true; // stop if we reach the end of our input } @@ -318,10 +342,11 @@ main(int argc, char** argv) // getting input vector ref_ids; // reference as vector if (t == 0) { - string r_; - getline(*refs, r_); + if (!read_bitext) { + getline(*refs, ref); + } vector ref_tok; - boost::split(ref_tok, r_, boost::is_any_of(" ")); + boost::split(ref_tok, ref, boost::is_any_of(" ")); register_and_convert(ref_tok, ref_ids); ref_ids_buf.push_back(ref_ids); src_str_buf.push_back(in); diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index 3981fb39..ccb5ad4d 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -9,6 +9,8 @@ #include #include +#include +#include #include #include "decoder.h" diff --git a/training/dtrain/examples/standard/dtrain.ini b/training/dtrain/examples/standard/dtrain.ini index e6d6382e..7dbb4ff0 100644 --- a/training/dtrain/examples/standard/dtrain.ini +++ b/training/dtrain/examples/standard/dtrain.ini @@ -1,5 +1,6 @@ -input=./nc-wmt11.de.gz -refs=./nc-wmt11.en.gz +#input=./nc-wmt11.de.gz +#refs=./nc-wmt11.en.gz +bitext=./nc-wmt11.gz output=- # a weights file (add .gz for gzip compression) or STDOUT '-' select_weights=VOID # output average (over epochs) weight vector decoder_config=./cdec.ini # config for cdec diff --git a/training/dtrain/examples/standard/nc-wmt11.gz b/training/dtrain/examples/standard/nc-wmt11.gz new file mode 100644 index 00000000..c39c5aef Binary files /dev/null and b/training/dtrain/examples/standard/nc-wmt11.gz differ -- cgit v1.2.3