diff options
Diffstat (limited to 'dtrain')
-rw-r--r-- | dtrain/dtrain.cc | 82 | ||||
-rw-r--r-- | dtrain/dtrain.h | 14 | ||||
-rw-r--r-- | dtrain/hstreaming/red-test | 7 | ||||
-rw-r--r-- | dtrain/test/example/dtrain.ini | 5 |
4 files changed, 79 insertions, 29 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 9b1bbe68..69f83633 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -23,6 +23,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("learning_rate", po::value<double>()->default_value(0.0005), "learning rate") ("gamma", po::value<double>()->default_value(0.), "gamma for SVM (0 for perceptron)") ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use") + ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights") ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() @@ -56,6 +57,10 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) && (*cfg)["sample_from"].as<string>() != "forest") { cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl; } + if ((*cfg)["select_weights"].as<string>() != "last" + && (*cfg)["select_weights"].as<string>() != "best") { + cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl; + } return true; } @@ -83,6 +88,7 @@ main(int argc, char** argv) const string filter_type = cfg["filter"].as<string>(); const string sample_from = cfg["sample_from"].as<string>(); const string pair_sampling = cfg["pair_sampling"].as<string>(); + const string select_weights = cfg["select_weights"].as<string>(); vector<string> print_weights; if (cfg.count("print_weights")) boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); @@ -140,13 +146,11 @@ main(int argc, char** argv) // buffer input for t > 0 vector<string> src_str_buf; // source strings vector<vector<WordID> > ref_ids_buf; // references as WordID vecs - // this is for writing the grammar buffer file - char grammar_buf_fn[1024]; - strcpy(grammar_buf_fn, cfg["tmp"].as<string>().c_str()); - strcat(grammar_buf_fn, "/dtrain-grammars-XXXXXX"); - mkstemp(grammar_buf_fn); + vector<string> weights_files; // remember weights for each iteration + string tmp_path = cfg["tmp"].as<string>(); + string grammar_buf_fn = gettmpf(tmp_path, "dtrain-grammars"); ogzstream grammar_buf_out; - grammar_buf_out.open(grammar_buf_fn); + grammar_buf_out.open(grammar_buf_fn.c_str()); unsigned in_sz = 999999999; // input index, input size vector<pair<score_t,score_t> > all_scores; @@ -173,6 +177,7 @@ main(int argc, char** argv) cerr << setw(25) << "gamma " << gamma << endl; cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; + cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl; } @@ -183,7 +188,7 @@ main(int argc, char** argv) time_t start, end; time(&start); igzstream grammar_buf_in; - if (t > 0) grammar_buf_in.open(grammar_buf_fn); + if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str()); score_t score_sum = 0., model_sum = 0.; unsigned ii = 0; if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl; @@ -221,7 +226,7 @@ main(int argc, char** argv) } } } - + // next iteration if (next || stop) break; @@ -244,8 +249,8 @@ main(int argc, char** argv) ref_ids_buf.push_back(ref_ids); // process and set grammar bool broken_grammar = true; - for (string::iterator ti = in_split[3].begin(); ti != in_split[3].end(); ti++) { - if (!isspace(*ti)) { + for (string::iterator it = in_split[3].begin(); it != in_split[3].end(); it++) { + if (!isspace(*it)) { broken_grammar = false; break; } @@ -300,14 +305,14 @@ main(int argc, char** argv) if (pair_sampling == "108010") sample108010(samples, pairs); - for (vector<pair<ScoredHyp,ScoredHyp> >::iterator ti = pairs.begin(); - ti != pairs.end(); ti++) { + for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); + it != pairs.end(); it++) { SparseVector<double> dv; - if (ti->first.score - ti->second.score < 0) { - dv = ti->second.f - ti->first.f; + if (it->first.score - it->second.score < 0) { + dv = it->second.f - it->first.f; //} else { - //dv = ti->first - ti->second; + //dv = it->first - it->second; //} dv.add_value(FD::Convert("__bias"), -1); @@ -317,7 +322,7 @@ main(int argc, char** argv) lambdas += dv * eta; if (verbose) { - /*cerr << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl; + /*cerr << "{{ f("<< it->first_rank <<") > f(" << it->second_rank << ") but g(i)="<< it->first_score <<" < g(j)="<< it->second_score << " so update" << endl; cerr << " i " << TD::GetString(samples->sents[ti->first_rank]) << endl; cerr << " " << samples->feats[ti->first_rank] << endl; cerr << " j " << TD::GetString(samples->sents[ti->second_rank]) << endl; @@ -393,23 +398,46 @@ main(int argc, char** argv) if (noup) break; + // write weights to file + if (select_weights == "best") { + weights.InitFromVector(lambdas); + string infix = "dtrain-weights-" + boost::lexical_cast<string>(t); + string w_fn = gettmpf(tmp_path, infix, "gz"); + weights.WriteToFile(w_fn, true); + weights_files.push_back(w_fn); + } + } // outer loop - unlink(grammar_buf_fn); + unlink(grammar_buf_fn.c_str()); if (!noup) { - if (!quiet) cerr << endl << "writing weights file '" << cfg["output"].as<string>() << "' ..."; - if (cfg["output"].as<string>() == "-") { - cout << _p9; - for (SparseVector<double>::const_iterator ti = lambdas.begin(); - ti != lambdas.end(); ++ti) { - if (ti->second == 0) continue; - cout << _np << FD::Convert(ti->first) << "\t" << ti->second << endl; + if (!quiet) cerr << endl << "writing weights file to '" << cfg["output"].as<string>() << "' ..."; + if (select_weights == "last") { // last + WriteFile out(cfg["output"].as<string>()); + ostream& o = *out.stream(); + o.precision(17); + o << _np; + for (SparseVector<double>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) { + if (it->second == 0) continue; + o << FD::Convert(it->first) << '\t' << it->second << endl; } if (hstreaming) cout << "__SHARD_COUNT__\t1" << endl; - } else if (cfg["output"].as<string>() != "VOID") { - weights.InitFromVector(lambdas); - weights.WriteToFile(cfg["output"].as<string>(), true); + } else { // best + if (cfg["output"].as<string>() != "-") { + CopyFile(weights_files[best_it], cfg["output"].as<string>()); + } else { + ReadFile(weights_files[best_it]); + string o; + cout.precision(17); + cout << _np; + while(getline(*input, o)) cout << o << endl; + } + for (vector<string>::iterator it = weights_files.begin(); it != weights_files.end(); ++it) { + unlink(it->c_str()); + it->erase(it->end()-3, it->end()); + unlink(it->c_str()); + } } if (!quiet) cerr << "done" << endl; } diff --git a/dtrain/dtrain.h b/dtrain/dtrain.h index c1d910aa..34464e3c 100644 --- a/dtrain/dtrain.h +++ b/dtrain/dtrain.h @@ -24,6 +24,20 @@ inline void register_and_convert(const vector<string>& strs, vector<WordID>& ids ids.push_back(TD::Convert(*it)); } +inline string gettmpf(const string path, const string infix, const string suffix="") { + char fn[1024]; + strcpy(fn, path.c_str()); + strcat(fn, "/"); + strcat(fn, infix.c_str()); + strcat(fn, "-XXXXXX"); + mkstemp(fn); + if (suffix != "") { // we will get 2 files + strcat(fn, "."); + strcat(fn, suffix.c_str()); + } + return string(fn); +} + inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); } inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); } inline ostream& _p2(ostream& out) { return out << setprecision(2); } diff --git a/dtrain/hstreaming/red-test b/dtrain/hstreaming/red-test new file mode 100644 index 00000000..b86e7894 --- /dev/null +++ b/dtrain/hstreaming/red-test @@ -0,0 +1,7 @@ +a 1 +b 2 +c 3.5 +a 1 +b 2 +c 3.5 +__SHARD_COUNT__ 2 diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini index 068074c4..c560a3a6 100644 --- a/dtrain/test/example/dtrain.ini +++ b/dtrain/test/example/dtrain.ini @@ -1,12 +1,13 @@ decoder_config=test/example/cdec.ini k=100 N=3 -epochs=3 +epochs=1000 input=test/example/nc-1k.gz scorer=stupid_bleu output=/tmp/weights.gz -stop_after=1000 +stop_after=10 sample_from=kbest pair_sampling=all print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough tmp=/tmp +select_weights=best |