From 7417ad2225b4c049cb7cb7122a717e8c8b6e5eaa Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Fri, 28 Oct 2011 14:20:30 +0200
Subject: added support for standard dev tuning setting
---
dtrain/dtrain.cc | 50 ++++++++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 48 insertions(+), 2 deletions(-)
(limited to 'dtrain/dtrain.cc')
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 27315358..277d4e14 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -24,6 +24,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("gamma", po::value()->default_value(0), "gamma for SVM (0 for perceptron)")
("tmp", po::value()->default_value("/tmp"), "temp dir to use")
("select_weights", po::value()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")
+#ifdef DTRAIN_LOCAL
+ ("refs,r", po::value(), "references for local mode")
+#endif
("noup", po::value()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -45,6 +48,12 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "When using 'hstreaming' the 'output' param should be '-'.";
return false;
}
+#ifdef DTRAIN_LOCAL
+ if ((*cfg)["input"].as() == "-") {
+ cerr << "Can't use stdin as input with this binary. Recompile without DTRAIN_LOCAL" << endl;
+ return false;
+ }
+#endif
if ((*cfg)["sample_from"].as() != "kbest"
&& (*cfg)["sample_from"].as() != "forest") {
cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as() << "', use 'kbest' or 'forest'." << endl;
@@ -148,13 +157,19 @@ main(int argc, char** argv)
string input_fn = cfg["input"].as();
ReadFile input(input_fn);
// buffer input for t > 0
- vector src_str_buf; // source strings
+ vector src_str_buf; // source strings (decoder takes only strings)
vector > ref_ids_buf; // references as WordID vecs
vector weights_files; // remember weights for each iteration
+ // where temp files go
string tmp_path = cfg["tmp"].as();
+#ifdef DTRAIN_LOCAL
+ string refs_fn = cfg["refs"].as();
+ ReadFile refs(refs_fn);
+#else
string grammar_buf_fn = gettmpf(tmp_path, "dtrain-grammars");
ogzstream grammar_buf_out;
grammar_buf_out.open(grammar_buf_fn.c_str());
+#endif
unsigned in_sz = UINT_MAX; // input index, input size
vector > all_scores;
@@ -174,6 +189,9 @@ main(int argc, char** argv)
if (cfg.count("input_weights"))
cerr << setw(25) << "weights in" << cfg["input_weights"].as() << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
+#ifdef DTRAIN_LOCAL
+ cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
+#endif
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
if (sample_from == "kbest")
cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl;
@@ -191,8 +209,10 @@ main(int argc, char** argv)
time_t start, end;
time(&start);
+#ifndef DTRAIN_LOCAL
igzstream grammar_buf_in;
if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str());
+#endif
score_t score_sum = 0.;
score_t model_sum(0);
unsigned ii = 0, nup = 0, npairs = 0;
@@ -239,8 +259,9 @@ main(int argc, char** argv)
lambdas.init_vector(&dense_weights);
// getting input
+ vector ref_ids; // reference as vector
+#ifndef DTRAIN_LOCAL
vector in_split; // input: sid\tsrc\tref\tpsg
- vector ref_ids; // reference as vector
if (t == 0) {
// handling input
split_in(in, in_split);
@@ -280,6 +301,24 @@ main(int argc, char** argv)
observer->SetRef(ref_ids_buf[ii]);
decoder.Decode(src_str_buf[ii], observer);
}
+#else
+ if (t == 0) {
+ string r_;
+ getline(*refs, r_);
+ vector ref_tok;
+ boost::split(ref_tok, r_, boost::is_any_of(" "));
+ register_and_convert(ref_tok, ref_ids);
+ ref_ids_buf.push_back(ref_ids);
+ src_str_buf.push_back(in);
+ } else {
+ ref_ids = ref_ids_buf[ii];
+ }
+ observer->SetRef(ref_ids);
+ if (t == 0)
+ decoder.Decode(in, observer);
+ else
+ decoder.Decode(src_str_buf[ii], observer);
+#endif
// get (scored) samples
vector* samples = observer->GetSamples();
@@ -345,10 +384,15 @@ main(int argc, char** argv)
if (t == 0) {
in_sz = ii; // remember size of input (# lines)
+ }
+
+#ifndef DTRAIN_LOCAL
+ if (t == 0) {
grammar_buf_out.close();
} else {
grammar_buf_in.close();
}
+#endif
// print some stats
score_t score_avg = score_sum/(score_t)in_sz;
@@ -406,7 +450,9 @@ main(int argc, char** argv)
} // outer loop
+#ifndef DTRAIN_LOCAL
unlink(grammar_buf_fn.c_str());
+#endif
if (!noup) {
if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl;
--
cgit v1.2.3