diff options
Diffstat (limited to 'dtrain')
| -rw-r--r-- | dtrain/dtrain.cc | 82 | ||||
| -rw-r--r-- | dtrain/dtrain.h | 14 | ||||
| -rw-r--r-- | dtrain/hstreaming/red-test | 7 | ||||
| -rw-r--r-- | dtrain/test/example/dtrain.ini | 5 | 
4 files changed, 79 insertions, 29 deletions
| diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 9b1bbe68..69f83633 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -23,6 +23,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)      ("learning_rate",  po::value<double>()->default_value(0.0005),                      "learning rate")      ("gamma",          po::value<double>()->default_value(0.),       "gamma for SVM (0 for perceptron)")      ("tmp",            po::value<string>()->default_value("/tmp"),                    "temp dir to use") +    ("select_weights", po::value<string>()->default_value("last"),    "output 'best' or 'last' weights")      ("noup",           po::value<bool>()->zero_tokens(),                        "do not update weights");    po::options_description cl("Command Line Options");    cl.add_options() @@ -56,6 +57,10 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)         && (*cfg)["sample_from"].as<string>() != "forest") {      cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;    } +  if ((*cfg)["select_weights"].as<string>() != "last" +       && (*cfg)["select_weights"].as<string>() != "best") { +    cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl; +  }    return true;  } @@ -83,6 +88,7 @@ main(int argc, char** argv)    const string filter_type = cfg["filter"].as<string>();    const string sample_from = cfg["sample_from"].as<string>();    const string pair_sampling = cfg["pair_sampling"].as<string>(); +  const string select_weights = cfg["select_weights"].as<string>();    vector<string> print_weights;    if (cfg.count("print_weights"))      boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); @@ -140,13 +146,11 @@ main(int argc, char** argv)      // buffer input for t > 0    vector<string> src_str_buf;          // source strings    vector<vector<WordID> > ref_ids_buf; // references as WordID vecs -  // this is for writing the grammar buffer file -  char grammar_buf_fn[1024]; -  strcpy(grammar_buf_fn, cfg["tmp"].as<string>().c_str()); -  strcat(grammar_buf_fn, "/dtrain-grammars-XXXXXX"); -  mkstemp(grammar_buf_fn); +  vector<string> weights_files;        // remember weights for each iteration +  string tmp_path = cfg["tmp"].as<string>(); +  string grammar_buf_fn = gettmpf(tmp_path, "dtrain-grammars");    ogzstream grammar_buf_out; -  grammar_buf_out.open(grammar_buf_fn); +  grammar_buf_out.open(grammar_buf_fn.c_str());    unsigned in_sz = 999999999; // input index, input size    vector<pair<score_t,score_t> > all_scores; @@ -173,6 +177,7 @@ main(int argc, char** argv)      cerr << setw(25) << "gamma " << gamma << endl;      cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;      cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; +    cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl;      if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl;    } @@ -183,7 +188,7 @@ main(int argc, char** argv)    time_t start, end;      time(&start);    igzstream grammar_buf_in; -  if (t > 0) grammar_buf_in.open(grammar_buf_fn); +  if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str());    score_t score_sum = 0., model_sum = 0.;    unsigned ii = 0;    if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl; @@ -221,7 +226,7 @@ main(int argc, char** argv)          }        }      } -     +         // next iteration      if (next || stop) break; @@ -244,8 +249,8 @@ main(int argc, char** argv)        ref_ids_buf.push_back(ref_ids);        // process and set grammar        bool broken_grammar = true; -      for (string::iterator ti = in_split[3].begin(); ti != in_split[3].end(); ti++) { -        if (!isspace(*ti)) { +      for (string::iterator it = in_split[3].begin(); it != in_split[3].end(); it++) { +        if (!isspace(*it)) {            broken_grammar = false;            break;          } @@ -300,14 +305,14 @@ main(int argc, char** argv)        if (pair_sampling == "108010")          sample108010(samples, pairs); -      for (vector<pair<ScoredHyp,ScoredHyp> >::iterator ti = pairs.begin(); -            ti != pairs.end(); ti++) { +      for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); +           it != pairs.end(); it++) {          SparseVector<double> dv; -        if (ti->first.score - ti->second.score < 0) { -          dv = ti->second.f - ti->first.f; +        if (it->first.score - it->second.score < 0) { +          dv = it->second.f - it->first.f;        //} else { -        //dv = ti->first - ti->second; +        //dv = it->first - it->second;        //}            dv.add_value(FD::Convert("__bias"), -1); @@ -317,7 +322,7 @@ main(int argc, char** argv)            lambdas += dv * eta;            if (verbose) { -            /*cerr << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl; +            /*cerr << "{{ f("<< it->first_rank <<") > f(" << it->second_rank << ") but g(i)="<< it->first_score <<" < g(j)="<< it->second_score << " so update" << endl;              cerr << " i  " << TD::GetString(samples->sents[ti->first_rank]) << endl;              cerr << "    " << samples->feats[ti->first_rank] << endl;              cerr << " j  " << TD::GetString(samples->sents[ti->second_rank]) << endl; @@ -393,23 +398,46 @@ main(int argc, char** argv)    if (noup) break; +  // write weights to file +  if (select_weights == "best") { +    weights.InitFromVector(lambdas); +    string infix = "dtrain-weights-" + boost::lexical_cast<string>(t); +    string w_fn = gettmpf(tmp_path, infix, "gz"); +    weights.WriteToFile(w_fn, true);  +    weights_files.push_back(w_fn); +  } +    } // outer loop -  unlink(grammar_buf_fn); +  unlink(grammar_buf_fn.c_str());    if (!noup) { -    if (!quiet) cerr << endl << "writing weights file '" << cfg["output"].as<string>() << "' ..."; -    if (cfg["output"].as<string>() == "-") { -      cout << _p9; -      for (SparseVector<double>::const_iterator ti = lambdas.begin(); -            ti != lambdas.end(); ++ti) { -	if (ti->second == 0) continue; -        cout << _np << FD::Convert(ti->first) << "\t" << ti->second << endl; +    if (!quiet) cerr << endl << "writing weights file to '" << cfg["output"].as<string>() << "' ..."; +    if (select_weights == "last") { // last +      WriteFile out(cfg["output"].as<string>()); +      ostream& o = *out.stream(); +      o.precision(17); +      o << _np; +      for (SparseVector<double>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) { +	    if (it->second == 0) continue; +        o << FD::Convert(it->first) << '\t' << it->second << endl;        }        if (hstreaming) cout << "__SHARD_COUNT__\t1" << endl; -    } else if (cfg["output"].as<string>() != "VOID") { -      weights.InitFromVector(lambdas); -      weights.WriteToFile(cfg["output"].as<string>(), true); +    } else { // best +      if (cfg["output"].as<string>() != "-") { +        CopyFile(weights_files[best_it], cfg["output"].as<string>());  +      } else { +        ReadFile(weights_files[best_it]); +        string o; +        cout.precision(17); +        cout << _np; +        while(getline(*input, o)) cout << o << endl;  +      } +      for (vector<string>::iterator it = weights_files.begin(); it != weights_files.end(); ++it) { +        unlink(it->c_str()); +        it->erase(it->end()-3, it->end()); +        unlink(it->c_str()); +      }      }      if (!quiet) cerr << "done" << endl;    } diff --git a/dtrain/dtrain.h b/dtrain/dtrain.h index c1d910aa..34464e3c 100644 --- a/dtrain/dtrain.h +++ b/dtrain/dtrain.h @@ -24,6 +24,20 @@ inline void register_and_convert(const vector<string>& strs, vector<WordID>& ids      ids.push_back(TD::Convert(*it));  } +inline string gettmpf(const string path, const string infix, const string suffix="") { +  char fn[1024]; +  strcpy(fn, path.c_str()); +  strcat(fn, "/"); +  strcat(fn, infix.c_str()); +  strcat(fn, "-XXXXXX"); +  mkstemp(fn); +  if (suffix != "") { // we will get 2 files +    strcat(fn, "."); +    strcat(fn, suffix.c_str()); +  } +  return string(fn); +} +  inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); }  inline ostream& _p(ostream& out)  { return out << setiosflags(ios::showpos); }  inline ostream& _p2(ostream& out) { return out << setprecision(2); } diff --git a/dtrain/hstreaming/red-test b/dtrain/hstreaming/red-test new file mode 100644 index 00000000..b86e7894 --- /dev/null +++ b/dtrain/hstreaming/red-test @@ -0,0 +1,7 @@ +a	1 +b	2 +c	3.5 +a	1 +b	2 +c	3.5 +__SHARD_COUNT__	2 diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini index 068074c4..c560a3a6 100644 --- a/dtrain/test/example/dtrain.ini +++ b/dtrain/test/example/dtrain.ini @@ -1,12 +1,13 @@  decoder_config=test/example/cdec.ini  k=100  N=3 -epochs=3 +epochs=1000  input=test/example/nc-1k.gz  scorer=stupid_bleu  output=/tmp/weights.gz -stop_after=1000 +stop_after=10  sample_from=kbest  pair_sampling=all  print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough  tmp=/tmp +select_weights=best | 
