summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-09-30 00:33:06 +0200
committerPatrick Simianer <p@simianer.de>2011-09-30 00:33:06 +0200
commit58f4ff5b79a545d59e21e77511a4b74c99b63d56 (patch)
tree5239aca8878eb6e6002fd7e983da041a6465dd34 /dtrain/dtrain.cc
parent78fb5d2761551f4a1a4f4e8c19be88dc0348f3d9 (diff)
added iteration selection param
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc82
1 files changed, 55 insertions, 27 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 9b1bbe68..69f83633 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -23,6 +23,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("learning_rate", po::value<double>()->default_value(0.0005), "learning rate")
("gamma", po::value<double>()->default_value(0.), "gamma for SVM (0 for perceptron)")
("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use")
+ ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights")
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -56,6 +57,10 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
&& (*cfg)["sample_from"].as<string>() != "forest") {
cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
}
+ if ((*cfg)["select_weights"].as<string>() != "last"
+ && (*cfg)["select_weights"].as<string>() != "best") {
+ cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl;
+ }
return true;
}
@@ -83,6 +88,7 @@ main(int argc, char** argv)
const string filter_type = cfg["filter"].as<string>();
const string sample_from = cfg["sample_from"].as<string>();
const string pair_sampling = cfg["pair_sampling"].as<string>();
+ const string select_weights = cfg["select_weights"].as<string>();
vector<string> print_weights;
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));
@@ -140,13 +146,11 @@ main(int argc, char** argv)
// buffer input for t > 0
vector<string> src_str_buf; // source strings
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
- // this is for writing the grammar buffer file
- char grammar_buf_fn[1024];
- strcpy(grammar_buf_fn, cfg["tmp"].as<string>().c_str());
- strcat(grammar_buf_fn, "/dtrain-grammars-XXXXXX");
- mkstemp(grammar_buf_fn);
+ vector<string> weights_files; // remember weights for each iteration
+ string tmp_path = cfg["tmp"].as<string>();
+ string grammar_buf_fn = gettmpf(tmp_path, "dtrain-grammars");
ogzstream grammar_buf_out;
- grammar_buf_out.open(grammar_buf_fn);
+ grammar_buf_out.open(grammar_buf_fn.c_str());
unsigned in_sz = 999999999; // input index, input size
vector<pair<score_t,score_t> > all_scores;
@@ -173,6 +177,7 @@ main(int argc, char** argv)
cerr << setw(25) << "gamma " << gamma << endl;
cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;
cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl;
+ cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl;
if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl;
}
@@ -183,7 +188,7 @@ main(int argc, char** argv)
time_t start, end;
time(&start);
igzstream grammar_buf_in;
- if (t > 0) grammar_buf_in.open(grammar_buf_fn);
+ if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str());
score_t score_sum = 0., model_sum = 0.;
unsigned ii = 0;
if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl;
@@ -221,7 +226,7 @@ main(int argc, char** argv)
}
}
}
-
+
// next iteration
if (next || stop) break;
@@ -244,8 +249,8 @@ main(int argc, char** argv)
ref_ids_buf.push_back(ref_ids);
// process and set grammar
bool broken_grammar = true;
- for (string::iterator ti = in_split[3].begin(); ti != in_split[3].end(); ti++) {
- if (!isspace(*ti)) {
+ for (string::iterator it = in_split[3].begin(); it != in_split[3].end(); it++) {
+ if (!isspace(*it)) {
broken_grammar = false;
break;
}
@@ -300,14 +305,14 @@ main(int argc, char** argv)
if (pair_sampling == "108010")
sample108010(samples, pairs);
- for (vector<pair<ScoredHyp,ScoredHyp> >::iterator ti = pairs.begin();
- ti != pairs.end(); ti++) {
+ for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
+ it != pairs.end(); it++) {
SparseVector<double> dv;
- if (ti->first.score - ti->second.score < 0) {
- dv = ti->second.f - ti->first.f;
+ if (it->first.score - it->second.score < 0) {
+ dv = it->second.f - it->first.f;
//} else {
- //dv = ti->first - ti->second;
+ //dv = it->first - it->second;
//}
dv.add_value(FD::Convert("__bias"), -1);
@@ -317,7 +322,7 @@ main(int argc, char** argv)
lambdas += dv * eta;
if (verbose) {
- /*cerr << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl;
+ /*cerr << "{{ f("<< it->first_rank <<") > f(" << it->second_rank << ") but g(i)="<< it->first_score <<" < g(j)="<< it->second_score << " so update" << endl;
cerr << " i " << TD::GetString(samples->sents[ti->first_rank]) << endl;
cerr << " " << samples->feats[ti->first_rank] << endl;
cerr << " j " << TD::GetString(samples->sents[ti->second_rank]) << endl;
@@ -393,23 +398,46 @@ main(int argc, char** argv)
if (noup) break;
+ // write weights to file
+ if (select_weights == "best") {
+ weights.InitFromVector(lambdas);
+ string infix = "dtrain-weights-" + boost::lexical_cast<string>(t);
+ string w_fn = gettmpf(tmp_path, infix, "gz");
+ weights.WriteToFile(w_fn, true);
+ weights_files.push_back(w_fn);
+ }
+
} // outer loop
- unlink(grammar_buf_fn);
+ unlink(grammar_buf_fn.c_str());
if (!noup) {
- if (!quiet) cerr << endl << "writing weights file '" << cfg["output"].as<string>() << "' ...";
- if (cfg["output"].as<string>() == "-") {
- cout << _p9;
- for (SparseVector<double>::const_iterator ti = lambdas.begin();
- ti != lambdas.end(); ++ti) {
- if (ti->second == 0) continue;
- cout << _np << FD::Convert(ti->first) << "\t" << ti->second << endl;
+ if (!quiet) cerr << endl << "writing weights file to '" << cfg["output"].as<string>() << "' ...";
+ if (select_weights == "last") { // last
+ WriteFile out(cfg["output"].as<string>());
+ ostream& o = *out.stream();
+ o.precision(17);
+ o << _np;
+ for (SparseVector<double>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
+ if (it->second == 0) continue;
+ o << FD::Convert(it->first) << '\t' << it->second << endl;
}
if (hstreaming) cout << "__SHARD_COUNT__\t1" << endl;
- } else if (cfg["output"].as<string>() != "VOID") {
- weights.InitFromVector(lambdas);
- weights.WriteToFile(cfg["output"].as<string>(), true);
+ } else { // best
+ if (cfg["output"].as<string>() != "-") {
+ CopyFile(weights_files[best_it], cfg["output"].as<string>());
+ } else {
+ ReadFile(weights_files[best_it]);
+ string o;
+ cout.precision(17);
+ cout << _np;
+ while(getline(*input, o)) cout << o << endl;
+ }
+ for (vector<string>::iterator it = weights_files.begin(); it != weights_files.end(); ++it) {
+ unlink(it->c_str());
+ it->erase(it->end()-3, it->end());
+ unlink(it->c_str());
+ }
}
if (!quiet) cerr << "done" << endl;
}