summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-10-28 14:20:30 +0200
committerPatrick Simianer <p@simianer.de>2011-10-28 14:20:30 +0200
commit7417ad2225b4c049cb7cb7122a717e8c8b6e5eaa (patch)
treee9c8a6bb35e8c09e75cdb7c43767d4da847a3d64 /dtrain/dtrain.cc
parentaaeb2dec23ff9257a9fc86ba49ee8d97f18138cd (diff)
added support for standard dev tuning setting
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc50
1 files changed, 48 insertions, 2 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 27315358..277d4e14 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -24,6 +24,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)")
("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use")
("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")
+#ifdef DTRAIN_LOCAL
+ ("refs,r", po::value<string>(), "references for local mode")
+#endif
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -45,6 +48,12 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "When using 'hstreaming' the 'output' param should be '-'.";
return false;
}
+#ifdef DTRAIN_LOCAL
+ if ((*cfg)["input"].as<string>() == "-") {
+ cerr << "Can't use stdin as input with this binary. Recompile without DTRAIN_LOCAL" << endl;
+ return false;
+ }
+#endif
if ((*cfg)["sample_from"].as<string>() != "kbest"
&& (*cfg)["sample_from"].as<string>() != "forest") {
cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
@@ -148,13 +157,19 @@ main(int argc, char** argv)
string input_fn = cfg["input"].as<string>();
ReadFile input(input_fn);
// buffer input for t > 0
- vector<string> src_str_buf; // source strings
+ vector<string> src_str_buf; // source strings (decoder takes only strings)
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
vector<string> weights_files; // remember weights for each iteration
+ // where temp files go
string tmp_path = cfg["tmp"].as<string>();
+#ifdef DTRAIN_LOCAL
+ string refs_fn = cfg["refs"].as<string>();
+ ReadFile refs(refs_fn);
+#else
string grammar_buf_fn = gettmpf(tmp_path, "dtrain-grammars");
ogzstream grammar_buf_out;
grammar_buf_out.open(grammar_buf_fn.c_str());
+#endif
unsigned in_sz = UINT_MAX; // input index, input size
vector<pair<score_t, score_t> > all_scores;
@@ -174,6 +189,9 @@ main(int argc, char** argv)
if (cfg.count("input_weights"))
cerr << setw(25) << "weights in" << cfg["input_weights"].as<string>() << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
+#ifdef DTRAIN_LOCAL
+ cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
+#endif
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
if (sample_from == "kbest")
cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl;
@@ -191,8 +209,10 @@ main(int argc, char** argv)
time_t start, end;
time(&start);
+#ifndef DTRAIN_LOCAL
igzstream grammar_buf_in;
if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str());
+#endif
score_t score_sum = 0.;
score_t model_sum(0);
unsigned ii = 0, nup = 0, npairs = 0;
@@ -239,8 +259,9 @@ main(int argc, char** argv)
lambdas.init_vector(&dense_weights);
// getting input
+ vector<WordID> ref_ids; // reference as vector<WordID>
+#ifndef DTRAIN_LOCAL
vector<string> in_split; // input: sid\tsrc\tref\tpsg
- vector<WordID> ref_ids; // reference as vector<WordID>
if (t == 0) {
// handling input
split_in(in, in_split);
@@ -280,6 +301,24 @@ main(int argc, char** argv)
observer->SetRef(ref_ids_buf[ii]);
decoder.Decode(src_str_buf[ii], observer);
}
+#else
+ if (t == 0) {
+ string r_;
+ getline(*refs, r_);
+ vector<string> ref_tok;
+ boost::split(ref_tok, r_, boost::is_any_of(" "));
+ register_and_convert(ref_tok, ref_ids);
+ ref_ids_buf.push_back(ref_ids);
+ src_str_buf.push_back(in);
+ } else {
+ ref_ids = ref_ids_buf[ii];
+ }
+ observer->SetRef(ref_ids);
+ if (t == 0)
+ decoder.Decode(in, observer);
+ else
+ decoder.Decode(src_str_buf[ii], observer);
+#endif
// get (scored) samples
vector<ScoredHyp>* samples = observer->GetSamples();
@@ -345,10 +384,15 @@ main(int argc, char** argv)
if (t == 0) {
in_sz = ii; // remember size of input (# lines)
+ }
+
+#ifndef DTRAIN_LOCAL
+ if (t == 0) {
grammar_buf_out.close();
} else {
grammar_buf_in.close();
}
+#endif
// print some stats
score_t score_avg = score_sum/(score_t)in_sz;
@@ -406,7 +450,9 @@ main(int argc, char** argv)
} // outer loop
+#ifndef DTRAIN_LOCAL
unlink(grammar_buf_fn.c_str());
+#endif
if (!noup) {
if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl;