summaryrefslogtreecommitdiff
path: root/dtrain
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain')
-rw-r--r--dtrain/README.md20
-rw-r--r--dtrain/dtrain.cc50
-rw-r--r--dtrain/dtrain.h3
3 files changed, 70 insertions, 3 deletions
diff --git a/dtrain/README.md b/dtrain/README.md
index b453c649..bc96ed18 100644
--- a/dtrain/README.md
+++ b/dtrain/README.md
@@ -43,7 +43,9 @@ Uncertain, known bugs, problems
FIXME
-----
-none
+merge dtrain part-* files
+mapred count shard sents
+
Data
----
@@ -61,3 +63,19 @@ ep-v6.de-en.cs.loo p
p: prep, e: extract, g: grammar, d: dtrain
</pre>
+
+Experiments
+-----------
+features
+ TODO
+
+"lm open better than lm closed when tuned"
+
+mira100-10
+mira100-17
+
+baselines
+ mira
+ pro
+ vest
+
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 27315358..277d4e14 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -24,6 +24,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)")
("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use")
("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")
+#ifdef DTRAIN_LOCAL
+ ("refs,r", po::value<string>(), "references for local mode")
+#endif
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -45,6 +48,12 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "When using 'hstreaming' the 'output' param should be '-'.";
return false;
}
+#ifdef DTRAIN_LOCAL
+ if ((*cfg)["input"].as<string>() == "-") {
+ cerr << "Can't use stdin as input with this binary. Recompile without DTRAIN_LOCAL" << endl;
+ return false;
+ }
+#endif
if ((*cfg)["sample_from"].as<string>() != "kbest"
&& (*cfg)["sample_from"].as<string>() != "forest") {
cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
@@ -148,13 +157,19 @@ main(int argc, char** argv)
string input_fn = cfg["input"].as<string>();
ReadFile input(input_fn);
// buffer input for t > 0
- vector<string> src_str_buf; // source strings
+ vector<string> src_str_buf; // source strings (decoder takes only strings)
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
vector<string> weights_files; // remember weights for each iteration
+ // where temp files go
string tmp_path = cfg["tmp"].as<string>();
+#ifdef DTRAIN_LOCAL
+ string refs_fn = cfg["refs"].as<string>();
+ ReadFile refs(refs_fn);
+#else
string grammar_buf_fn = gettmpf(tmp_path, "dtrain-grammars");
ogzstream grammar_buf_out;
grammar_buf_out.open(grammar_buf_fn.c_str());
+#endif
unsigned in_sz = UINT_MAX; // input index, input size
vector<pair<score_t, score_t> > all_scores;
@@ -174,6 +189,9 @@ main(int argc, char** argv)
if (cfg.count("input_weights"))
cerr << setw(25) << "weights in" << cfg["input_weights"].as<string>() << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
+#ifdef DTRAIN_LOCAL
+ cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
+#endif
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
if (sample_from == "kbest")
cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl;
@@ -191,8 +209,10 @@ main(int argc, char** argv)
time_t start, end;
time(&start);
+#ifndef DTRAIN_LOCAL
igzstream grammar_buf_in;
if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str());
+#endif
score_t score_sum = 0.;
score_t model_sum(0);
unsigned ii = 0, nup = 0, npairs = 0;
@@ -239,8 +259,9 @@ main(int argc, char** argv)
lambdas.init_vector(&dense_weights);
// getting input
+ vector<WordID> ref_ids; // reference as vector<WordID>
+#ifndef DTRAIN_LOCAL
vector<string> in_split; // input: sid\tsrc\tref\tpsg
- vector<WordID> ref_ids; // reference as vector<WordID>
if (t == 0) {
// handling input
split_in(in, in_split);
@@ -280,6 +301,24 @@ main(int argc, char** argv)
observer->SetRef(ref_ids_buf[ii]);
decoder.Decode(src_str_buf[ii], observer);
}
+#else
+ if (t == 0) {
+ string r_;
+ getline(*refs, r_);
+ vector<string> ref_tok;
+ boost::split(ref_tok, r_, boost::is_any_of(" "));
+ register_and_convert(ref_tok, ref_ids);
+ ref_ids_buf.push_back(ref_ids);
+ src_str_buf.push_back(in);
+ } else {
+ ref_ids = ref_ids_buf[ii];
+ }
+ observer->SetRef(ref_ids);
+ if (t == 0)
+ decoder.Decode(in, observer);
+ else
+ decoder.Decode(src_str_buf[ii], observer);
+#endif
// get (scored) samples
vector<ScoredHyp>* samples = observer->GetSamples();
@@ -345,10 +384,15 @@ main(int argc, char** argv)
if (t == 0) {
in_sz = ii; // remember size of input (# lines)
+ }
+
+#ifndef DTRAIN_LOCAL
+ if (t == 0) {
grammar_buf_out.close();
} else {
grammar_buf_in.close();
}
+#endif
// print some stats
score_t score_avg = score_sum/(score_t)in_sz;
@@ -406,7 +450,9 @@ main(int argc, char** argv)
} // outer loop
+#ifndef DTRAIN_LOCAL
unlink(grammar_buf_fn.c_str());
+#endif
if (!noup) {
if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl;
diff --git a/dtrain/dtrain.h b/dtrain/dtrain.h
index f4d32ecb..6c9decf4 100644
--- a/dtrain/dtrain.h
+++ b/dtrain/dtrain.h
@@ -2,6 +2,7 @@
#define _DTRAIN_COMMON_H_
+
#include <iomanip>
#include <climits>
#include <string.h>
@@ -14,6 +15,8 @@
#include "filelib.h"
+//#define DTRAIN_LOCAL
+
#define DTRAIN_DOTS 100 // when to display a '.'
#define DTRAIN_GRAMMAR_DELIM "########EOS########"