summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc67
1 files changed, 23 insertions, 44 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 823a50de..737326f8 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -12,9 +12,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
{
po::options_description ini("Configuration File Options");
ini.add_options()
- ("input", po::value<string>(), "input file (src)")
- ("refs,r", po::value<string>(), "references")
- ("bitext,b", po::value<string>(), "bitext: 'src ||| tgt'")
+ ("bitext,b", po::value<string>(), "bitext: 'src ||| tgt ||| tgt ||| ...'")
("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
("decoder_config", po::value<string>(), "configuration file for cdec")
@@ -84,8 +82,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "hi_lo must lie in [0.01, 0.5]" << endl;
return false;
}
- if ((cfg->count("input")>0 || cfg->count("refs")>0) && cfg->count("bitext")>0) {
- cerr << "Provide 'input' and 'refs' or 'bitext', not both." << endl;
+ if (!cfg->count("bitext")) {
+ cerr << "No training data given." << endl;
return false;
}
if ((*cfg)["pair_threshold"].as<score_t>() < 0) {
@@ -221,24 +219,11 @@ main(int argc, char** argv)
// output
string output_fn = cfg["output"].as<string>();
// input
- bool read_bitext = false;
string input_fn;
- if (cfg.count("bitext")) {
- read_bitext = true;
- input_fn = cfg["bitext"].as<string>();
- } else {
- input_fn = cfg["input"].as<string>();
- }
- ReadFile input(input_fn);
+ ReadFile input(cfg["bitext"].as<string>());
// buffer input for t > 0
vector<string> src_str_buf; // source strings (decoder takes only strings)
- vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
- ReadFile refs;
- string refs_fn;
- if (!read_bitext) {
- refs_fn = cfg["refs"].as<string>();
- refs.Init(refs_fn);
- }
+ vector<vector<vector<WordID> > > refs_as_ids_buf; // references as WordID vecs
unsigned in_sz = std::numeric_limits<unsigned>::max(); // input index, input size
vector<pair<score_t, score_t> > all_scores;
@@ -280,8 +265,6 @@ main(int argc, char** argv)
//cerr << setw(25) << "test k-best " << test_k_best << endl;
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
- if (!read_bitext)
- cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
if (cfg.count("input_weights"))
cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl;
@@ -311,17 +294,13 @@ main(int argc, char** argv)
{
string in;
- vector<string> ref;
+ vector<string> refs;
bool next = false, stop = false; // next iteration or premature stop
if (t == 0) {
if(!getline(*input, in)) next = true;
- if(read_bitext) {
- vector<string> strs;
- boost::algorithm::split_regex(strs, in, boost::regex(" \\|\\|\\| "));
- in = strs[0];
- strs.erase(strs.begin());
- ref = strs;
- }
+ boost::algorithm::split_regex(refs, in, boost::regex(" \\|\\|\\| "));
+ in = refs[0];
+ refs.erase(refs.begin());
} else {
if (ii == in_sz) next = true; // stop if we reach the end of our input
}
@@ -356,20 +335,19 @@ main(int argc, char** argv)
lambdas.init_vector(&decoder_weights);
// getting input
- vector<vector<WordID> ref_ids; // reference as vector<WordID>
if (t == 0) {
- if (!read_bitext) {
- getline(*refs, ref);
+ vector<vector<WordID> > cur_refs;
+ for (auto r: refs) {
+ vector<WordID> cur_ref;
+ vector<string> tok;
+ boost::split(tok, r, boost::is_any_of(" "));
+ register_and_convert(tok, cur_ref);
+ cur_refs.push_back(cur_ref);
}
- vector<string> ref_tok;
- boost::split(ref_tok, ref, boost::is_any_of(" "));
- register_and_convert(ref_tok, ref_ids);
- ref_ids_buf.push_back(ref_ids);
+ refs_as_ids_buf.push_back(cur_refs);
src_str_buf.push_back(in);
- } else {
- ref_ids = ref_ids_buf[ii];
}
- observer->SetRef(ref_ids);
+ observer->SetRef(refs_as_ids_buf[ii]);
if (t == 0)
decoder.Decode(in, observer);
else
@@ -379,10 +357,11 @@ main(int argc, char** argv)
vector<ScoredHyp>* samples = observer->GetSamples();
if (verbose) {
- cerr << "--- ref for " << ii << ": ";
- if (t > 0) printWordIDVec(ref_ids_buf[ii]);
- else printWordIDVec(ref_ids);
- cerr << endl;
+ cerr << "--- refs for " << ii << ": ";
+ for (auto r: refs_as_ids_buf[ii]) {
+ printWordIDVec(r);
+ cerr << endl;
+ }
for (unsigned u = 0; u < samples->size(); u++) {
cerr << _p2 << _np << "[" << u << ". '";
printWordIDVec((*samples)[u].w);