summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/Makefile.am4
-rw-r--r--training/mpi_em_optimize.cc389
2 files changed, 393 insertions, 0 deletions
diff --git a/training/Makefile.am b/training/Makefile.am
index b046c698..5f3b9bc4 100644
--- a/training/Makefile.am
+++ b/training/Makefile.am
@@ -11,6 +11,7 @@ bin_PROGRAMS = \
cllh_filter_grammar \
mpi_online_optimize \
mpi_batch_optimize \
+ mpi_em_optimize \
augment_grammar
noinst_PROGRAMS = \
@@ -25,6 +26,9 @@ mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval
mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc
mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
+mpi_em_optimize_SOURCES = mpi_em_optimize.cc optimize.cc
+mpi_em_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
+
if MPI
bin_PROGRAMS += compute_cllh
diff --git a/training/mpi_em_optimize.cc b/training/mpi_em_optimize.cc
new file mode 100644
index 00000000..48683b15
--- /dev/null
+++ b/training/mpi_em_optimize.cc
@@ -0,0 +1,389 @@
+#include <sstream>
+#include <iostream>
+#include <vector>
+#include <cassert>
+#include <cmath>
+
+#ifdef HAVE_MPI
+#include <mpi.h>
+#endif
+
+#include <boost/shared_ptr.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "verbose.h"
+#include "hg.h"
+#include "prob.h"
+#include "inside_outside.h"
+#include "ff_register.h"
+#include "decoder.h"
+#include "filelib.h"
+#include "optimize.h"
+#include "fdict.h"
+#include "weights.h"
+#include "sparse_vector.h"
+
+using namespace std;
+using boost::shared_ptr;
+namespace po = boost::program_options;
+
+void SanityCheck(const vector<double>& w) {
+ for (int i = 0; i < w.size(); ++i) {
+ assert(!isnan(w[i]));
+ assert(!isinf(w[i]));
+ }
+}
+
+struct FComp {
+ const vector<double>& w_;
+ FComp(const vector<double>& w) : w_(w) {}
+ bool operator()(int a, int b) const {
+ return fabs(w_[a]) > fabs(w_[b]);
+ }
+};
+
+void ShowLargestFeatures(const vector<double>& w) {
+ vector<int> fnums(w.size());
+ for (int i = 0; i < w.size(); ++i)
+ fnums[i] = i;
+ vector<int>::iterator mid = fnums.begin();
+ mid += (w.size() > 10 ? 10 : w.size());
+ partial_sort(fnums.begin(), mid, fnums.end(), FComp(w));
+ cerr << "TOP FEATURES:";
+ for (vector<int>::iterator i = fnums.begin(); i != mid; ++i) {
+ cerr << ' ' << FD::Convert(*i) << '=' << w[*i];
+ }
+ cerr << endl;
+}
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("input_weights,w",po::value<string>(),"Input feature weights file")
+ ("training_data,t",po::value<string>(),"Training data")
+ ("decoder_config,c",po::value<string>(),"Decoder configuration file")
+ ("output_weights,o",po::value<string>()->default_value("-"),"Output feature weights file");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help") || !(conf->count("training_data")) || !conf->count("decoder_config")) {
+ cerr << dcmdline_options << endl;
+#ifdef HAVE_MPI
+ MPI::Finalize();
+#endif
+ exit(1);
+ }
+}
+
+void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c) {
+ ReadFile rf(fname);
+ istream& in = *rf.stream();
+ string line;
+ int lc = 0;
+ while(in) {
+ getline(in, line);
+ if (!in) break;
+ if (lc % size == rank) c->push_back(line);
+ ++lc;
+ }
+}
+
+static const double kMINUS_EPSILON = -1e-6;
+
+struct TrainingObserver : public DecoderObserver {
+ void Reset() {
+ total_complete = 0;
+ cur_obj = 0;
+ tot_obj = 0;
+ tot.clear();
+ }
+
+ void SetLocalGradientAndObjective(SparseVector<double>* g, double* o) const {
+ *o = tot_obj;
+ *g = tot;
+ }
+
+ virtual void NotifyDecodingStart(const SentenceMetadata& smeta) {
+ cur_obj = 0;
+ state = 1;
+ }
+
+ void ExtractExpectedCounts(Hypergraph* hg) {
+ vector<prob_t> posts;
+ cur.clear();
+ const prob_t z = hg->ComputeEdgePosteriors(1.0, &posts);
+ cur_obj = log(z);
+ for (int i = 0; i < posts.size(); ++i) {
+ const SparseVector<double>& efeats = hg->edges_[i].feature_values_;
+ const double post = static_cast<double>(posts[i] / z);
+ for (SparseVector<double>::const_iterator j = efeats.begin(); j != efeats.end(); ++j)
+ cur.add_value(j->first, post);
+ }
+ }
+
+ // compute model expectations, denominator of objective
+ virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) {
+ assert(state == 1);
+ state = 2;
+ ExtractExpectedCounts(hg);
+ }
+
+ // replace translation forest, since we're doing EM training (we don't know which)
+ virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) {
+ assert(state == 2);
+ state = 3;
+ ExtractExpectedCounts(hg);
+ }
+
+ virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) {
+ ++total_complete;
+ tot_obj += cur_obj;
+ tot += cur;
+ }
+
+ int total_complete;
+ double cur_obj;
+ double tot_obj;
+ SparseVector<double> cur, tot;
+ int state;
+};
+
+void ReadConfig(const string& ini, vector<string>* out) {
+ ReadFile rf(ini);
+ istream& in = *rf.stream();
+ while(in) {
+ string line;
+ getline(in, line);
+ if (!in) continue;
+ out->push_back(line);
+ }
+}
+
+void StoreConfig(const vector<string>& cfg, istringstream* o) {
+ ostringstream os;
+ for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; }
+ o->str(os.str());
+}
+
+struct OptimizableMultinomialFamily {
+ struct CPD {
+ CPD() : z() {}
+ double z;
+ map<WordID, double> c2counts;
+ };
+ map<WordID, CPD> counts;
+ double Value(WordID conditioning, WordID generated) const {
+ map<WordID, CPD>::const_iterator it = counts.find(conditioning);
+ assert(it != counts.end());
+ map<WordID,double>::const_iterator r = it->second.c2counts.find(generated);
+ if (r == it->second.c2counts.end()) return 0;
+ return r->second;
+ }
+ void Increment(WordID conditioning, WordID generated, double count) {
+ CPD& cc = counts[conditioning];
+ cc.z += count;
+ cc.c2counts[generated] += count;
+ }
+ void Optimize() {
+ for (map<WordID, CPD>::iterator i = counts.begin(); i != counts.end(); ++i) {
+ CPD& cpd = i->second;
+ for (map<WordID, double>::iterator j = cpd.c2counts.begin(); j != cpd.c2counts.end(); ++j) {
+ j->second /= cpd.z;
+ // cerr << "P(" << TD::Convert(j->first) << " | " << TD::Convert(i->first) << " ) = " << j->second << endl;
+ }
+ }
+ }
+ void Clear() {
+ counts.clear();
+ }
+};
+
+struct CountManager {
+ CountManager(size_t num_types) : oms_(num_types) {}
+ virtual ~CountManager();
+ virtual void AddCounts(const SparseVector<double>& c) = 0;
+ void Optimize(SparseVector<double>* weights) {
+ for (int i = 0; i < oms_.size(); ++i) {
+ oms_[i].Optimize();
+ }
+ GetOptimalValues(weights);
+ for (int i = 0; i < oms_.size(); ++i) {
+ oms_[i].Clear();
+ }
+ }
+ virtual void GetOptimalValues(SparseVector<double>* wv) const = 0;
+ vector<OptimizableMultinomialFamily> oms_;
+};
+CountManager::~CountManager() {}
+
+struct TaggerCountManager : public CountManager {
+ // 0 = transitions, 2 = emissions
+ TaggerCountManager() : CountManager(2) {}
+ void AddCounts(const SparseVector<double>& c);
+ void GetOptimalValues(SparseVector<double>* wv) const {
+ for (set<int>::const_iterator it = fids_.begin(); it != fids_.end(); ++it) {
+ int ftype;
+ WordID cond, gen;
+ bool is_optimized = TaggerCountManager::GetFeature(*it, &ftype, &cond, &gen);
+ assert(is_optimized);
+ wv->set_value(*it, log(oms_[ftype].Value(cond, gen)));
+ }
+ }
+ // Id:0:a=1 Bi:a_b=1 Bi:b_c=1 Bi:c_d=1 Uni:a=1 Uni:b=1 Uni:c=1 Uni:d=1 Id:1:b=1 Bi:BOS_a=1 Id:2:c=1
+ static bool GetFeature(const int fid, int* feature_type, WordID* cond, WordID* gen) {
+ const string& feat = FD::Convert(fid);
+ if (feat.size() > 5 && feat[0] == 'I' && feat[1] == 'd' && feat[2] == ':') {
+ // emission
+ const size_t p = feat.rfind(':');
+ assert(p != string::npos);
+ *cond = TD::Convert(feat.substr(p+1));
+ *gen = TD::Convert(feat.substr(3, p - 3));
+ *feature_type = 1;
+ return true;
+ } else if (feat[0] == 'B' && feat.size() > 5 && feat[2] == ':' && feat[1] == 'i') {
+ // transition
+ const size_t p = feat.rfind('_');
+ assert(p != string::npos);
+ *gen = TD::Convert(feat.substr(p+1));
+ *cond = TD::Convert(feat.substr(3, p - 3));
+ *feature_type = 0;
+ return true;
+ } else if (feat[0] == 'U' && feat.size() > 4 && feat[1] == 'n' && feat[2] == 'i' && feat[3] == ':') {
+ // ignore
+ return false;
+ } else {
+ cerr << "Don't know how to deal with feature of type: " << feat << endl;
+ abort();
+ }
+ }
+ set<int> fids_;
+};
+
+void TaggerCountManager::AddCounts(const SparseVector<double>& c) {
+ for (SparseVector<double>::const_iterator it = c.begin(); it != c.end(); ++it) {
+ const double& val = it->second;
+ int ftype;
+ WordID cond, gen;
+ if (GetFeature(it->first, &ftype, &cond, &gen)) {
+ oms_[ftype].Increment(cond, gen, val);
+ fids_.insert(it->first);
+ }
+ }
+}
+
+int main(int argc, char** argv) {
+#ifdef HAVE_MPI
+ MPI::Init(argc, argv);
+ const int size = MPI::COMM_WORLD.Get_size();
+ const int rank = MPI::COMM_WORLD.Get_rank();
+#else
+ const int size = 1;
+ const int rank = 0;
+#endif
+ SetSilent(true); // turn off verbose decoder output
+ register_feature_functions();
+
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+
+ TaggerCountManager tcm;
+
+ // load cdec.ini and set up decoder
+ vector<string> cdec_ini;
+ ReadConfig(conf["decoder_config"].as<string>(), &cdec_ini);
+ istringstream ini;
+ StoreConfig(cdec_ini, &ini);
+ if (rank == 0) cerr << "Loading grammar...\n";
+ Decoder* decoder = new Decoder(&ini);
+ if (decoder->GetConf()["input"].as<string>() != "-") {
+ cerr << "cdec.ini must not set an input file\n";
+#ifdef HAVE_MPI
+ MPI::COMM_WORLD.Abort(1);
+#endif
+ }
+ if (rank == 0) cerr << "Done loading grammar!\n";
+ Weights w;
+ if (conf.count("input_weights"))
+ w.InitFromFile(conf["input_weights"].as<string>());
+
+ double objective = 0;
+ bool converged = false;
+
+ vector<double> lambdas;
+ w.InitVector(&lambdas);
+ vector<string> corpus;
+ ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus);
+ assert(corpus.size() > 0);
+
+ int iteration = 0;
+ TrainingObserver observer;
+ while (!converged) {
+ ++iteration;
+ observer.Reset();
+ if (rank == 0) {
+ cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n";
+ }
+ decoder->SetWeights(lambdas);
+ for (int i = 0; i < corpus.size(); ++i)
+ decoder->Decode(corpus[i], &observer);
+
+ SparseVector<double> x;
+ observer.SetLocalGradientAndObjective(&x, &objective);
+ cerr << "COUNTS = " << x << endl;
+ cerr << " OBJ = " << objective << endl;
+ tcm.AddCounts(x);
+
+#if 0
+#ifdef HAVE_MPI
+ MPI::COMM_WORLD.Reduce(const_cast<double*>(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0);
+ MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0);
+ swap(gradient, rcv_grad);
+ objective = to;
+#endif
+#endif
+
+ if (rank == 0) {
+ SparseVector<double> wsv;
+ tcm.Optimize(&wsv);
+
+ w.InitFromVector(wsv);
+ w.InitVector(&lambdas);
+
+ ShowLargestFeatures(lambdas);
+
+ converged = iteration > 100;
+ if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; }
+
+ string fname = "weights.cur.gz";
+ if (converged) { fname = "weights.final.gz"; }
+ ostringstream vv;
+ vv << "Objective = " << objective << " (ITERATION=" << iteration << ")";
+ const string svv = vv.str();
+ w.WriteToFile(fname, true, &svv);
+ } // rank == 0
+ int cint = converged;
+#ifdef HAVE_MPI
+ MPI::COMM_WORLD.Bcast(const_cast<double*>(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0);
+ MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0);
+ MPI::COMM_WORLD.Barrier();
+#endif
+ converged = cint;
+ }
+#ifdef HAVE_MPI
+ MPI::Finalize();
+#endif
+ return 0;
+}