diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-04-14 01:52:14 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-04-14 01:52:14 -0400 |
commit | 19147c5f45b40eac1e0ae1bc8bc8ccf90d1ea56c (patch) | |
tree | 36cb3175a60cd2bebf711ca3ec61fcd2dbec8760 | |
parent | 8b7872d6e72ad87cdc8411b0deff92ff9b4c2a95 (diff) |
matrix tree theorem stuff
-rw-r--r-- | rst_parser/Makefile.am | 2 | ||||
-rw-r--r-- | rst_parser/arc_factored.cc | 98 | ||||
-rw-r--r-- | rst_parser/arc_factored.h | 27 | ||||
-rw-r--r-- | rst_parser/arc_factored_marginals.cc | 52 | ||||
-rw-r--r-- | rst_parser/rst_test.cc | 9 |
5 files changed, 177 insertions, 11 deletions
diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am index e97ab5c5..b61a20dd 100644 --- a/rst_parser/Makefile.am +++ b/rst_parser/Makefile.am @@ -8,7 +8,7 @@ TESTS = rst_test noinst_LIBRARIES = librst.a -librst_a_SOURCES = arc_factored.cc rst.cc +librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc mst_train_SOURCES = mst_train.cc mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc index 1e75600b..b2c2c427 100644 --- a/rst_parser/arc_factored.cc +++ b/rst_parser/arc_factored.cc @@ -1,31 +1,121 @@ #include "arc_factored.h" #include <set> +#include <tr1/unordered_set> #include <boost/pending/disjoint_sets.hpp> +#include <boost/functional/hash.hpp> using namespace std; +using namespace std::tr1; using namespace boost; +void ArcFactoredForest::PickBestParentForEachWord(EdgeSubset* st) const { + for (int m = 1; m <= num_words_; ++m) { + int best_head = -1; + prob_t best_score; + for (int h = 0; h <= num_words_; ++h) { + const Edge& edge = (*this)(h,m); + if (best_head < 0 || edge.edge_prob > best_score) { + best_score = edge.edge_prob; + best_head = h; + } + } + assert(best_head >= 0); + if (best_head) + st->h_m_pairs.push_back(make_pair<short,short>(best_head, m)); + else + st->roots.push_back(m); + } +} + +struct WeightedEdge { + WeightedEdge() : h(), m(), weight() {} + WeightedEdge(short hh, short mm, float w) : h(hh), m(mm), weight(w) {} + short h, m; + float weight; + inline bool operator==(const WeightedEdge& o) const { + return h == o.h && m == o.m && weight == o.weight; + } + inline bool operator!=(const WeightedEdge& o) const { + return h != o.h || m != o.m || weight != o.weight; + } +}; +inline bool operator<(const WeightedEdge& l, const WeightedEdge& o) { return l.weight < o.weight; } +inline size_t hash_value(const WeightedEdge& e) { return reinterpret_cast<const size_t&>(e); } + + +struct PriorityQueue { + void push(const WeightedEdge& e) {} + const WeightedEdge& top() const { + static WeightedEdge w(1,2,3); + return w; + } + void pop() {} + void increment_all(float p) {} +}; + // based on Trajan 1977 -void ArcFactoredForest::MaximumSpanningTree(SpanningTree* st) const { +void ArcFactoredForest::MaximumEdgeSubset(EdgeSubset* st) const { typedef disjoint_sets_with_storage<identity_property_map, identity_property_map, find_with_full_path_compression> DisjointSet; DisjointSet strongly(num_words_ + 1); DisjointSet weakly(num_words_ + 1); - set<unsigned> roots, h, rset; - vector<pair<short, short> > enter(num_words_ + 1); + set<unsigned> roots, rset; + unordered_set<WeightedEdge, boost::hash<WeightedEdge> > h; + vector<PriorityQueue> qs(num_words_ + 1); + vector<WeightedEdge> enter(num_words_ + 1); + vector<unsigned> mins(num_words_ + 1); + const WeightedEdge kDUMMY(0,0,0.0f); for (unsigned i = 0; i <= num_words_; ++i) { + if (i > 0) { + // I(i) incidence on i -- all incoming edges + for (unsigned j = 0; j <= num_words_; ++j) { + qs[i].push(WeightedEdge(j, i, Weight(j,i))); + } + } strongly.make_set(i); weakly.make_set(i); roots.insert(i); + enter[i] = kDUMMY; + mins[i] = i; } while(!roots.empty()) { set<unsigned>::iterator it = roots.begin(); const unsigned k = *it; roots.erase(it); cerr << "k=" << k << endl; - pair<short,short> ij; // TODO = Max(k); + WeightedEdge ij = qs[k].top(); // MAX(k) + qs[k].pop(); + if (ij.weight <= 0) { + rset.insert(k); + } else { + if (strongly.find_set(ij.h) == k) { + roots.insert(k); + } else { + h.insert(ij); + if (weakly.find_set(ij.h) != weakly.find_set(ij.m)) { + weakly.union_set(ij.h, ij.m); + enter[k] = ij; + } else { + unsigned vertex = 0; + float val = 99999999999; + WeightedEdge xy = ij; + while(xy != kDUMMY) { + if (xy.weight < val) { + val = xy.weight; + vertex = strongly.find_set(xy.m); + } + xy = enter[strongly.find_set(xy.h)]; + } + qs[k].increment_all(val - ij.weight); + mins[k] = mins[vertex]; + xy = enter[strongly.find_set(ij.h)]; + while (xy != kDUMMY) { + } + } + } + } } } diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index e99be482..3003a86e 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -10,11 +10,11 @@ #include "prob.h" #include "weights.h" -struct SpanningTree { - SpanningTree() : roots(1, -1) {} +struct EdgeSubset { + EdgeSubset() {} std::vector<short> roots; // unless multiroot trees are supported, this // will have a single member - std::vector<std::pair<short, short> > h_m_pairs; + std::vector<std::pair<short, short> > h_m_pairs; // h,m start at *1* }; class ArcFactoredForest { @@ -35,7 +35,14 @@ class ArcFactoredForest { // compute the maximum spanning tree based on the current weighting // using the O(n^2) CLE algorithm - void MaximumSpanningTree(SpanningTree* st) const; + void MaximumEdgeSubset(EdgeSubset* st) const; + + // Reweight edges so that edge_prob is the edge's marginals + // optionally returns log partition + void EdgeMarginals(double* p_log_z = NULL); + + // This may not return a tree + void PickBestParentForEachWord(EdgeSubset* st) const; struct Edge { Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {} @@ -61,6 +68,10 @@ class ArcFactoredForest { return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; } + float Weight(short h, short m) const { + return log((*this)(h,m).edge_prob); + } + template <class V> void Reweight(const V& weights) { for (int m = 0; m < num_words_; ++m) { @@ -85,4 +96,12 @@ inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge& return os << "(" << edge.h << " < " << edge.m << ")"; } +inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) { + for (unsigned i = 0; i < ss.roots.size(); ++i) + os << "ROOT < " << ss.roots[i] << std::endl; + for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i) + os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl; + return os; +} + #endif diff --git a/rst_parser/arc_factored_marginals.cc b/rst_parser/arc_factored_marginals.cc new file mode 100644 index 00000000..9851b59a --- /dev/null +++ b/rst_parser/arc_factored_marginals.cc @@ -0,0 +1,52 @@ +#include "arc_factored.h" + +#include <iostream> + +#include "config.h" + +using namespace std; + +#if HAVE_EIGEN + +#include <Eigen/Dense> +typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> ArcMatrix; +typedef Eigen::Matrix<double, Eigen::Dynamic, 1> RootVector; + +void ArcFactoredForest::EdgeMarginals(double *plog_z) { + ArcMatrix A(num_words_,num_words_); + RootVector r(num_words_); + for (int h = 0; h < num_words_; ++h) { + for (int m = 0; m < num_words_; ++m) { + if (h != m) + A(h,m) = edges_(h,m).edge_prob.as_float(); + else + A(h,m) = 0; + } + r(h) = root_edges_[h].edge_prob.as_float(); + } + + ArcMatrix L = -A; + L.diagonal() = A.colwise().sum(); + L.row(0) = r; + ArcMatrix Linv = L.inverse(); + if (plog_z) *plog_z = log(Linv.determinant()); + RootVector rootMarginals = r.cwiseProduct(Linv.col(0)); + for (int h = 0; h < num_words_; ++h) { + for (int m = 0; m < num_words_; ++m) { + edges_(h,m).edge_prob = prob_t((m == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,m) - + (h == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,h)); + } + root_edges_[h].edge_prob = prob_t(rootMarginals(h)); + } + // cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl; +} + +#else + +void ArcFactoredForest::EdgeMarginals(double*) { + cerr << "EdgeMarginals() requires --with-eigen!\n"; + abort(); +} + +#endif + diff --git a/rst_parser/rst_test.cc b/rst_parser/rst_test.cc index e8fe706e..8995515f 100644 --- a/rst_parser/rst_test.cc +++ b/rst_parser/rst_test.cc @@ -26,8 +26,13 @@ int main(int argc, char** argv) { af(0,2).edge_prob.logeq(10); af(0,1).edge_prob.logeq(9); af(0,3).edge_prob.logeq(9); - SpanningTree tree; - af.MaximumSpanningTree(&tree); + EdgeSubset tree; +// af.MaximumEdgeSubset(&tree); + double lz; + af.EdgeMarginals(&lz); + cerr << "Z = " << lz << endl; + af.PickBestParentForEachWord(&tree); + cerr << tree << endl; return 0; } |