diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-25 19:24:56 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-25 19:24:56 -0400 |
commit | 0dc2ea9f2e6a263e5601df820e513cb443c2a716 (patch) | |
tree | a34282163aa1a3a71b65c5e8669359adc7aa7dd7 | |
parent | afb41a09cc10db8b47047630c8db3148dfa5f648 (diff) |
em optimizer- really crappy implementation
-rw-r--r-- | training/Makefile.am | 4 | ||||
-rw-r--r-- | training/mpi_em_optimize.cc | 389 |
2 files changed, 393 insertions, 0 deletions
diff --git a/training/Makefile.am b/training/Makefile.am index b046c698..5f3b9bc4 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -11,6 +11,7 @@ bin_PROGRAMS = \ cllh_filter_grammar \ mpi_online_optimize \ mpi_batch_optimize \ + mpi_em_optimize \ augment_grammar noinst_PROGRAMS = \ @@ -25,6 +26,9 @@ mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_em_optimize_SOURCES = mpi_em_optimize.cc optimize.cc +mpi_em_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + if MPI bin_PROGRAMS += compute_cllh diff --git a/training/mpi_em_optimize.cc b/training/mpi_em_optimize.cc new file mode 100644 index 00000000..48683b15 --- /dev/null +++ b/training/mpi_em_optimize.cc @@ -0,0 +1,389 @@ +#include <sstream> +#include <iostream> +#include <vector> +#include <cassert> +#include <cmath> + +#ifdef HAVE_MPI +#include <mpi.h> +#endif + +#include <boost/shared_ptr.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "optimize.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" + +using namespace std; +using boost::shared_ptr; +namespace po = boost::program_options; + +void SanityCheck(const vector<double>& w) { + for (int i = 0; i < w.size(); ++i) { + assert(!isnan(w[i])); + assert(!isinf(w[i])); + } +} + +struct FComp { + const vector<double>& w_; + FComp(const vector<double>& w) : w_(w) {} + bool operator()(int a, int b) const { + return fabs(w_[a]) > fabs(w_[b]); + } +}; + +void ShowLargestFeatures(const vector<double>& w) { + vector<int> fnums(w.size()); + for (int i = 0; i < w.size(); ++i) + fnums[i] = i; + vector<int>::iterator mid = fnums.begin(); + mid += (w.size() > 10 ? 10 : w.size()); + partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); + cerr << "TOP FEATURES:"; + for (vector<int>::iterator i = fnums.begin(); i != mid; ++i) { + cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; + } + cerr << endl; +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input_weights,w",po::value<string>(),"Input feature weights file") + ("training_data,t",po::value<string>(),"Training data") + ("decoder_config,c",po::value<string>(),"Decoder configuration file") + ("output_weights,o",po::value<string>()->default_value("-"),"Output feature weights file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !(conf->count("training_data")) || !conf->count("decoder_config")) { + cerr << dcmdline_options << endl; +#ifdef HAVE_MPI + MPI::Finalize(); +#endif + exit(1); + } +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { + void Reset() { + total_complete = 0; + cur_obj = 0; + tot_obj = 0; + tot.clear(); + } + + void SetLocalGradientAndObjective(SparseVector<double>* g, double* o) const { + *o = tot_obj; + *g = tot; + } + + virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { + cur_obj = 0; + state = 1; + } + + void ExtractExpectedCounts(Hypergraph* hg) { + vector<prob_t> posts; + cur.clear(); + const prob_t z = hg->ComputeEdgePosteriors(1.0, &posts); + cur_obj = log(z); + for (int i = 0; i < posts.size(); ++i) { + const SparseVector<double>& efeats = hg->edges_[i].feature_values_; + const double post = static_cast<double>(posts[i] / z); + for (SparseVector<double>::const_iterator j = efeats.begin(); j != efeats.end(); ++j) + cur.add_value(j->first, post); + } + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 1); + state = 2; + ExtractExpectedCounts(hg); + } + + // replace translation forest, since we're doing EM training (we don't know which) + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 2); + state = 3; + ExtractExpectedCounts(hg); + } + + virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { + ++total_complete; + tot_obj += cur_obj; + tot += cur; + } + + int total_complete; + double cur_obj; + double tot_obj; + SparseVector<double> cur, tot; + int state; +}; + +void ReadConfig(const string& ini, vector<string>* out) { + ReadFile rf(ini); + istream& in = *rf.stream(); + while(in) { + string line; + getline(in, line); + if (!in) continue; + out->push_back(line); + } +} + +void StoreConfig(const vector<string>& cfg, istringstream* o) { + ostringstream os; + for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; } + o->str(os.str()); +} + +struct OptimizableMultinomialFamily { + struct CPD { + CPD() : z() {} + double z; + map<WordID, double> c2counts; + }; + map<WordID, CPD> counts; + double Value(WordID conditioning, WordID generated) const { + map<WordID, CPD>::const_iterator it = counts.find(conditioning); + assert(it != counts.end()); + map<WordID,double>::const_iterator r = it->second.c2counts.find(generated); + if (r == it->second.c2counts.end()) return 0; + return r->second; + } + void Increment(WordID conditioning, WordID generated, double count) { + CPD& cc = counts[conditioning]; + cc.z += count; + cc.c2counts[generated] += count; + } + void Optimize() { + for (map<WordID, CPD>::iterator i = counts.begin(); i != counts.end(); ++i) { + CPD& cpd = i->second; + for (map<WordID, double>::iterator j = cpd.c2counts.begin(); j != cpd.c2counts.end(); ++j) { + j->second /= cpd.z; + // cerr << "P(" << TD::Convert(j->first) << " | " << TD::Convert(i->first) << " ) = " << j->second << endl; + } + } + } + void Clear() { + counts.clear(); + } +}; + +struct CountManager { + CountManager(size_t num_types) : oms_(num_types) {} + virtual ~CountManager(); + virtual void AddCounts(const SparseVector<double>& c) = 0; + void Optimize(SparseVector<double>* weights) { + for (int i = 0; i < oms_.size(); ++i) { + oms_[i].Optimize(); + } + GetOptimalValues(weights); + for (int i = 0; i < oms_.size(); ++i) { + oms_[i].Clear(); + } + } + virtual void GetOptimalValues(SparseVector<double>* wv) const = 0; + vector<OptimizableMultinomialFamily> oms_; +}; +CountManager::~CountManager() {} + +struct TaggerCountManager : public CountManager { + // 0 = transitions, 2 = emissions + TaggerCountManager() : CountManager(2) {} + void AddCounts(const SparseVector<double>& c); + void GetOptimalValues(SparseVector<double>* wv) const { + for (set<int>::const_iterator it = fids_.begin(); it != fids_.end(); ++it) { + int ftype; + WordID cond, gen; + bool is_optimized = TaggerCountManager::GetFeature(*it, &ftype, &cond, &gen); + assert(is_optimized); + wv->set_value(*it, log(oms_[ftype].Value(cond, gen))); + } + } + // Id:0:a=1 Bi:a_b=1 Bi:b_c=1 Bi:c_d=1 Uni:a=1 Uni:b=1 Uni:c=1 Uni:d=1 Id:1:b=1 Bi:BOS_a=1 Id:2:c=1 + static bool GetFeature(const int fid, int* feature_type, WordID* cond, WordID* gen) { + const string& feat = FD::Convert(fid); + if (feat.size() > 5 && feat[0] == 'I' && feat[1] == 'd' && feat[2] == ':') { + // emission + const size_t p = feat.rfind(':'); + assert(p != string::npos); + *cond = TD::Convert(feat.substr(p+1)); + *gen = TD::Convert(feat.substr(3, p - 3)); + *feature_type = 1; + return true; + } else if (feat[0] == 'B' && feat.size() > 5 && feat[2] == ':' && feat[1] == 'i') { + // transition + const size_t p = feat.rfind('_'); + assert(p != string::npos); + *gen = TD::Convert(feat.substr(p+1)); + *cond = TD::Convert(feat.substr(3, p - 3)); + *feature_type = 0; + return true; + } else if (feat[0] == 'U' && feat.size() > 4 && feat[1] == 'n' && feat[2] == 'i' && feat[3] == ':') { + // ignore + return false; + } else { + cerr << "Don't know how to deal with feature of type: " << feat << endl; + abort(); + } + } + set<int> fids_; +}; + +void TaggerCountManager::AddCounts(const SparseVector<double>& c) { + for (SparseVector<double>::const_iterator it = c.begin(); it != c.end(); ++it) { + const double& val = it->second; + int ftype; + WordID cond, gen; + if (GetFeature(it->first, &ftype, &cond, &gen)) { + oms_[ftype].Increment(cond, gen, val); + fids_.insert(it->first); + } + } +} + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + MPI::Init(argc, argv); + const int size = MPI::COMM_WORLD.Get_size(); + const int rank = MPI::COMM_WORLD.Get_rank(); +#else + const int size = 1; + const int rank = 0; +#endif + SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + + TaggerCountManager tcm; + + // load cdec.ini and set up decoder + vector<string> cdec_ini; + ReadConfig(conf["decoder_config"].as<string>(), &cdec_ini); + istringstream ini; + StoreConfig(cdec_ini, &ini); + if (rank == 0) cerr << "Loading grammar...\n"; + Decoder* decoder = new Decoder(&ini); + if (decoder->GetConf()["input"].as<string>() != "-") { + cerr << "cdec.ini must not set an input file\n"; +#ifdef HAVE_MPI + MPI::COMM_WORLD.Abort(1); +#endif + } + if (rank == 0) cerr << "Done loading grammar!\n"; + Weights w; + if (conf.count("input_weights")) + w.InitFromFile(conf["input_weights"].as<string>()); + + double objective = 0; + bool converged = false; + + vector<double> lambdas; + w.InitVector(&lambdas); + vector<string> corpus; + ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus); + assert(corpus.size() > 0); + + int iteration = 0; + TrainingObserver observer; + while (!converged) { + ++iteration; + observer.Reset(); + if (rank == 0) { + cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; + } + decoder->SetWeights(lambdas); + for (int i = 0; i < corpus.size(); ++i) + decoder->Decode(corpus[i], &observer); + + SparseVector<double> x; + observer.SetLocalGradientAndObjective(&x, &objective); + cerr << "COUNTS = " << x << endl; + cerr << " OBJ = " << objective << endl; + tcm.AddCounts(x); + +#if 0 +#ifdef HAVE_MPI + MPI::COMM_WORLD.Reduce(const_cast<double*>(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0); + MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0); + swap(gradient, rcv_grad); + objective = to; +#endif +#endif + + if (rank == 0) { + SparseVector<double> wsv; + tcm.Optimize(&wsv); + + w.InitFromVector(wsv); + w.InitVector(&lambdas); + + ShowLargestFeatures(lambdas); + + converged = iteration > 100; + if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } + + string fname = "weights.cur.gz"; + if (converged) { fname = "weights.final.gz"; } + ostringstream vv; + vv << "Objective = " << objective << " (ITERATION=" << iteration << ")"; + const string svv = vv.str(); + w.WriteToFile(fname, true, &svv); + } // rank == 0 + int cint = converged; +#ifdef HAVE_MPI + MPI::COMM_WORLD.Bcast(const_cast<double*>(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0); + MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0); + MPI::COMM_WORLD.Barrier(); +#endif + converged = cint; + } +#ifdef HAVE_MPI + MPI::Finalize(); +#endif + return 0; +} |