From 8d51973c21337a1633e559cd09a649265600cc4c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 15 Apr 2012 17:28:08 -0400 Subject: crf training of arc-factored dep parser --- rst_parser/Makefile.am | 6 +- rst_parser/arc_factored.cc | 29 +++-- rst_parser/arc_factored.h | 53 ++++++---- rst_parser/arc_factored_marginals.cc | 10 +- rst_parser/arc_ff.cc | 64 +++++++++++ rst_parser/arc_ff.h | 43 ++++++++ rst_parser/mst_train.cc | 200 ++++++++++++++++++++++++++++++++++- rst_parser/rst_test.cc | 18 ++-- 8 files changed, 379 insertions(+), 44 deletions(-) create mode 100644 rst_parser/arc_ff.cc create mode 100644 rst_parser/arc_ff.h (limited to 'rst_parser') diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am index b61a20dd..2b64b43a 100644 --- a/rst_parser/Makefile.am +++ b/rst_parser/Makefile.am @@ -8,12 +8,12 @@ TESTS = rst_test noinst_LIBRARIES = librst.a -librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc +librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc arc_ff.cc mst_train_SOURCES = mst_train.cc -mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a ../training/optimize.o -lz rst_test_SOURCES = rst_test.cc rst_test_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/training -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc index b2c2c427..44e769b8 100644 --- a/rst_parser/arc_factored.cc +++ b/rst_parser/arc_factored.cc @@ -6,23 +6,38 @@ #include #include +#include "arc_ff.h" + using namespace std; using namespace std::tr1; using namespace boost; +void ArcFactoredForest::ExtractFeatures(const TaggedSentence& sentence, + const std::vector >& ffs) { + for (int i = 0; i < ffs.size(); ++i) { + const ArcFeatureFunction& ff = *ffs[i]; + for (int m = 0; m < num_words_; ++m) { + for (int h = 0; h < num_words_; ++h) { + ff.EgdeFeatures(sentence, h, m, &edges_(h,m).features); + } + ff.EgdeFeatures(sentence, -1, m, &root_edges_[m].features); + } + } +} + void ArcFactoredForest::PickBestParentForEachWord(EdgeSubset* st) const { - for (int m = 1; m <= num_words_; ++m) { - int best_head = -1; + for (int m = 0; m < num_words_; ++m) { + int best_head = -2; prob_t best_score; - for (int h = 0; h <= num_words_; ++h) { + for (int h = -1; h < num_words_; ++h) { const Edge& edge = (*this)(h,m); - if (best_head < 0 || edge.edge_prob > best_score) { + if (best_head < -1 || edge.edge_prob > best_score) { best_score = edge.edge_prob; best_head = h; } } - assert(best_head >= 0); - if (best_head) + assert(best_head >= -1); + if (best_head >= 0) st->h_m_pairs.push_back(make_pair(best_head, m)); else st->roots.push_back(m); @@ -56,7 +71,7 @@ struct PriorityQueue { }; // based on Trajan 1977 -void ArcFactoredForest::MaximumEdgeSubset(EdgeSubset* st) const { +void ArcFactoredForest::MaximumSpanningTree(EdgeSubset* st) const { typedef disjoint_sets_with_storage DisjointSet; DisjointSet strongly(num_words_ + 1); diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index 3003a86e..a95f8230 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -5,37 +5,52 @@ #include #include #include +#include #include "array2d.h" #include "sparse_vector.h" #include "prob.h" #include "weights.h" +#include "wordid.h" + +struct TaggedSentence { + std::vector words; + std::vector pos; +}; struct EdgeSubset { EdgeSubset() {} std::vector roots; // unless multiroot trees are supported, this // will have a single member - std::vector > h_m_pairs; // h,m start at *1* + std::vector > h_m_pairs; // h,m start at 0 }; +struct ArcFeatureFunction; class ArcFactoredForest { public: - explicit ArcFactoredForest(short num_words) : - num_words_(num_words), - root_edges_(num_words), - edges_(num_words, num_words) { + ArcFactoredForest() : num_words_() {} + explicit ArcFactoredForest(short num_words) { + resize(num_words); + } + + void resize(unsigned num_words) { + num_words_ = num_words; + root_edges_.clear(); + edges_.clear(); + root_edges_.resize(num_words); + edges_.resize(num_words, num_words); for (int h = 0; h < num_words; ++h) { for (int m = 0; m < num_words; ++m) { - edges_(h, m).h = h + 1; - edges_(h, m).m = m + 1; + edges_(h, m).h = h; + edges_(h, m).m = m; } - root_edges_[h].h = 0; - root_edges_[h].m = h + 1; + root_edges_[h].h = -1; + root_edges_[h].m = h; } } // compute the maximum spanning tree based on the current weighting // using the O(n^2) CLE algorithm - void MaximumEdgeSubset(EdgeSubset* st) const; + void MaximumSpanningTree(EdgeSubset* st) const; // Reweight edges so that edge_prob is the edge's marginals // optionally returns log partition @@ -52,20 +67,16 @@ class ArcFactoredForest { prob_t edge_prob; }; + // set eges_[*].features + void ExtractFeatures(const TaggedSentence& sentence, + const std::vector >& ffs); + const Edge& operator()(short h, short m) const { - assert(m > 0); - assert(m <= num_words_); - assert(h >= 0); - assert(h <= num_words_); - return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; + return h >= 0 ? edges_(h, m) : root_edges_[m]; } Edge& operator()(short h, short m) { - assert(m > 0); - assert(m <= num_words_); - assert(h >= 0); - assert(h <= num_words_); - return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; + return h >= 0 ? edges_(h, m) : root_edges_[m]; } float Weight(short h, short m) const { @@ -87,7 +98,7 @@ class ArcFactoredForest { } private: - unsigned num_words_; + int num_words_; std::vector root_edges_; Array2D edges_; }; diff --git a/rst_parser/arc_factored_marginals.cc b/rst_parser/arc_factored_marginals.cc index 9851b59a..16360b0d 100644 --- a/rst_parser/arc_factored_marginals.cc +++ b/rst_parser/arc_factored_marginals.cc @@ -31,14 +31,18 @@ void ArcFactoredForest::EdgeMarginals(double *plog_z) { ArcMatrix Linv = L.inverse(); if (plog_z) *plog_z = log(Linv.determinant()); RootVector rootMarginals = r.cwiseProduct(Linv.col(0)); +// ArcMatrix T = Linv; for (int h = 0; h < num_words_; ++h) { for (int m = 0; m < num_words_; ++m) { - edges_(h,m).edge_prob = prob_t((m == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,m) - - (h == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,h)); + const double marginal = (m == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,m) - + (h == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,h); + edges_(h,m).edge_prob = prob_t(marginal); +// T(h,m) = marginal; } root_edges_[h].edge_prob = prob_t(rootMarginals(h)); } - // cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl; +// cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl; +// cerr << "M:\n" << T << endl; } #else diff --git a/rst_parser/arc_ff.cc b/rst_parser/arc_ff.cc new file mode 100644 index 00000000..f9effbda --- /dev/null +++ b/rst_parser/arc_ff.cc @@ -0,0 +1,64 @@ +#include "arc_ff.h" + +#include "tdict.h" +#include "fdict.h" +#include "sentence_metadata.h" + +using namespace std; + +ArcFeatureFunction::~ArcFeatureFunction() {} + +void ArcFeatureFunction::PrepareForInput(const TaggedSentence&) {} + +DistancePenalty::DistancePenalty(const string&) : fidw_(FD::Convert("Distance")), fidr_(FD::Convert("RootDistance")) {} + +void DistancePenalty::EdgeFeaturesImpl(const TaggedSentence& sent, + short h, + short m, + SparseVector* features) const { + const bool dir = m < h; + const bool is_root = (h == -1); + int v = m - h; + if (v < 0) { + v= -1 - int(log(-v) / log(2)); + } else { + v= int(log(v) / log(2)); + } + static map lenmap; + int& lenfid = lenmap[v]; + if (!lenfid) { + ostringstream os; + if (v < 0) os << "LenL" << -v; else os << "LenR" << v; + lenfid = FD::Convert(os.str()); + } + features->set_value(lenfid, 1.0); + const string& lenstr = FD::Convert(lenfid); + if (!is_root) { + static int modl = FD::Convert("ModLeft"); + static int modr = FD::Convert("ModRight"); + if (dir) features->set_value(modl, 1); + else features->set_value(modr, 1); + } + if (is_root) { + ostringstream os; + os << "ROOT:" << TD::Convert(sent.pos[m]); + features->set_value(FD::Convert(os.str()), 1.0); + os << "_" << lenstr; + features->set_value(FD::Convert(os.str()), 1.0); + } else { // not root + ostringstream os; + os << "HM:" << TD::Convert(sent.pos[h]) << '_' << TD::Convert(sent.pos[m]); + features->set_value(FD::Convert(os.str()), 1.0); + os << '_' << dir; + features->set_value(FD::Convert(os.str()), 1.0); + os << '_' << lenstr; + features->set_value(FD::Convert(os.str()), 1.0); + ostringstream os2; + os2 << "LexHM:" << TD::Convert(sent.words[h]) << '_' << TD::Convert(sent.words[m]); + features->set_value(FD::Convert(os2.str()), 1.0); + os2 << '_' << dir; + features->set_value(FD::Convert(os2.str()), 1.0); + os2 << '_' << lenstr; + features->set_value(FD::Convert(os2.str()), 1.0); + } +} diff --git a/rst_parser/arc_ff.h b/rst_parser/arc_ff.h new file mode 100644 index 00000000..bc51fef4 --- /dev/null +++ b/rst_parser/arc_ff.h @@ -0,0 +1,43 @@ +#ifndef _ARC_FF_H_ +#define _ARC_FF_H_ + +#include +#include "sparse_vector.h" +#include "weights.h" +#include "arc_factored.h" + +struct TaggedSentence; +class ArcFeatureFunction { + public: + virtual ~ArcFeatureFunction(); + + // called once, per input, before any calls to EdgeFeatures + // used to initialize sentence-specific data structures + virtual void PrepareForInput(const TaggedSentence& sentence); + + inline void EgdeFeatures(const TaggedSentence& sentence, + short h, + short m, + SparseVector* features) const { + EdgeFeaturesImpl(sentence, h, m, features); + } + protected: + virtual void EdgeFeaturesImpl(const TaggedSentence& sentence, + short h, + short m, + SparseVector* features) const = 0; +}; + +class DistancePenalty : public ArcFeatureFunction { + public: + DistancePenalty(const std::string& param); + protected: + virtual void EdgeFeaturesImpl(const TaggedSentence& sentence, + short h, + short m, + SparseVector* features) const; + private: + const int fidw_, fidr_; +}; + +#endif diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index 7b5af4c1..def23edb 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -1,12 +1,210 @@ #include "arc_factored.h" +#include #include +#include +#include + +#include "arc_ff.h" +#include "arc_ff_factory.h" +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "picojson.h" +#include "optimize.h" +#include "weights.h" using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + string cfg_file; + opts.add_options() + ("training_data,t",po::value()->default_value("-"), "File containing training data (jsent format)") + ("feature_function,F",po::value >()->composing(), "feature function") + ("regularization_strength,C",po::value()->default_value(1.0), "Regularization strength") + ("correction_buffers,m", po::value()->default_value(10), "LBFGS correction buffers"); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value(&cfg_file), "Configuration file") + ("help,?", "Print this help message and exit"); + + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(dconfig_options).add(clo); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (cfg_file.size() > 0) { + ReadFile rf(cfg_file); + po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); + } + if (conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +struct TrainingInstance { + TaggedSentence ts; + EdgeSubset tree; + SparseVector features; +}; + +void ReadTraining(const string& fname, vector* corpus, int rank = 0, int size = 1) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + string err; + int lc = 0; + bool flag = false; + while(getline(in, line)) { + ++lc; + if ((lc-1) % size != rank) continue; + if (rank == 0 && lc % 10 == 0) { cerr << '.' << flush; flag = true; } + if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; } + size_t pos = line.rfind('\t'); + assert(pos != string::npos); + picojson::value obj; + picojson::parse(obj, line.begin() + pos, line.end(), &err); + if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); } + corpus->push_back(TrainingInstance()); + TrainingInstance& cur = corpus->back(); + TaggedSentence& ts = cur.ts; + EdgeSubset& tree = cur.tree; + assert(obj.is()); + const picojson::object& d = obj.get(); + const picojson::array& ta = d.find("tokens")->second.get(); + for (unsigned i = 0; i < ta.size(); ++i) { + ts.words.push_back(TD::Convert(ta[i].get()[0].get())); + ts.pos.push_back(TD::Convert(ta[i].get()[1].get())); + } + const picojson::array& da = d.find("deps")->second.get(); + for (unsigned i = 0; i < da.size(); ++i) { + const picojson::array& thm = da[i].get(); + // get dep type here + short h = thm[2].get(); + short m = thm[1].get(); + if (h < 0) + tree.roots.push_back(m); + else + tree.h_m_pairs.push_back(make_pair(h,m)); + } + //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl; + } + if (flag) cerr << "\nRead " << lc << " training instances\n"; +} + +void AddFeatures(double prob, const SparseVector& fmap, vector* g) { + for (SparseVector::const_iterator it = fmap.begin(); it != fmap.end(); ++it) + (*g)[it->first] += it->second * prob; +} + +double ApplyRegularizationTerms(const double C, + const vector& weights, + vector* g) { + assert(weights.size() == g->size()); + double reg = 0; + for (size_t i = 0; i < weights.size(); ++i) { +// const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0); + const double& w_i = weights[i]; + double& g_i = (*g)[i]; + reg += C * w_i * w_i; + g_i += 2 * C * w_i; + +// reg += T * (w_i - prev_w_i) * (w_i - prev_w_i); +// g_i += 2 * T * (w_i - prev_w_i); + } + return reg; +} int main(int argc, char** argv) { + int rank = 0; + int size = 1; + po::variables_map conf; + InitCommandLine(argc, argv, &conf); ArcFactoredForest af(5); - cerr << af(0,3) << endl; + ArcFFRegistry reg; + reg.Register("DistancePenalty", new ArcFFFactory); + vector corpus; + vector > ffs; + ffs.push_back(boost::shared_ptr(new DistancePenalty(""))); + ReadTraining(conf["training_data"].as(), &corpus, rank, size); + vector forests(corpus.size()); + SparseVector empirical; + bool flag = false; + for (int i = 0; i < corpus.size(); ++i) { + TrainingInstance& cur = corpus[i]; + if (rank == 0 && (i+1) % 10 == 0) { cerr << '.' << flush; flag = true; } + if (rank == 0 && (i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; } + for (int fi = 0; fi < ffs.size(); ++fi) { + ArcFeatureFunction& ff = *ffs[fi]; + ff.PrepareForInput(cur.ts); + SparseVector efmap; + for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { + efmap.clear(); + ff.EgdeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, + cur.tree.h_m_pairs[j].second, + &efmap); + cur.features += efmap; + } + for (int j = 0; j < cur.tree.roots.size(); ++j) { + efmap.clear(); + ff.EgdeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); + cur.features += efmap; + } + } + empirical += cur.features; + forests[i].resize(cur.ts.words.size()); + forests[i].ExtractFeatures(cur.ts, ffs); + } + if (flag) cerr << endl; + //cerr << "EMP: " << empirical << endl; //DE + vector weights(FD::NumFeats(), 0.0); + vector g(FD::NumFeats(), 0.0); + cerr << "features initialized\noptimizing...\n"; + boost::shared_ptr o; + o.reset(new LBFGSOptimizer(g.size(), conf["correction_buffers"].as())); + int iterations = 1000; + for (int iter = 0; iter < iterations; ++iter) { + cerr << "ITERATION " << iter << " " << flush; + fill(g.begin(), g.end(), 0.0); + for (SparseVector::const_iterator it = empirical.begin(); it != empirical.end(); ++it) + g[it->first] = -it->second; + double obj = -empirical.dot(weights); + // SparseVector mfm; //DE + for (int i = 0; i < corpus.size(); ++i) { + forests[i].Reweight(weights); + double logz; + forests[i].EdgeMarginals(&logz); + //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -logz << " OO= " << (-corpus[i].features.dot(weights) - logz) << endl; + obj -= logz; + int num_words = corpus[i].ts.words.size(); + for (int h = -1; h < num_words; ++h) { + for (int m = 0; m < num_words; ++m) { + if (h == m) continue; + const ArcFactoredForest::Edge& edge = forests[i](h,m); + const SparseVector& fmap = edge.features; + double prob = edge.edge_prob.as_float(); + if (prob < -0.000001) { cerr << "Prob < 0: " << prob << endl; prob = 0; } + if (prob > 1.000001) { cerr << "Prob > 1: " << prob << endl; prob = 1; } + AddFeatures(prob, fmap, &g); + //mfm += fmap * prob; // DE + } + } + } + //cerr << endl << "E: " << empirical << endl; // DE + //cerr << "M: " << mfm << endl; // DE + double r = ApplyRegularizationTerms(conf["regularization_strength"].as(), weights, &g); + double gnorm = 0; + for (int i = 0; i < g.size(); ++i) + gnorm += g[i]*g[i]; + cerr << "OBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm) << endl; + obj += r; + assert(obj >= 0); + o->Optimize(obj, g, &weights); + Weights::ShowLargestFeatures(weights); + if (o->HasConverged()) { cerr << "CONVERGED\n"; break; } + } return 0; } diff --git a/rst_parser/rst_test.cc b/rst_parser/rst_test.cc index 8995515f..7e6fb2c1 100644 --- a/rst_parser/rst_test.cc +++ b/rst_parser/rst_test.cc @@ -17,15 +17,15 @@ int main(int argc, char** argv) { // (0, 1) 9 // (0, 3) 9 ArcFactoredForest af(3); - af(1,2).edge_prob.logeq(20); - af(1,3).edge_prob.logeq(3); - af(2,1).edge_prob.logeq(20); - af(2,3).edge_prob.logeq(30); - af(3,2).edge_prob.logeq(0); - af(3,1).edge_prob.logeq(11); - af(0,2).edge_prob.logeq(10); - af(0,1).edge_prob.logeq(9); - af(0,3).edge_prob.logeq(9); + af(0,1).edge_prob.logeq(20); + af(0,2).edge_prob.logeq(3); + af(1,0).edge_prob.logeq(20); + af(1,2).edge_prob.logeq(30); + af(2,1).edge_prob.logeq(0); + af(2,0).edge_prob.logeq(11); + af(-1,1).edge_prob.logeq(10); + af(-1,0).edge_prob.logeq(9); + af(-1,2).edge_prob.logeq(9); EdgeSubset tree; // af.MaximumEdgeSubset(&tree); double lz; -- cgit v1.2.3