diff options
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 67 |
1 files changed, 23 insertions, 44 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 823a50de..737326f8 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -12,9 +12,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) { po::options_description ini("Configuration File Options"); ini.add_options() - ("input", po::value<string>(), "input file (src)") - ("refs,r", po::value<string>(), "references") - ("bitext,b", po::value<string>(), "bitext: 'src ||| tgt'") + ("bitext,b", po::value<string>(), "bitext: 'src ||| tgt ||| tgt ||| ...'") ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") ("decoder_config", po::value<string>(), "configuration file for cdec") @@ -84,8 +82,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) cerr << "hi_lo must lie in [0.01, 0.5]" << endl; return false; } - if ((cfg->count("input")>0 || cfg->count("refs")>0) && cfg->count("bitext")>0) { - cerr << "Provide 'input' and 'refs' or 'bitext', not both." << endl; + if (!cfg->count("bitext")) { + cerr << "No training data given." << endl; return false; } if ((*cfg)["pair_threshold"].as<score_t>() < 0) { @@ -221,24 +219,11 @@ main(int argc, char** argv) // output string output_fn = cfg["output"].as<string>(); // input - bool read_bitext = false; string input_fn; - if (cfg.count("bitext")) { - read_bitext = true; - input_fn = cfg["bitext"].as<string>(); - } else { - input_fn = cfg["input"].as<string>(); - } - ReadFile input(input_fn); + ReadFile input(cfg["bitext"].as<string>()); // buffer input for t > 0 vector<string> src_str_buf; // source strings (decoder takes only strings) - vector<vector<WordID> > ref_ids_buf; // references as WordID vecs - ReadFile refs; - string refs_fn; - if (!read_bitext) { - refs_fn = cfg["refs"].as<string>(); - refs.Init(refs_fn); - } + vector<vector<vector<WordID> > > refs_as_ids_buf; // references as WordID vecs unsigned in_sz = std::numeric_limits<unsigned>::max(); // input index, input size vector<pair<score_t, score_t> > all_scores; @@ -280,8 +265,6 @@ main(int argc, char** argv) //cerr << setw(25) << "test k-best " << test_k_best << endl; cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; - if (!read_bitext) - cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; if (cfg.count("input_weights")) cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl; @@ -311,17 +294,13 @@ main(int argc, char** argv) { string in; - vector<string> ref; + vector<string> refs; bool next = false, stop = false; // next iteration or premature stop if (t == 0) { if(!getline(*input, in)) next = true; - if(read_bitext) { - vector<string> strs; - boost::algorithm::split_regex(strs, in, boost::regex(" \\|\\|\\| ")); - in = strs[0]; - strs.erase(strs.begin()); - ref = strs; - } + boost::algorithm::split_regex(refs, in, boost::regex(" \\|\\|\\| ")); + in = refs[0]; + refs.erase(refs.begin()); } else { if (ii == in_sz) next = true; // stop if we reach the end of our input } @@ -356,20 +335,19 @@ main(int argc, char** argv) lambdas.init_vector(&decoder_weights); // getting input - vector<vector<WordID> ref_ids; // reference as vector<WordID> if (t == 0) { - if (!read_bitext) { - getline(*refs, ref); + vector<vector<WordID> > cur_refs; + for (auto r: refs) { + vector<WordID> cur_ref; + vector<string> tok; + boost::split(tok, r, boost::is_any_of(" ")); + register_and_convert(tok, cur_ref); + cur_refs.push_back(cur_ref); } - vector<string> ref_tok; - boost::split(ref_tok, ref, boost::is_any_of(" ")); - register_and_convert(ref_tok, ref_ids); - ref_ids_buf.push_back(ref_ids); + refs_as_ids_buf.push_back(cur_refs); src_str_buf.push_back(in); - } else { - ref_ids = ref_ids_buf[ii]; } - observer->SetRef(ref_ids); + observer->SetRef(refs_as_ids_buf[ii]); if (t == 0) decoder.Decode(in, observer); else @@ -379,10 +357,11 @@ main(int argc, char** argv) vector<ScoredHyp>* samples = observer->GetSamples(); if (verbose) { - cerr << "--- ref for " << ii << ": "; - if (t > 0) printWordIDVec(ref_ids_buf[ii]); - else printWordIDVec(ref_ids); - cerr << endl; + cerr << "--- refs for " << ii << ": "; + for (auto r: refs_as_ids_buf[ii]) { + printWordIDVec(r); + cerr << endl; + } for (unsigned u = 0; u < samples->size(); u++) { cerr << _p2 << _np << "[" << u << ". '"; printWordIDVec((*samples)[u].w); |