diff options
Diffstat (limited to 'rst_parser')
| -rw-r--r-- | rst_parser/Makefile.am | 5 | ||||
| -rw-r--r-- | rst_parser/arc_factored.cc | 31 | ||||
| -rw-r--r-- | rst_parser/arc_factored.h | 72 | ||||
| -rw-r--r-- | rst_parser/mst_train.cc | 1 | ||||
| -rw-r--r-- | rst_parser/rst.cc | 5 | ||||
| -rw-r--r-- | rst_parser/rst.h | 5 | ||||
| -rw-r--r-- | rst_parser/rst_test.cc | 33 | 
7 files changed, 129 insertions, 23 deletions
| diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am index fef1c1a2..e97ab5c5 100644 --- a/rst_parser/Makefile.am +++ b/rst_parser/Makefile.am @@ -8,9 +8,12 @@ TESTS = rst_test  noinst_LIBRARIES = librst.a -librst_a_SOURCES = rst.cc +librst_a_SOURCES = arc_factored.cc rst.cc  mst_train_SOURCES = mst_train.cc  mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +rst_test_SOURCES = rst_test.cc +rst_test_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +  AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc new file mode 100644 index 00000000..1e75600b --- /dev/null +++ b/rst_parser/arc_factored.cc @@ -0,0 +1,31 @@ +#include "arc_factored.h" + +#include <set> + +#include <boost/pending/disjoint_sets.hpp> + +using namespace std; +using namespace boost; + +// based on Trajan 1977 +void ArcFactoredForest::MaximumSpanningTree(SpanningTree* st) const { +  typedef disjoint_sets_with_storage<identity_property_map, identity_property_map, +      find_with_full_path_compression> DisjointSet; +  DisjointSet strongly(num_words_ + 1); +  DisjointSet weakly(num_words_ + 1); +  set<unsigned> roots, h, rset; +  vector<pair<short, short> > enter(num_words_ + 1); +  for (unsigned i = 0; i <= num_words_; ++i) { +    strongly.make_set(i); +    weakly.make_set(i); +    roots.insert(i); +  } +  while(!roots.empty()) { +    set<unsigned>::iterator it = roots.begin(); +    const unsigned k = *it; +    roots.erase(it); +    cerr << "k=" << k << endl; +    pair<short,short> ij; // TODO = Max(k); +  } +} + diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index 312d7d67..e99be482 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -1,58 +1,88 @@  #ifndef _ARC_FACTORED_H_  #define _ARC_FACTORED_H_ -#include <vector> +#include <iostream>  #include <cassert> +#include <vector> +#include <utility>  #include "array2d.h"  #include "sparse_vector.h" +#include "prob.h" +#include "weights.h" + +struct SpanningTree { +  SpanningTree() : roots(1, -1) {} +  std::vector<short> roots; // unless multiroot trees are supported, this +                            // will have a single member +  std::vector<std::pair<short, short> > h_m_pairs; +};  class ArcFactoredForest {   public:    explicit ArcFactoredForest(short num_words) :        num_words_(num_words),        root_edges_(num_words), -      edges_(num_words, num_words) {} +      edges_(num_words, num_words) { +    for (int h = 0; h < num_words; ++h) { +      for (int m = 0; m < num_words; ++m) { +        edges_(h, m).h = h + 1; +        edges_(h, m).m = m + 1; +      } +      root_edges_[h].h = 0; +      root_edges_[h].m = h + 1; +    } +  } + +  // compute the maximum spanning tree based on the current weighting +  // using the O(n^2) CLE algorithm +  void MaximumSpanningTree(SpanningTree* st) const;    struct Edge { -    Edge() : features(), edge_prob(prob_t::Zero()) {} +    Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {} +    short h; +    short m;      SparseVector<weight_t> features;      prob_t edge_prob;    }; -  template <class V> -  void Reweight(const V& weights) { -    for (int m = 0; m < num_words_; ++m) { -      for (int h = 0; h < num_words_; ++h) { -        if (h != m) { -          Edge& e = edges_(h, m); -          e.edge_prob.logeq(e.features.dot(weights)); -        } -      } -      if (m) { -        Edge& e = root_edges_[m]; -        e.edge_prob.logeq(e.features.dot(weights)); -      } -    } -  } -    const Edge& operator()(short h, short m) const {      assert(m > 0);      assert(m <= num_words_);      assert(h >= 0);      assert(h <= num_words_); -    return h ? edges_(h - 1, m - 1) : root_edges[m - 1]; +    return h ? edges_(h - 1, m - 1) : root_edges_[m - 1];    } +    Edge& operator()(short h, short m) {      assert(m > 0);      assert(m <= num_words_);      assert(h >= 0);      assert(h <= num_words_); -    return h ? edges_(h - 1, m - 1) : root_edges[m - 1]; +    return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; +  } + +  template <class V> +  void Reweight(const V& weights) { +    for (int m = 0; m < num_words_; ++m) { +      for (int h = 0; h < num_words_; ++h) { +        if (h != m) { +          Edge& e = edges_(h, m); +          e.edge_prob.logeq(e.features.dot(weights)); +        } +      } +      Edge& e = root_edges_[m]; +      e.edge_prob.logeq(e.features.dot(weights)); +    }    } +   private:    unsigned num_words_;    std::vector<Edge> root_edges_;    Array2D<Edge> edges_;  }; +inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge& edge) { +  return os << "(" << edge.h << " < " << edge.m << ")"; +} +  #endif diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index 1bceaff5..7b5af4c1 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -6,6 +6,7 @@ using namespace std;  int main(int argc, char** argv) {    ArcFactoredForest af(5); +  cerr << af(0,3) << endl;    return 0;  } diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc index 0ab3e296..f6b295b3 100644 --- a/rst_parser/rst.cc +++ b/rst_parser/rst.cc @@ -1,2 +1,7 @@  #include "rst.h" +using namespace std; + +StochasticForest::StochasticForest(const ArcFactoredForest& af) { +} + diff --git a/rst_parser/rst.h b/rst_parser/rst.h index 30a1f8a4..865871eb 100644 --- a/rst_parser/rst.h +++ b/rst_parser/rst.h @@ -1,7 +1,10 @@  #ifndef _RST_H_  #define _RST_H_ -struct RandomSpanningTree { +#include "arc_factored.h" + +struct StochasticForest { +  explicit StochasticForest(const ArcFactoredForest& af);  };  #endif diff --git a/rst_parser/rst_test.cc b/rst_parser/rst_test.cc new file mode 100644 index 00000000..e8fe706e --- /dev/null +++ b/rst_parser/rst_test.cc @@ -0,0 +1,33 @@ +#include "arc_factored.h" + +#include <iostream> + +using namespace std; + +int main(int argc, char** argv) { +  // John saw Mary +  //   (H -> M) +  //   (1 -> 2) 20 +  //   (1 -> 3) 3 +  //   (2 -> 1) 20 +  //   (2 -> 3) 30 +  //   (3 -> 2) 0 +  //   (3 -> 1) 11 +  //   (0, 2) 10 +  //   (0, 1) 9 +  //   (0, 3) 9 +  ArcFactoredForest af(3); +  af(1,2).edge_prob.logeq(20); +  af(1,3).edge_prob.logeq(3); +  af(2,1).edge_prob.logeq(20); +  af(2,3).edge_prob.logeq(30); +  af(3,2).edge_prob.logeq(0); +  af(3,1).edge_prob.logeq(11); +  af(0,2).edge_prob.logeq(10); +  af(0,1).edge_prob.logeq(9); +  af(0,3).edge_prob.logeq(9); +  SpanningTree tree; +  af.MaximumSpanningTree(&tree); +  return 0; +} + | 
