diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 50 |
1 files changed, 48 insertions, 2 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 27315358..277d4e14 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -24,6 +24,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)") ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use") ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") +#ifdef DTRAIN_LOCAL + ("refs,r", po::value<string>(), "references for local mode") +#endif ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() @@ -45,6 +48,12 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) cerr << "When using 'hstreaming' the 'output' param should be '-'."; return false; } +#ifdef DTRAIN_LOCAL + if ((*cfg)["input"].as<string>() == "-") { + cerr << "Can't use stdin as input with this binary. Recompile without DTRAIN_LOCAL" << endl; + return false; + } +#endif if ((*cfg)["sample_from"].as<string>() != "kbest" && (*cfg)["sample_from"].as<string>() != "forest") { cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl; @@ -148,13 +157,19 @@ main(int argc, char** argv) string input_fn = cfg["input"].as<string>(); ReadFile input(input_fn); // buffer input for t > 0 - vector<string> src_str_buf; // source strings + vector<string> src_str_buf; // source strings (decoder takes only strings) vector<vector<WordID> > ref_ids_buf; // references as WordID vecs vector<string> weights_files; // remember weights for each iteration + // where temp files go string tmp_path = cfg["tmp"].as<string>(); +#ifdef DTRAIN_LOCAL + string refs_fn = cfg["refs"].as<string>(); + ReadFile refs(refs_fn); +#else string grammar_buf_fn = gettmpf(tmp_path, "dtrain-grammars"); ogzstream grammar_buf_out; grammar_buf_out.open(grammar_buf_fn.c_str()); +#endif unsigned in_sz = UINT_MAX; // input index, input size vector<pair<score_t, score_t> > all_scores; @@ -174,6 +189,9 @@ main(int argc, char** argv) if (cfg.count("input_weights")) cerr << setw(25) << "weights in" << cfg["input_weights"].as<string>() << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; +#ifdef DTRAIN_LOCAL + cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; +#endif cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; if (sample_from == "kbest") cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl; @@ -191,8 +209,10 @@ main(int argc, char** argv) time_t start, end; time(&start); +#ifndef DTRAIN_LOCAL igzstream grammar_buf_in; if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str()); +#endif score_t score_sum = 0.; score_t model_sum(0); unsigned ii = 0, nup = 0, npairs = 0; @@ -239,8 +259,9 @@ main(int argc, char** argv) lambdas.init_vector(&dense_weights); // getting input + vector<WordID> ref_ids; // reference as vector<WordID> +#ifndef DTRAIN_LOCAL vector<string> in_split; // input: sid\tsrc\tref\tpsg - vector<WordID> ref_ids; // reference as vector<WordID> if (t == 0) { // handling input split_in(in, in_split); @@ -280,6 +301,24 @@ main(int argc, char** argv) observer->SetRef(ref_ids_buf[ii]); decoder.Decode(src_str_buf[ii], observer); } +#else + if (t == 0) { + string r_; + getline(*refs, r_); + vector<string> ref_tok; + boost::split(ref_tok, r_, boost::is_any_of(" ")); + register_and_convert(ref_tok, ref_ids); + ref_ids_buf.push_back(ref_ids); + src_str_buf.push_back(in); + } else { + ref_ids = ref_ids_buf[ii]; + } + observer->SetRef(ref_ids); + if (t == 0) + decoder.Decode(in, observer); + else + decoder.Decode(src_str_buf[ii], observer); +#endif // get (scored) samples vector<ScoredHyp>* samples = observer->GetSamples(); @@ -345,10 +384,15 @@ main(int argc, char** argv) if (t == 0) { in_sz = ii; // remember size of input (# lines) + } + +#ifndef DTRAIN_LOCAL + if (t == 0) { grammar_buf_out.close(); } else { grammar_buf_in.close(); } +#endif // print some stats score_t score_avg = score_sum/(score_t)in_sz; @@ -406,7 +450,9 @@ main(int argc, char** argv) } // outer loop +#ifndef DTRAIN_LOCAL unlink(grammar_buf_fn.c_str()); +#endif if (!noup) { if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl; |