From 4b38556c88c739de82b9c298261a262ec620280e Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 16 Apr 2012 18:20:33 -0400 Subject: rst sampler --- rst_parser/Makefile.am | 7 ++- rst_parser/dep_training.cc | 56 ++++++++++++++++++++ rst_parser/dep_training.h | 17 ++++++ rst_parser/mst_train.cc | 58 +-------------------- rst_parser/rst.cc | 56 +++++++++++++++----- rst_parser/rst.h | 8 ++- rst_parser/rst_parse.cc | 126 +++++++++++++++++++++++++++++++++++++++++++++ utils/weights.cc | 4 +- 8 files changed, 260 insertions(+), 72 deletions(-) create mode 100644 rst_parser/dep_training.cc create mode 100644 rst_parser/dep_training.h create mode 100644 rst_parser/rst_parse.cc diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am index 2b64b43a..6e884f53 100644 --- a/rst_parser/Makefile.am +++ b/rst_parser/Makefile.am @@ -1,5 +1,5 @@ bin_PROGRAMS = \ - mst_train + mst_train rst_parse noinst_PROGRAMS = \ rst_test @@ -8,11 +8,14 @@ TESTS = rst_test noinst_LIBRARIES = librst.a -librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc arc_ff.cc +librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc arc_ff.cc dep_training.cc mst_train_SOURCES = mst_train.cc mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a ../training/optimize.o -lz +rst_parse_SOURCES = rst_parse.cc +rst_parse_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + rst_test_SOURCES = rst_test.cc rst_test_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc new file mode 100644 index 00000000..de431ebc --- /dev/null +++ b/rst_parser/dep_training.cc @@ -0,0 +1,56 @@ +#include "dep_training.h" + +#include +#include + +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "picojson.h" + +using namespace std; + +void TrainingInstance::ReadTraining(const string& fname, vector* corpus, int rank, int size) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + string err; + int lc = 0; + bool flag = false; + while(getline(in, line)) { + ++lc; + if ((lc-1) % size != rank) continue; + if (rank == 0 && lc % 10 == 0) { cerr << '.' << flush; flag = true; } + if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; } + size_t pos = line.rfind('\t'); + assert(pos != string::npos); + picojson::value obj; + picojson::parse(obj, line.begin() + pos, line.end(), &err); + if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); } + corpus->push_back(TrainingInstance()); + TrainingInstance& cur = corpus->back(); + TaggedSentence& ts = cur.ts; + EdgeSubset& tree = cur.tree; + assert(obj.is()); + const picojson::object& d = obj.get(); + const picojson::array& ta = d.find("tokens")->second.get(); + for (unsigned i = 0; i < ta.size(); ++i) { + ts.words.push_back(TD::Convert(ta[i].get()[0].get())); + ts.pos.push_back(TD::Convert(ta[i].get()[1].get())); + } + const picojson::array& da = d.find("deps")->second.get(); + for (unsigned i = 0; i < da.size(); ++i) { + const picojson::array& thm = da[i].get(); + // get dep type here + short h = thm[2].get(); + short m = thm[1].get(); + if (h < 0) + tree.roots.push_back(m); + else + tree.h_m_pairs.push_back(make_pair(h,m)); + } + //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl; + } + if (flag) cerr << "\nRead " << lc << " training instances\n"; +} + diff --git a/rst_parser/dep_training.h b/rst_parser/dep_training.h new file mode 100644 index 00000000..73ffd298 --- /dev/null +++ b/rst_parser/dep_training.h @@ -0,0 +1,17 @@ +#ifndef _DEP_TRAINING_H_ +#define _DEP_TRAINING_H_ + +#include +#include +#include "arc_factored.h" +#include "weights.h" + +struct TrainingInstance { + TaggedSentence ts; + EdgeSubset tree; + SparseVector features; + // reads a "Jsent" formatted dependency file + static void ReadTraining(const std::string& fname, std::vector* corpus, int rank = 0, int size = 1); +}; + +#endif diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index c5cab6ec..f0403d7e 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -10,10 +10,9 @@ #include "stringlib.h" #include "filelib.h" #include "tdict.h" -#include "picojson.h" +#include "dep_training.h" #include "optimize.h" #include "weights.h" -#include "rst.h" using namespace std; namespace po = boost::program_options; @@ -47,56 +46,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } -struct TrainingInstance { - TaggedSentence ts; - EdgeSubset tree; - SparseVector features; -}; - -void ReadTraining(const string& fname, vector* corpus, int rank = 0, int size = 1) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - string err; - int lc = 0; - bool flag = false; - while(getline(in, line)) { - ++lc; - if ((lc-1) % size != rank) continue; - if (rank == 0 && lc % 10 == 0) { cerr << '.' << flush; flag = true; } - if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; } - size_t pos = line.rfind('\t'); - assert(pos != string::npos); - picojson::value obj; - picojson::parse(obj, line.begin() + pos, line.end(), &err); - if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); } - corpus->push_back(TrainingInstance()); - TrainingInstance& cur = corpus->back(); - TaggedSentence& ts = cur.ts; - EdgeSubset& tree = cur.tree; - assert(obj.is()); - const picojson::object& d = obj.get(); - const picojson::array& ta = d.find("tokens")->second.get(); - for (unsigned i = 0; i < ta.size(); ++i) { - ts.words.push_back(TD::Convert(ta[i].get()[0].get())); - ts.pos.push_back(TD::Convert(ta[i].get()[1].get())); - } - const picojson::array& da = d.find("deps")->second.get(); - for (unsigned i = 0; i < da.size(); ++i) { - const picojson::array& thm = da[i].get(); - // get dep type here - short h = thm[2].get(); - short m = thm[1].get(); - if (h < 0) - tree.roots.push_back(m); - else - tree.h_m_pairs.push_back(make_pair(h,m)); - } - //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl; - } - if (flag) cerr << "\nRead " << lc << " training instances\n"; -} - void AddFeatures(double prob, const SparseVector& fmap, vector* g) { for (SparseVector::const_iterator it = fmap.begin(); it != fmap.end(); ++it) (*g)[it->first] += it->second * prob; @@ -131,7 +80,7 @@ int main(int argc, char** argv) { vector corpus; vector > ffs; ffs.push_back(boost::shared_ptr(new DistancePenalty(""))); - ReadTraining(conf["training_data"].as(), &corpus, rank, size); + TrainingInstance::ReadTraining(conf["training_data"].as(), &corpus, rank, size); vector forests(corpus.size()); SparseVector empirical; bool flag = false; @@ -224,9 +173,6 @@ int main(int argc, char** argv) { } if (converged) { cerr << "CONVERGED\n"; break; } } - forests[0].Reweight(weights); - TreeSampler ts(forests[0]); - EdgeSubset tt; ts.SampleRandomSpanningTree(&tt); return 0; } diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc index c4ce898e..bc91330b 100644 --- a/rst_parser/rst.cc +++ b/rst_parser/rst.cc @@ -3,45 +3,77 @@ using namespace std; // David B. Wilson. Generating Random Spanning Trees More Quickly than the Cover Time. - +// this is an awesome algorithm TreeSampler::TreeSampler(const ArcFactoredForest& af) : forest(af), usucc(af.size() + 1) { - // edges are directed from modifiers to heads, to the root + // edges are directed from modifiers to heads, and finally to the root + vector p; for (int m = 1; m <= forest.size(); ++m) { +#if USE_ALIAS_SAMPLER + p.clear(); +#else SampleSet& ss = usucc[m]; - for (int h = 0; h <= forest.size(); ++h) - ss.add(forest(h-1,m-1).edge_prob.as_float()); +#endif + double z = 0; + for (int h = 0; h <= forest.size(); ++h) { + double u = forest(h-1,m-1).edge_prob.as_float(); + z += u; +#if USE_ALIAS_SAMPLER + p.push_back(u); +#else + ss.add(u); +#endif + } +#if USE_ALIAS_SAMPLER + for (int i = 0; i < p.size(); ++i) { p[i] /= z; } + usucc[m].Init(p); +#endif } } -void TreeSampler::SampleRandomSpanningTree(EdgeSubset* tree) { - MT19937 rng; +void TreeSampler::SampleRandomSpanningTree(EdgeSubset* tree, MT19937* prng) { + MT19937& rng = *prng; const int r = 0; bool success = false; while (!success) { int roots = 0; + tree->h_m_pairs.clear(); + tree->roots.clear(); vector next(forest.size() + 1, -1); vector in_tree(forest.size() + 1, 0); in_tree[r] = 1; - for (int i = 0; i < forest.size(); ++i) { + //cerr << "Forest size: " << forest.size() << endl; + for (int i = 0; i <= forest.size(); ++i) { + //cerr << "Sampling starting at u=" << i << endl; int u = i; if (in_tree[u]) continue; while(!in_tree[u]) { +#if USE_ALIAS_SAMPLER + next[u] = usucc[u].Draw(rng); +#else next[u] = rng.SelectSample(usucc[u]); +#endif u = next[u]; } u = i; - cerr << (u-1); + //cerr << (u-1); + int prev = u-1; while(!in_tree[u]) { in_tree[u] = true; u = next[u]; - cerr << " > " << (u-1); - if (u == r) { ++roots; } + //cerr << " > " << (u-1); + if (u == r) { + ++roots; + tree->roots.push_back(prev); + } else { + tree->h_m_pairs.push_back(make_pair(u-1,prev)); + } + prev = u-1; } - cerr << endl; + //cerr << endl; } assert(roots > 0); if (roots > 1) { - cerr << "FAILURE\n"; + //cerr << "FAILURE\n"; } else { success = true; } diff --git a/rst_parser/rst.h b/rst_parser/rst.h index a269ff9b..8bf389f7 100644 --- a/rst_parser/rst.h +++ b/rst_parser/rst.h @@ -4,12 +4,18 @@ #include #include "sampler.h" #include "arc_factored.h" +#include "alias_sampler.h" struct TreeSampler { explicit TreeSampler(const ArcFactoredForest& af); - void SampleRandomSpanningTree(EdgeSubset* tree); + void SampleRandomSpanningTree(EdgeSubset* tree, MT19937* rng); const ArcFactoredForest& forest; +#define USE_ALIAS_SAMPLER 1 +#if USE_ALIAS_SAMPLER + std::vector usucc; +#else std::vector > usucc; +#endif }; #endif diff --git a/rst_parser/rst_parse.cc b/rst_parser/rst_parse.cc new file mode 100644 index 00000000..9cc1359a --- /dev/null +++ b/rst_parser/rst_parse.cc @@ -0,0 +1,126 @@ +#include "arc_factored.h" + +#include +#include +#include +#include + +#include "timing_stats.h" +#include "arc_ff.h" +#include "arc_ff_factory.h" +#include "dep_training.h" +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "weights.h" +#include "rst.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + string cfg_file; + opts.add_options() + ("training_data,t",po::value()->default_value("-"), "File containing training data (jsent format)") + ("feature_function,F",po::value >()->composing(), "feature function (multiple permitted)") + ("q_weights,q",po::value(), "Arc-factored weights for proposal distribution") + ("samples,n",po::value()->default_value(1000), "Number of samples"); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value(&cfg_file), "Configuration file") + ("help,?", "Print this help message and exit"); + + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(dconfig_options).add(clo); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (cfg_file.size() > 0) { + ReadFile rf(cfg_file); + po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); + } + if (conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + ArcFactoredForest af(5); + ArcFFRegistry reg; + reg.Register("DistancePenalty", new ArcFFFactory); + vector corpus; + vector > ffs; + ffs.push_back(boost::shared_ptr(new DistancePenalty(""))); + TrainingInstance::ReadTraining(conf["training_data"].as(), &corpus); + vector forests(corpus.size()); + SparseVector empirical; + bool flag = false; + for (int i = 0; i < corpus.size(); ++i) { + TrainingInstance& cur = corpus[i]; + if ((i+1) % 10 == 0) { cerr << '.' << flush; flag = true; } + if ((i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; } + for (int fi = 0; fi < ffs.size(); ++fi) { + ArcFeatureFunction& ff = *ffs[fi]; + ff.PrepareForInput(cur.ts); + SparseVector efmap; + for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { + efmap.clear(); + ff.EgdeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, + cur.tree.h_m_pairs[j].second, + &efmap); + cur.features += efmap; + } + for (int j = 0; j < cur.tree.roots.size(); ++j) { + efmap.clear(); + ff.EgdeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); + cur.features += efmap; + } + } + empirical += cur.features; + forests[i].resize(cur.ts.words.size()); + forests[i].ExtractFeatures(cur.ts, ffs); + } + if (flag) cerr << endl; + vector weights(FD::NumFeats(), 0.0); + Weights::InitFromFile(conf["q_weights"].as(), &weights); + MT19937 rng; + SparseVector model_exp; + SparseVector sampled_exp; + int samples = conf["samples"].as(); + for (int i = 0; i < corpus.size(); ++i) { + const int num_words = corpus[i].ts.words.size(); + forests[i].Reweight(weights); + forests[i].EdgeMarginals(); + model_exp.clear(); + for (int h = -1; h < num_words; ++h) { + for (int m = 0; m < num_words; ++m) { + if (h == m) continue; + const ArcFactoredForest::Edge& edge = forests[i](h,m); + const SparseVector& fmap = edge.features; + double prob = edge.edge_prob.as_float(); + model_exp += fmap * prob; + } + } + //cerr << "TRUE EXP: " << model_exp << endl; + + forests[i].Reweight(weights); + TreeSampler ts(forests[i]); + sampled_exp.clear(); + //ostringstream os; os << "Samples_" << samples; + //Timer t(os.str()); + for (int n = 0; n < samples; ++n) { + EdgeSubset tree; + ts.SampleRandomSpanningTree(&tree, &rng); + SparseVector feats; + tree.ExtractFeatures(corpus[i].ts, ffs, &feats); + sampled_exp += feats; + } + sampled_exp /= samples; + cerr << "L2 norm of diff @ " << samples << " samples: " << (model_exp - sampled_exp).l2norm() << endl; + } + return 0; +} + diff --git a/utils/weights.cc b/utils/weights.cc index ac407dfb..39c18474 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -144,8 +144,10 @@ void Weights::ShowLargestFeatures(const vector& w) { vector fnums(w.size()); for (int i = 0; i < w.size(); ++i) fnums[i] = i; + int nf = FD::NumFeats(); + if (nf > 10) nf = 10; vector::iterator mid = fnums.begin(); - mid += (w.size() > 10 ? 10 : w.size()); + mid += nf; partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); cerr << "TOP FEATURES:"; for (vector::iterator i = fnums.begin(); i != mid; ++i) { -- cgit v1.2.3