summaryrefslogtreecommitdiff
path: root/dtrain
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain')
-rw-r--r--dtrain/dtrain.cc82
-rw-r--r--dtrain/dtrain.h14
-rw-r--r--dtrain/hstreaming/red-test7
-rw-r--r--dtrain/test/example/dtrain.ini5
4 files changed, 79 insertions, 29 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 9b1bbe68..69f83633 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -23,6 +23,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("learning_rate", po::value<double>()->default_value(0.0005), "learning rate")
("gamma", po::value<double>()->default_value(0.), "gamma for SVM (0 for perceptron)")
("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use")
+ ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights")
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -56,6 +57,10 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
&& (*cfg)["sample_from"].as<string>() != "forest") {
cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
}
+ if ((*cfg)["select_weights"].as<string>() != "last"
+ && (*cfg)["select_weights"].as<string>() != "best") {
+ cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl;
+ }
return true;
}
@@ -83,6 +88,7 @@ main(int argc, char** argv)
const string filter_type = cfg["filter"].as<string>();
const string sample_from = cfg["sample_from"].as<string>();
const string pair_sampling = cfg["pair_sampling"].as<string>();
+ const string select_weights = cfg["select_weights"].as<string>();
vector<string> print_weights;
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));
@@ -140,13 +146,11 @@ main(int argc, char** argv)
// buffer input for t > 0
vector<string> src_str_buf; // source strings
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
- // this is for writing the grammar buffer file
- char grammar_buf_fn[1024];
- strcpy(grammar_buf_fn, cfg["tmp"].as<string>().c_str());
- strcat(grammar_buf_fn, "/dtrain-grammars-XXXXXX");
- mkstemp(grammar_buf_fn);
+ vector<string> weights_files; // remember weights for each iteration
+ string tmp_path = cfg["tmp"].as<string>();
+ string grammar_buf_fn = gettmpf(tmp_path, "dtrain-grammars");
ogzstream grammar_buf_out;
- grammar_buf_out.open(grammar_buf_fn);
+ grammar_buf_out.open(grammar_buf_fn.c_str());
unsigned in_sz = 999999999; // input index, input size
vector<pair<score_t,score_t> > all_scores;
@@ -173,6 +177,7 @@ main(int argc, char** argv)
cerr << setw(25) << "gamma " << gamma << endl;
cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;
cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl;
+ cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl;
if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl;
}
@@ -183,7 +188,7 @@ main(int argc, char** argv)
time_t start, end;
time(&start);
igzstream grammar_buf_in;
- if (t > 0) grammar_buf_in.open(grammar_buf_fn);
+ if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str());
score_t score_sum = 0., model_sum = 0.;
unsigned ii = 0;
if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl;
@@ -221,7 +226,7 @@ main(int argc, char** argv)
}
}
}
-
+
// next iteration
if (next || stop) break;
@@ -244,8 +249,8 @@ main(int argc, char** argv)
ref_ids_buf.push_back(ref_ids);
// process and set grammar
bool broken_grammar = true;
- for (string::iterator ti = in_split[3].begin(); ti != in_split[3].end(); ti++) {
- if (!isspace(*ti)) {
+ for (string::iterator it = in_split[3].begin(); it != in_split[3].end(); it++) {
+ if (!isspace(*it)) {
broken_grammar = false;
break;
}
@@ -300,14 +305,14 @@ main(int argc, char** argv)
if (pair_sampling == "108010")
sample108010(samples, pairs);
- for (vector<pair<ScoredHyp,ScoredHyp> >::iterator ti = pairs.begin();
- ti != pairs.end(); ti++) {
+ for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
+ it != pairs.end(); it++) {
SparseVector<double> dv;
- if (ti->first.score - ti->second.score < 0) {
- dv = ti->second.f - ti->first.f;
+ if (it->first.score - it->second.score < 0) {
+ dv = it->second.f - it->first.f;
//} else {
- //dv = ti->first - ti->second;
+ //dv = it->first - it->second;
//}
dv.add_value(FD::Convert("__bias"), -1);
@@ -317,7 +322,7 @@ main(int argc, char** argv)
lambdas += dv * eta;
if (verbose) {
- /*cerr << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl;
+ /*cerr << "{{ f("<< it->first_rank <<") > f(" << it->second_rank << ") but g(i)="<< it->first_score <<" < g(j)="<< it->second_score << " so update" << endl;
cerr << " i " << TD::GetString(samples->sents[ti->first_rank]) << endl;
cerr << " " << samples->feats[ti->first_rank] << endl;
cerr << " j " << TD::GetString(samples->sents[ti->second_rank]) << endl;
@@ -393,23 +398,46 @@ main(int argc, char** argv)
if (noup) break;
+ // write weights to file
+ if (select_weights == "best") {
+ weights.InitFromVector(lambdas);
+ string infix = "dtrain-weights-" + boost::lexical_cast<string>(t);
+ string w_fn = gettmpf(tmp_path, infix, "gz");
+ weights.WriteToFile(w_fn, true);
+ weights_files.push_back(w_fn);
+ }
+
} // outer loop
- unlink(grammar_buf_fn);
+ unlink(grammar_buf_fn.c_str());
if (!noup) {
- if (!quiet) cerr << endl << "writing weights file '" << cfg["output"].as<string>() << "' ...";
- if (cfg["output"].as<string>() == "-") {
- cout << _p9;
- for (SparseVector<double>::const_iterator ti = lambdas.begin();
- ti != lambdas.end(); ++ti) {
- if (ti->second == 0) continue;
- cout << _np << FD::Convert(ti->first) << "\t" << ti->second << endl;
+ if (!quiet) cerr << endl << "writing weights file to '" << cfg["output"].as<string>() << "' ...";
+ if (select_weights == "last") { // last
+ WriteFile out(cfg["output"].as<string>());
+ ostream& o = *out.stream();
+ o.precision(17);
+ o << _np;
+ for (SparseVector<double>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
+ if (it->second == 0) continue;
+ o << FD::Convert(it->first) << '\t' << it->second << endl;
}
if (hstreaming) cout << "__SHARD_COUNT__\t1" << endl;
- } else if (cfg["output"].as<string>() != "VOID") {
- weights.InitFromVector(lambdas);
- weights.WriteToFile(cfg["output"].as<string>(), true);
+ } else { // best
+ if (cfg["output"].as<string>() != "-") {
+ CopyFile(weights_files[best_it], cfg["output"].as<string>());
+ } else {
+ ReadFile(weights_files[best_it]);
+ string o;
+ cout.precision(17);
+ cout << _np;
+ while(getline(*input, o)) cout << o << endl;
+ }
+ for (vector<string>::iterator it = weights_files.begin(); it != weights_files.end(); ++it) {
+ unlink(it->c_str());
+ it->erase(it->end()-3, it->end());
+ unlink(it->c_str());
+ }
}
if (!quiet) cerr << "done" << endl;
}
diff --git a/dtrain/dtrain.h b/dtrain/dtrain.h
index c1d910aa..34464e3c 100644
--- a/dtrain/dtrain.h
+++ b/dtrain/dtrain.h
@@ -24,6 +24,20 @@ inline void register_and_convert(const vector<string>& strs, vector<WordID>& ids
ids.push_back(TD::Convert(*it));
}
+inline string gettmpf(const string path, const string infix, const string suffix="") {
+ char fn[1024];
+ strcpy(fn, path.c_str());
+ strcat(fn, "/");
+ strcat(fn, infix.c_str());
+ strcat(fn, "-XXXXXX");
+ mkstemp(fn);
+ if (suffix != "") { // we will get 2 files
+ strcat(fn, ".");
+ strcat(fn, suffix.c_str());
+ }
+ return string(fn);
+}
+
inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); }
inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); }
inline ostream& _p2(ostream& out) { return out << setprecision(2); }
diff --git a/dtrain/hstreaming/red-test b/dtrain/hstreaming/red-test
new file mode 100644
index 00000000..b86e7894
--- /dev/null
+++ b/dtrain/hstreaming/red-test
@@ -0,0 +1,7 @@
+a 1
+b 2
+c 3.5
+a 1
+b 2
+c 3.5
+__SHARD_COUNT__ 2
diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini
index 068074c4..c560a3a6 100644
--- a/dtrain/test/example/dtrain.ini
+++ b/dtrain/test/example/dtrain.ini
@@ -1,12 +1,13 @@
decoder_config=test/example/cdec.ini
k=100
N=3
-epochs=3
+epochs=1000
input=test/example/nc-1k.gz
scorer=stupid_bleu
output=/tmp/weights.gz
-stop_after=1000
+stop_after=10
sample_from=kbest
pair_sampling=all
print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough
tmp=/tmp
+select_weights=best