From c5ec52ded3f14271e25e97cefc8bac03b176f297 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 16 Apr 2012 00:18:20 -0400 Subject: rst algorithm --- rst_parser/arc_factored.h | 4 +++- rst_parser/mst_train.cc | 21 +++++++++++++++------ rst_parser/rst.cc | 45 ++++++++++++++++++++++++++++++++++++++++++++- rst_parser/rst.h | 9 +++++++-- 4 files changed, 69 insertions(+), 10 deletions(-) (limited to 'rst_parser') diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index a95f8230..d9a0bb24 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -28,10 +28,12 @@ struct ArcFeatureFunction; class ArcFactoredForest { public: ArcFactoredForest() : num_words_() {} - explicit ArcFactoredForest(short num_words) { + explicit ArcFactoredForest(short num_words) : num_words_(num_words) { resize(num_words); } + unsigned size() const { return num_words_; } + void resize(unsigned num_words) { num_words_ = num_words; root_edges_.clear(); diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index def23edb..b5114726 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -13,6 +13,7 @@ #include "picojson.h" #include "optimize.h" #include "weights.h" +#include "rst.h" using namespace std; namespace po = boost::program_options; @@ -173,12 +174,13 @@ int main(int argc, char** argv) { double obj = -empirical.dot(weights); // SparseVector mfm; //DE for (int i = 0; i < corpus.size(); ++i) { + const int num_words = corpus[i].ts.words.size(); forests[i].Reweight(weights); - double logz; - forests[i].EdgeMarginals(&logz); - //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -logz << " OO= " << (-corpus[i].features.dot(weights) - logz) << endl; - obj -= logz; - int num_words = corpus[i].ts.words.size(); + double lz; + forests[i].EdgeMarginals(&lz); + obj -= lz; + //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -lz << " OO= " << (-corpus[i].features.dot(weights) - lz) << endl; + //cerr << " ZZ = " << zz << endl; for (int h = -1; h < num_words; ++h) { for (int m = 0; m < num_words; ++m) { if (h == m) continue; @@ -198,13 +200,20 @@ int main(int argc, char** argv) { double gnorm = 0; for (int i = 0; i < g.size(); ++i) gnorm += g[i]*g[i]; - cerr << "OBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm) << endl; + ostringstream ll; + ll << "ITER=" << (iter+1) << "\tOBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm); + cerr << endl << ll.str() << endl; obj += r; assert(obj >= 0); o->Optimize(obj, g, &weights); Weights::ShowLargestFeatures(weights); + string sl = ll.str(); + Weights::WriteToFile(o->HasConverged() ? "weights.final.gz" : "weights.cur.gz", weights, true, &sl); if (o->HasConverged()) { cerr << "CONVERGED\n"; break; } } + forests[0].Reweight(weights); + TreeSampler ts(forests[0]); + EdgeSubset tt; ts.SampleRandomSpanningTree(&tt); return 0; } diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc index f6b295b3..c4ce898e 100644 --- a/rst_parser/rst.cc +++ b/rst_parser/rst.cc @@ -2,6 +2,49 @@ using namespace std; -StochasticForest::StochasticForest(const ArcFactoredForest& af) { +// David B. Wilson. Generating Random Spanning Trees More Quickly than the Cover Time. + +TreeSampler::TreeSampler(const ArcFactoredForest& af) : forest(af), usucc(af.size() + 1) { + // edges are directed from modifiers to heads, to the root + for (int m = 1; m <= forest.size(); ++m) { + SampleSet& ss = usucc[m]; + for (int h = 0; h <= forest.size(); ++h) + ss.add(forest(h-1,m-1).edge_prob.as_float()); + } } +void TreeSampler::SampleRandomSpanningTree(EdgeSubset* tree) { + MT19937 rng; + const int r = 0; + bool success = false; + while (!success) { + int roots = 0; + vector next(forest.size() + 1, -1); + vector in_tree(forest.size() + 1, 0); + in_tree[r] = 1; + for (int i = 0; i < forest.size(); ++i) { + int u = i; + if (in_tree[u]) continue; + while(!in_tree[u]) { + next[u] = rng.SelectSample(usucc[u]); + u = next[u]; + } + u = i; + cerr << (u-1); + while(!in_tree[u]) { + in_tree[u] = true; + u = next[u]; + cerr << " > " << (u-1); + if (u == r) { ++roots; } + } + cerr << endl; + } + assert(roots > 0); + if (roots > 1) { + cerr << "FAILURE\n"; + } else { + success = true; + } + } +}; + diff --git a/rst_parser/rst.h b/rst_parser/rst.h index 865871eb..a269ff9b 100644 --- a/rst_parser/rst.h +++ b/rst_parser/rst.h @@ -1,10 +1,15 @@ #ifndef _RST_H_ #define _RST_H_ +#include +#include "sampler.h" #include "arc_factored.h" -struct StochasticForest { - explicit StochasticForest(const ArcFactoredForest& af); +struct TreeSampler { + explicit TreeSampler(const ArcFactoredForest& af); + void SampleRandomSpanningTree(EdgeSubset* tree); + const ArcFactoredForest& forest; + std::vector > usucc; }; #endif -- cgit v1.2.3