diff options
| author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-04-16 00:18:20 -0400 | 
|---|---|---|
| committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-04-16 00:18:20 -0400 | 
| commit | fa47b549e5ac7c16dce9e40d52328ffd51b60dc6 (patch) | |
| tree | 037edacd471b3a91427db2708af1533bb6116a65 | |
| parent | daa182defda1a97cb66b45b4ebf2a223948d950b (diff) | |
rst algorithm
| -rw-r--r-- | rst_parser/arc_factored.h | 4 | ||||
| -rw-r--r-- | rst_parser/mst_train.cc | 21 | ||||
| -rw-r--r-- | rst_parser/rst.cc | 45 | ||||
| -rw-r--r-- | rst_parser/rst.h | 9 | 
4 files changed, 69 insertions, 10 deletions
| diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index a95f8230..d9a0bb24 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -28,10 +28,12 @@ struct ArcFeatureFunction;  class ArcFactoredForest {   public:    ArcFactoredForest() : num_words_() {} -  explicit ArcFactoredForest(short num_words) { +  explicit ArcFactoredForest(short num_words) : num_words_(num_words) {      resize(num_words);    } +  unsigned size() const { return num_words_; } +    void resize(unsigned num_words) {      num_words_ = num_words;      root_edges_.clear(); diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index def23edb..b5114726 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -13,6 +13,7 @@  #include "picojson.h"  #include "optimize.h"  #include "weights.h" +#include "rst.h"  using namespace std;  namespace po = boost::program_options; @@ -173,12 +174,13 @@ int main(int argc, char** argv) {      double obj = -empirical.dot(weights);      // SparseVector<double> mfm;  //DE      for (int i = 0; i < corpus.size(); ++i) { +      const int num_words = corpus[i].ts.words.size();        forests[i].Reweight(weights); -      double logz; -      forests[i].EdgeMarginals(&logz); -      //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -logz << "  OO= " << (-corpus[i].features.dot(weights) - logz) << endl; -      obj -= logz; -      int num_words = corpus[i].ts.words.size(); +      double lz; +      forests[i].EdgeMarginals(&lz); +      obj -= lz; +      //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -lz << "  OO= " << (-corpus[i].features.dot(weights) - lz) << endl; +      //cerr << " ZZ = " << zz << endl;        for (int h = -1; h < num_words; ++h) {          for (int m = 0; m < num_words; ++m) {            if (h == m) continue; @@ -198,13 +200,20 @@ int main(int argc, char** argv) {      double gnorm = 0;      for (int i = 0; i < g.size(); ++i)        gnorm += g[i]*g[i]; -    cerr << "OBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm) << endl; +    ostringstream ll; +    ll << "ITER=" << (iter+1) << "\tOBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm); +    cerr << endl << ll.str() << endl;      obj += r;      assert(obj >= 0);      o->Optimize(obj, g, &weights);      Weights::ShowLargestFeatures(weights); +    string sl = ll.str(); +    Weights::WriteToFile(o->HasConverged() ? "weights.final.gz" : "weights.cur.gz", weights, true, &sl);      if (o->HasConverged()) { cerr << "CONVERGED\n"; break; }    } +  forests[0].Reweight(weights); +  TreeSampler ts(forests[0]); +  EdgeSubset tt; ts.SampleRandomSpanningTree(&tt);    return 0;  } diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc index f6b295b3..c4ce898e 100644 --- a/rst_parser/rst.cc +++ b/rst_parser/rst.cc @@ -2,6 +2,49 @@  using namespace std; -StochasticForest::StochasticForest(const ArcFactoredForest& af) { +// David B. Wilson. Generating Random Spanning Trees More Quickly than the Cover Time. + +TreeSampler::TreeSampler(const ArcFactoredForest& af) : forest(af), usucc(af.size() + 1) { +  // edges are directed from modifiers to heads, to the root +  for (int m = 1; m <= forest.size(); ++m) { +    SampleSet<double>& ss = usucc[m]; +    for (int h = 0; h <= forest.size(); ++h) +      ss.add(forest(h-1,m-1).edge_prob.as_float()); +  }  } +void TreeSampler::SampleRandomSpanningTree(EdgeSubset* tree) { +  MT19937 rng; +  const int r = 0; +  bool success = false; +  while (!success) { +    int roots = 0; +    vector<int> next(forest.size() + 1, -1); +    vector<char> in_tree(forest.size() + 1, 0); +    in_tree[r] = 1; +    for (int i = 0; i < forest.size(); ++i) { +      int u = i; +      if (in_tree[u]) continue; +      while(!in_tree[u]) { +        next[u] = rng.SelectSample(usucc[u]); +        u = next[u]; +      } +      u = i; +      cerr << (u-1); +      while(!in_tree[u]) { +        in_tree[u] = true; +        u = next[u]; +        cerr << " > " << (u-1); +        if (u == r) { ++roots; } +      } +      cerr << endl; +    } +    assert(roots > 0); +    if (roots > 1) { +      cerr << "FAILURE\n"; +    } else { +      success = true; +    } +  } +}; + diff --git a/rst_parser/rst.h b/rst_parser/rst.h index 865871eb..a269ff9b 100644 --- a/rst_parser/rst.h +++ b/rst_parser/rst.h @@ -1,10 +1,15 @@  #ifndef _RST_H_  #define _RST_H_ +#include <vector> +#include "sampler.h"  #include "arc_factored.h" -struct StochasticForest { -  explicit StochasticForest(const ArcFactoredForest& af); +struct TreeSampler { +  explicit TreeSampler(const ArcFactoredForest& af); +  void SampleRandomSpanningTree(EdgeSubset* tree); +  const ArcFactoredForest& forest; +  std::vector<SampleSet<double> > usucc;  };  #endif | 
