diff options
| -rw-r--r-- | training/Makefile.am | 4 | ||||
| -rw-r--r-- | training/mpi_em_optimize.cc | 389 | 
2 files changed, 393 insertions, 0 deletions
| diff --git a/training/Makefile.am b/training/Makefile.am index b046c698..5f3b9bc4 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -11,6 +11,7 @@ bin_PROGRAMS = \    cllh_filter_grammar \    mpi_online_optimize \    mpi_batch_optimize \ +  mpi_em_optimize \    augment_grammar  noinst_PROGRAMS = \ @@ -25,6 +26,9 @@ mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval  mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc  mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_em_optimize_SOURCES = mpi_em_optimize.cc optimize.cc +mpi_em_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +  if MPI  bin_PROGRAMS += compute_cllh diff --git a/training/mpi_em_optimize.cc b/training/mpi_em_optimize.cc new file mode 100644 index 00000000..48683b15 --- /dev/null +++ b/training/mpi_em_optimize.cc @@ -0,0 +1,389 @@ +#include <sstream> +#include <iostream> +#include <vector> +#include <cassert> +#include <cmath> + +#ifdef HAVE_MPI +#include <mpi.h> +#endif + +#include <boost/shared_ptr.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "optimize.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" + +using namespace std; +using boost::shared_ptr; +namespace po = boost::program_options; + +void SanityCheck(const vector<double>& w) { +  for (int i = 0; i < w.size(); ++i) { +    assert(!isnan(w[i])); +    assert(!isinf(w[i])); +  } +} + +struct FComp { +  const vector<double>& w_; +  FComp(const vector<double>& w) : w_(w) {} +  bool operator()(int a, int b) const { +    return fabs(w_[a]) > fabs(w_[b]); +  } +}; + +void ShowLargestFeatures(const vector<double>& w) { +  vector<int> fnums(w.size()); +  for (int i = 0; i < w.size(); ++i) +    fnums[i] = i; +  vector<int>::iterator mid = fnums.begin(); +  mid += (w.size() > 10 ? 10 : w.size()); +  partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); +  cerr << "TOP FEATURES:"; +  for (vector<int>::iterator i = fnums.begin(); i != mid; ++i) { +    cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; +  } +  cerr << endl; +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +  po::options_description opts("Configuration options"); +  opts.add_options() +        ("input_weights,w",po::value<string>(),"Input feature weights file") +        ("training_data,t",po::value<string>(),"Training data") +        ("decoder_config,c",po::value<string>(),"Decoder configuration file") +        ("output_weights,o",po::value<string>()->default_value("-"),"Output feature weights file"); +  po::options_description clo("Command line options"); +  clo.add_options() +        ("config", po::value<string>(), "Configuration file") +        ("help,h", "Print this help message and exit"); +  po::options_description dconfig_options, dcmdline_options; +  dconfig_options.add(opts); +  dcmdline_options.add(opts).add(clo); +   +  po::store(parse_command_line(argc, argv, dcmdline_options), *conf); +  if (conf->count("config")) { +    ifstream config((*conf)["config"].as<string>().c_str()); +    po::store(po::parse_config_file(config, dconfig_options), *conf); +  } +  po::notify(*conf); + +  if (conf->count("help") || !(conf->count("training_data")) || !conf->count("decoder_config")) { +    cerr << dcmdline_options << endl; +#ifdef HAVE_MPI +    MPI::Finalize(); +#endif +    exit(1); +  } +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c) { +  ReadFile rf(fname); +  istream& in = *rf.stream(); +  string line; +  int lc = 0; +  while(in) { +    getline(in, line); +    if (!in) break; +    if (lc % size == rank) c->push_back(line); +    ++lc; +  } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { +  void Reset() { +    total_complete = 0; +    cur_obj = 0; +    tot_obj = 0; +    tot.clear(); +  }  + +  void SetLocalGradientAndObjective(SparseVector<double>* g, double* o) const { +    *o = tot_obj; +    *g = tot; +  } + +  virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { +    cur_obj = 0; +    state = 1; +  } + +  void ExtractExpectedCounts(Hypergraph* hg) { +    vector<prob_t> posts; +    cur.clear(); +    const prob_t z = hg->ComputeEdgePosteriors(1.0, &posts); +    cur_obj = log(z); +    for (int i = 0; i < posts.size(); ++i) { +      const SparseVector<double>& efeats = hg->edges_[i].feature_values_; +      const double post = static_cast<double>(posts[i] / z); +      for (SparseVector<double>::const_iterator j = efeats.begin(); j != efeats.end(); ++j) +        cur.add_value(j->first, post); +    } +  } + +  // compute model expectations, denominator of objective +  virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { +    assert(state == 1); +    state = 2; +    ExtractExpectedCounts(hg); +  } + +  // replace translation forest, since we're doing EM training (we don't know which) +  virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { +    assert(state == 2); +    state = 3; +    ExtractExpectedCounts(hg); +  } + +  virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { +    ++total_complete; +    tot_obj += cur_obj; +    tot += cur; +  } + +  int total_complete; +  double cur_obj; +  double tot_obj; +  SparseVector<double> cur, tot; +  int state; +}; + +void ReadConfig(const string& ini, vector<string>* out) { +  ReadFile rf(ini); +  istream& in = *rf.stream(); +  while(in) { +    string line; +    getline(in, line); +    if (!in) continue; +    out->push_back(line); +  } +} + +void StoreConfig(const vector<string>& cfg, istringstream* o) { +  ostringstream os; +  for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; } +  o->str(os.str()); +} + +struct OptimizableMultinomialFamily { +  struct CPD { +    CPD() : z() {} +    double z; +    map<WordID, double> c2counts; +  }; +  map<WordID, CPD> counts; +  double Value(WordID conditioning, WordID generated) const { +    map<WordID, CPD>::const_iterator it = counts.find(conditioning); +    assert(it != counts.end()); +    map<WordID,double>::const_iterator r = it->second.c2counts.find(generated); +    if (r == it->second.c2counts.end()) return 0; +    return r->second; +  } +  void Increment(WordID conditioning, WordID generated, double count) { +    CPD& cc = counts[conditioning]; +    cc.z += count; +    cc.c2counts[generated] += count; +  } +  void Optimize() { +    for (map<WordID, CPD>::iterator i = counts.begin(); i != counts.end(); ++i) { +      CPD& cpd = i->second; +      for (map<WordID, double>::iterator j = cpd.c2counts.begin(); j != cpd.c2counts.end(); ++j) { +        j->second /= cpd.z; +        // cerr << "P(" << TD::Convert(j->first) << " | " << TD::Convert(i->first) << " ) =  " << j->second << endl; +      } +    } +  } +  void Clear() { +    counts.clear(); +  } +}; + +struct CountManager { +  CountManager(size_t num_types) : oms_(num_types) {} +  virtual ~CountManager(); +  virtual void AddCounts(const SparseVector<double>& c) = 0; +  void Optimize(SparseVector<double>* weights) { +    for (int i = 0; i < oms_.size(); ++i) { +      oms_[i].Optimize(); +    } +    GetOptimalValues(weights); +    for (int i = 0; i < oms_.size(); ++i) { +      oms_[i].Clear(); +    } +  } +  virtual void GetOptimalValues(SparseVector<double>* wv) const = 0; +  vector<OptimizableMultinomialFamily> oms_; +}; +CountManager::~CountManager() {} + +struct TaggerCountManager : public CountManager { +  // 0 = transitions, 2 = emissions +  TaggerCountManager() : CountManager(2) {} +  void AddCounts(const SparseVector<double>& c); +  void GetOptimalValues(SparseVector<double>* wv) const { +    for (set<int>::const_iterator it = fids_.begin(); it != fids_.end(); ++it) { +      int ftype; +      WordID cond, gen; +      bool is_optimized = TaggerCountManager::GetFeature(*it, &ftype, &cond, &gen); +      assert(is_optimized); +      wv->set_value(*it, log(oms_[ftype].Value(cond, gen))); +    } +  } +  // Id:0:a=1 Bi:a_b=1 Bi:b_c=1 Bi:c_d=1 Uni:a=1 Uni:b=1 Uni:c=1 Uni:d=1 Id:1:b=1 Bi:BOS_a=1 Id:2:c=1 +  static bool GetFeature(const int fid, int* feature_type, WordID* cond, WordID* gen) { +    const string& feat = FD::Convert(fid); +    if (feat.size() > 5 && feat[0] == 'I' && feat[1] == 'd' && feat[2] == ':') { +      // emission +      const size_t p = feat.rfind(':'); +      assert(p != string::npos); +      *cond = TD::Convert(feat.substr(p+1)); +      *gen = TD::Convert(feat.substr(3, p - 3)); +      *feature_type = 1; +      return true; +    } else if (feat[0] == 'B' && feat.size() > 5 && feat[2] == ':' && feat[1] == 'i') { +      // transition +      const size_t p = feat.rfind('_'); +      assert(p != string::npos); +      *gen = TD::Convert(feat.substr(p+1)); +      *cond = TD::Convert(feat.substr(3, p - 3)); +      *feature_type = 0; +      return true; +    } else if (feat[0] == 'U' && feat.size() > 4 && feat[1] == 'n' && feat[2] == 'i' && feat[3] == ':') { +      // ignore +      return false; +    } else { +      cerr << "Don't know how to deal with feature of type: " << feat << endl; +      abort(); +    } +  } +  set<int> fids_; +}; + +void TaggerCountManager::AddCounts(const SparseVector<double>& c) { +  for (SparseVector<double>::const_iterator it = c.begin(); it != c.end(); ++it) { +    const double& val = it->second; +    int ftype; +    WordID cond, gen; +    if (GetFeature(it->first, &ftype, &cond, &gen)) { +      oms_[ftype].Increment(cond, gen, val); +      fids_.insert(it->first); +    } +  } +} + +int main(int argc, char** argv) { +#ifdef HAVE_MPI +  MPI::Init(argc, argv); +  const int size = MPI::COMM_WORLD.Get_size();  +  const int rank = MPI::COMM_WORLD.Get_rank(); +#else +  const int size = 1; +  const int rank = 0; +#endif +  SetSilent(true);  // turn off verbose decoder output +  register_feature_functions(); + +  po::variables_map conf; +  InitCommandLine(argc, argv, &conf); + +  TaggerCountManager tcm; + +  // load cdec.ini and set up decoder +  vector<string> cdec_ini; +  ReadConfig(conf["decoder_config"].as<string>(), &cdec_ini); +  istringstream ini; +  StoreConfig(cdec_ini, &ini); +  if (rank == 0) cerr << "Loading grammar...\n"; +  Decoder* decoder = new Decoder(&ini); +  if (decoder->GetConf()["input"].as<string>() != "-") { +    cerr << "cdec.ini must not set an input file\n"; +#ifdef HAVE_MPI +    MPI::COMM_WORLD.Abort(1); +#endif +  } +  if (rank == 0) cerr << "Done loading grammar!\n"; +  Weights w; +  if (conf.count("input_weights")) +    w.InitFromFile(conf["input_weights"].as<string>()); + +  double objective = 0; +  bool converged = false; + +  vector<double> lambdas; +  w.InitVector(&lambdas); +  vector<string> corpus; +  ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus); +  assert(corpus.size() > 0); + +  int iteration = 0; +  TrainingObserver observer; +  while (!converged) { +    ++iteration; +    observer.Reset(); +    if (rank == 0) { +      cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; +    } +    decoder->SetWeights(lambdas); +    for (int i = 0; i < corpus.size(); ++i) +      decoder->Decode(corpus[i], &observer); + +    SparseVector<double> x; +    observer.SetLocalGradientAndObjective(&x, &objective); +    cerr << "COUNTS = " << x << endl; +    cerr << "   OBJ = " << objective << endl; +    tcm.AddCounts(x); + +#if 0 +#ifdef HAVE_MPI +    MPI::COMM_WORLD.Reduce(const_cast<double*>(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0); +    MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0); +    swap(gradient, rcv_grad); +    objective = to; +#endif +#endif + +    if (rank == 0) { +      SparseVector<double> wsv; +      tcm.Optimize(&wsv); + +      w.InitFromVector(wsv); +      w.InitVector(&lambdas); + +      ShowLargestFeatures(lambdas); + +      converged = iteration > 100; +      if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } + +      string fname = "weights.cur.gz"; +      if (converged) { fname = "weights.final.gz"; } +      ostringstream vv; +      vv << "Objective = " << objective << "  (ITERATION=" << iteration << ")"; +      const string svv = vv.str(); +      w.WriteToFile(fname, true, &svv); +    }  // rank == 0 +    int cint = converged; +#ifdef HAVE_MPI +    MPI::COMM_WORLD.Bcast(const_cast<double*>(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0); +    MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0); +    MPI::COMM_WORLD.Barrier(); +#endif +    converged = cint; +  } +#ifdef HAVE_MPI +  MPI::Finalize();  +#endif +  return 0; +} | 
