diff options
Diffstat (limited to 'rst_parser')
| -rw-r--r-- | rst_parser/Makefile.am | 2 | ||||
| -rw-r--r-- | rst_parser/arc_factored.cc | 98 | ||||
| -rw-r--r-- | rst_parser/arc_factored.h | 27 | ||||
| -rw-r--r-- | rst_parser/arc_factored_marginals.cc | 52 | ||||
| -rw-r--r-- | rst_parser/rst_test.cc | 9 | 
5 files changed, 177 insertions, 11 deletions
diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am index e97ab5c5..b61a20dd 100644 --- a/rst_parser/Makefile.am +++ b/rst_parser/Makefile.am @@ -8,7 +8,7 @@ TESTS = rst_test  noinst_LIBRARIES = librst.a -librst_a_SOURCES = arc_factored.cc rst.cc +librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc  mst_train_SOURCES = mst_train.cc  mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc index 1e75600b..b2c2c427 100644 --- a/rst_parser/arc_factored.cc +++ b/rst_parser/arc_factored.cc @@ -1,31 +1,121 @@  #include "arc_factored.h"  #include <set> +#include <tr1/unordered_set>  #include <boost/pending/disjoint_sets.hpp> +#include <boost/functional/hash.hpp>  using namespace std; +using namespace std::tr1;  using namespace boost; +void ArcFactoredForest::PickBestParentForEachWord(EdgeSubset* st) const { +  for (int m = 1; m <= num_words_; ++m) { +    int best_head = -1; +    prob_t best_score; +    for (int h = 0; h <= num_words_; ++h) { +      const Edge& edge = (*this)(h,m); +      if (best_head < 0 || edge.edge_prob > best_score) { +        best_score = edge.edge_prob; +        best_head = h; +      } +    } +    assert(best_head >= 0);  +    if (best_head) +      st->h_m_pairs.push_back(make_pair<short,short>(best_head, m)); +    else +      st->roots.push_back(m); +  } +} + +struct WeightedEdge { +  WeightedEdge() : h(), m(), weight() {} +  WeightedEdge(short hh, short mm, float w) : h(hh), m(mm), weight(w) {} +  short h, m; +  float weight; +  inline bool operator==(const WeightedEdge& o) const { +    return h == o.h && m == o.m && weight == o.weight; +  } +  inline bool operator!=(const WeightedEdge& o) const { +    return h != o.h || m != o.m || weight != o.weight; +  } +}; +inline bool operator<(const WeightedEdge& l, const WeightedEdge& o) { return l.weight < o.weight; } +inline size_t hash_value(const WeightedEdge& e) { return reinterpret_cast<const size_t&>(e); } + + +struct PriorityQueue { +  void push(const WeightedEdge& e) {} +  const WeightedEdge& top() const { +    static WeightedEdge w(1,2,3); +    return w; +  } +  void pop() {} +  void increment_all(float p) {} +}; +  // based on Trajan 1977 -void ArcFactoredForest::MaximumSpanningTree(SpanningTree* st) const { +void ArcFactoredForest::MaximumEdgeSubset(EdgeSubset* st) const {    typedef disjoint_sets_with_storage<identity_property_map, identity_property_map,        find_with_full_path_compression> DisjointSet;    DisjointSet strongly(num_words_ + 1);    DisjointSet weakly(num_words_ + 1); -  set<unsigned> roots, h, rset; -  vector<pair<short, short> > enter(num_words_ + 1); +  set<unsigned> roots, rset; +  unordered_set<WeightedEdge, boost::hash<WeightedEdge> > h; +  vector<PriorityQueue> qs(num_words_ + 1); +  vector<WeightedEdge> enter(num_words_ + 1); +  vector<unsigned> mins(num_words_ + 1); +  const WeightedEdge kDUMMY(0,0,0.0f);    for (unsigned i = 0; i <= num_words_; ++i) { +    if (i > 0) { +      // I(i) incidence on i -- all incoming edges +      for (unsigned j = 0; j <= num_words_; ++j) { +        qs[i].push(WeightedEdge(j, i, Weight(j,i))); +      } +    }      strongly.make_set(i);      weakly.make_set(i);      roots.insert(i); +    enter[i] = kDUMMY; +    mins[i] = i;    }    while(!roots.empty()) {      set<unsigned>::iterator it = roots.begin();      const unsigned k = *it;      roots.erase(it);      cerr << "k=" << k << endl; -    pair<short,short> ij; // TODO = Max(k); +    WeightedEdge ij = qs[k].top();  // MAX(k) +    qs[k].pop(); +    if (ij.weight <= 0) { +      rset.insert(k); +    } else { +      if (strongly.find_set(ij.h) == k) { +        roots.insert(k); +      } else { +        h.insert(ij); +        if (weakly.find_set(ij.h) != weakly.find_set(ij.m)) { +          weakly.union_set(ij.h, ij.m); +          enter[k] = ij; +        } else { +          unsigned vertex = 0; +          float val = 99999999999; +          WeightedEdge xy = ij; +          while(xy != kDUMMY) { +            if (xy.weight < val) { +              val = xy.weight; +              vertex = strongly.find_set(xy.m); +            } +            xy = enter[strongly.find_set(xy.h)]; +          } +          qs[k].increment_all(val - ij.weight); +          mins[k] = mins[vertex]; +          xy = enter[strongly.find_set(ij.h)]; +          while (xy != kDUMMY) { +          } +        } +      } +    }    }  } diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index e99be482..3003a86e 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -10,11 +10,11 @@  #include "prob.h"  #include "weights.h" -struct SpanningTree { -  SpanningTree() : roots(1, -1) {} +struct EdgeSubset { +  EdgeSubset() {}    std::vector<short> roots; // unless multiroot trees are supported, this                              // will have a single member -  std::vector<std::pair<short, short> > h_m_pairs; +  std::vector<std::pair<short, short> > h_m_pairs; // h,m start at *1*  };  class ArcFactoredForest { @@ -35,7 +35,14 @@ class ArcFactoredForest {    // compute the maximum spanning tree based on the current weighting    // using the O(n^2) CLE algorithm -  void MaximumSpanningTree(SpanningTree* st) const; +  void MaximumEdgeSubset(EdgeSubset* st) const; + +  // Reweight edges so that edge_prob is the edge's marginals +  // optionally returns log partition +  void EdgeMarginals(double* p_log_z = NULL); + +  // This may not return a tree +  void PickBestParentForEachWord(EdgeSubset* st) const;    struct Edge {      Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {} @@ -61,6 +68,10 @@ class ArcFactoredForest {      return h ? edges_(h - 1, m - 1) : root_edges_[m - 1];    } +  float Weight(short h, short m) const { +    return log((*this)(h,m).edge_prob); +  } +    template <class V>    void Reweight(const V& weights) {      for (int m = 0; m < num_words_; ++m) { @@ -85,4 +96,12 @@ inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge&    return os << "(" << edge.h << " < " << edge.m << ")";  } +inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) { +  for (unsigned i = 0; i < ss.roots.size(); ++i) +    os << "ROOT < " << ss.roots[i] << std::endl; +  for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i) +    os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl; +  return os; +} +  #endif diff --git a/rst_parser/arc_factored_marginals.cc b/rst_parser/arc_factored_marginals.cc new file mode 100644 index 00000000..9851b59a --- /dev/null +++ b/rst_parser/arc_factored_marginals.cc @@ -0,0 +1,52 @@ +#include "arc_factored.h" + +#include <iostream> + +#include "config.h" + +using namespace std; + +#if HAVE_EIGEN + +#include <Eigen/Dense> +typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> ArcMatrix; +typedef Eigen::Matrix<double, Eigen::Dynamic, 1> RootVector; + +void ArcFactoredForest::EdgeMarginals(double *plog_z) { +  ArcMatrix A(num_words_,num_words_); +  RootVector r(num_words_); +  for (int h = 0; h < num_words_; ++h) { +    for (int m = 0; m < num_words_; ++m) { +      if (h != m) +        A(h,m) = edges_(h,m).edge_prob.as_float(); +      else +        A(h,m) = 0; +    } +    r(h) = root_edges_[h].edge_prob.as_float(); +  } + +  ArcMatrix L = -A; +  L.diagonal() = A.colwise().sum(); +  L.row(0) = r; +  ArcMatrix Linv = L.inverse(); +  if (plog_z) *plog_z = log(Linv.determinant()); +  RootVector rootMarginals = r.cwiseProduct(Linv.col(0)); +  for (int h = 0; h < num_words_; ++h) { +    for (int m = 0; m < num_words_; ++m) { +      edges_(h,m).edge_prob = prob_t((m == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,m) - +                                     (h == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,h)); +    } +    root_edges_[h].edge_prob = prob_t(rootMarginals(h)); +  } +  // cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl; +} + +#else + +void ArcFactoredForest::EdgeMarginals(double*) { +  cerr << "EdgeMarginals() requires --with-eigen!\n"; +  abort(); +} + +#endif + diff --git a/rst_parser/rst_test.cc b/rst_parser/rst_test.cc index e8fe706e..8995515f 100644 --- a/rst_parser/rst_test.cc +++ b/rst_parser/rst_test.cc @@ -26,8 +26,13 @@ int main(int argc, char** argv) {    af(0,2).edge_prob.logeq(10);    af(0,1).edge_prob.logeq(9);    af(0,3).edge_prob.logeq(9); -  SpanningTree tree; -  af.MaximumSpanningTree(&tree); +  EdgeSubset tree; +//  af.MaximumEdgeSubset(&tree); +  double lz; +  af.EdgeMarginals(&lz); +  cerr << "Z = " << lz << endl; +  af.PickBestParentForEachWord(&tree); +  cerr << tree << endl;    return 0;  }  | 
