summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/dtrain/Makefile.am2
-rw-r--r--training/dtrain/dtrain.cc45
-rw-r--r--training/dtrain/dtrain.h2
-rw-r--r--training/dtrain/examples/standard/dtrain.ini5
-rw-r--r--training/dtrain/examples/standard/nc-wmt11.gzbin0 -> 113504 bytes
5 files changed, 41 insertions, 13 deletions
diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am
index 844c790d..ecb6c128 100644
--- a/training/dtrain/Makefile.am
+++ b/training/dtrain/Makefile.am
@@ -1,7 +1,7 @@
bin_PROGRAMS = dtrain
dtrain_SOURCES = dtrain.cc score.cc dtrain.h kbestget.h ksampler.h pairsampling.h score.h
-dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a
+dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lboost_regex
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 38a9b69a..a496f08a 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -12,8 +12,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
{
po::options_description ini("Configuration File Options");
ini.add_options()
- ("input", po::value<string>()->default_value("-"), "input file (src)")
+ ("input", po::value<string>(), "input file (src)")
("refs,r", po::value<string>(), "references")
+ ("bitext,b", po::value<string>(), "bitext: 'src ||| tgt'")
("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
("decoder_config", po::value<string>(), "configuration file for cdec")
@@ -73,13 +74,17 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl;
return false;
}
- if(cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") {
+ if (cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") {
cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl;
}
- if((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) {
+ if ((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) {
cerr << "hi_lo must lie in [0.01, 0.5]" << endl;
return false;
}
+ if ((cfg->count("input")>0 || cfg->count("refs")>0) && cfg->count("bitext")>0) {
+ cerr << "Provide 'input' and 'refs' or 'bitext', not both." << endl;
+ return false;
+ }
if ((*cfg)["pair_threshold"].as<score_t>() < 0) {
cerr << "The threshold must be >= 0!" << endl;
return false;
@@ -208,13 +213,24 @@ main(int argc, char** argv)
// output
string output_fn = cfg["output"].as<string>();
// input
- string input_fn = cfg["input"].as<string>();
+ bool read_bitext = false;
+ string input_fn;
+ if (cfg.count("bitext")) {
+ read_bitext = true;
+ input_fn = cfg["bitext"].as<string>();
+ } else {
+ input_fn = cfg["input"].as<string>();
+ }
ReadFile input(input_fn);
// buffer input for t > 0
vector<string> src_str_buf; // source strings (decoder takes only strings)
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
- string refs_fn = cfg["refs"].as<string>();
- ReadFile refs(refs_fn);
+ ReadFile refs;
+ string refs_fn;
+ if (!read_bitext) {
+ refs_fn = cfg["refs"].as<string>();
+ refs.Init(refs_fn);
+ }
unsigned in_sz = std::numeric_limits<unsigned>::max(); // input index, input size
vector<pair<score_t, score_t> > all_scores;
@@ -253,7 +269,8 @@ main(int argc, char** argv)
cerr << setw(25) << "max pairs " << max_pairs << endl;
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
- cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
+ if (!read_bitext)
+ cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
if (cfg.count("input_weights"))
cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl;
@@ -279,9 +296,16 @@ main(int argc, char** argv)
{
string in;
+ string ref;
bool next = false, stop = false; // next iteration or premature stop
if (t == 0) {
if(!getline(*input, in)) next = true;
+ if(read_bitext) {
+ vector<string> strs;
+ boost::algorithm::split_regex(strs, in, boost::regex(" \\|\\|\\| "));
+ in = strs[0];
+ ref = strs[1];
+ }
} else {
if (ii == in_sz) next = true; // stop if we reach the end of our input
}
@@ -318,10 +342,11 @@ main(int argc, char** argv)
// getting input
vector<WordID> ref_ids; // reference as vector<WordID>
if (t == 0) {
- string r_;
- getline(*refs, r_);
+ if (!read_bitext) {
+ getline(*refs, ref);
+ }
vector<string> ref_tok;
- boost::split(ref_tok, r_, boost::is_any_of(" "));
+ boost::split(ref_tok, ref, boost::is_any_of(" "));
register_and_convert(ref_tok, ref_ids);
ref_ids_buf.push_back(ref_ids);
src_str_buf.push_back(in);
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index 3981fb39..ccb5ad4d 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -9,6 +9,8 @@
#include <string.h>
#include <boost/algorithm/string.hpp>
+#include <boost/regex.hpp>
+#include <boost/algorithm/string/regex.hpp>
#include <boost/program_options.hpp>
#include "decoder.h"
diff --git a/training/dtrain/examples/standard/dtrain.ini b/training/dtrain/examples/standard/dtrain.ini
index e6d6382e..7dbb4ff0 100644
--- a/training/dtrain/examples/standard/dtrain.ini
+++ b/training/dtrain/examples/standard/dtrain.ini
@@ -1,5 +1,6 @@
-input=./nc-wmt11.de.gz
-refs=./nc-wmt11.en.gz
+#input=./nc-wmt11.de.gz
+#refs=./nc-wmt11.en.gz
+bitext=./nc-wmt11.gz
output=- # a weights file (add .gz for gzip compression) or STDOUT '-'
select_weights=VOID # output average (over epochs) weight vector
decoder_config=./cdec.ini # config for cdec
diff --git a/training/dtrain/examples/standard/nc-wmt11.gz b/training/dtrain/examples/standard/nc-wmt11.gz
new file mode 100644
index 00000000..c39c5aef
--- /dev/null
+++ b/training/dtrain/examples/standard/nc-wmt11.gz
Binary files differ