summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-04-14 01:52:14 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-04-14 01:52:14 -0400
commit19147c5f45b40eac1e0ae1bc8bc8ccf90d1ea56c (patch)
tree36cb3175a60cd2bebf711ca3ec61fcd2dbec8760
parent8b7872d6e72ad87cdc8411b0deff92ff9b4c2a95 (diff)
matrix tree theorem stuff
-rw-r--r--rst_parser/Makefile.am2
-rw-r--r--rst_parser/arc_factored.cc98
-rw-r--r--rst_parser/arc_factored.h27
-rw-r--r--rst_parser/arc_factored_marginals.cc52
-rw-r--r--rst_parser/rst_test.cc9
5 files changed, 177 insertions, 11 deletions
diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am
index e97ab5c5..b61a20dd 100644
--- a/rst_parser/Makefile.am
+++ b/rst_parser/Makefile.am
@@ -8,7 +8,7 @@ TESTS = rst_test
noinst_LIBRARIES = librst.a
-librst_a_SOURCES = arc_factored.cc rst.cc
+librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc
mst_train_SOURCES = mst_train.cc
mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc
index 1e75600b..b2c2c427 100644
--- a/rst_parser/arc_factored.cc
+++ b/rst_parser/arc_factored.cc
@@ -1,31 +1,121 @@
#include "arc_factored.h"
#include <set>
+#include <tr1/unordered_set>
#include <boost/pending/disjoint_sets.hpp>
+#include <boost/functional/hash.hpp>
using namespace std;
+using namespace std::tr1;
using namespace boost;
+void ArcFactoredForest::PickBestParentForEachWord(EdgeSubset* st) const {
+ for (int m = 1; m <= num_words_; ++m) {
+ int best_head = -1;
+ prob_t best_score;
+ for (int h = 0; h <= num_words_; ++h) {
+ const Edge& edge = (*this)(h,m);
+ if (best_head < 0 || edge.edge_prob > best_score) {
+ best_score = edge.edge_prob;
+ best_head = h;
+ }
+ }
+ assert(best_head >= 0);
+ if (best_head)
+ st->h_m_pairs.push_back(make_pair<short,short>(best_head, m));
+ else
+ st->roots.push_back(m);
+ }
+}
+
+struct WeightedEdge {
+ WeightedEdge() : h(), m(), weight() {}
+ WeightedEdge(short hh, short mm, float w) : h(hh), m(mm), weight(w) {}
+ short h, m;
+ float weight;
+ inline bool operator==(const WeightedEdge& o) const {
+ return h == o.h && m == o.m && weight == o.weight;
+ }
+ inline bool operator!=(const WeightedEdge& o) const {
+ return h != o.h || m != o.m || weight != o.weight;
+ }
+};
+inline bool operator<(const WeightedEdge& l, const WeightedEdge& o) { return l.weight < o.weight; }
+inline size_t hash_value(const WeightedEdge& e) { return reinterpret_cast<const size_t&>(e); }
+
+
+struct PriorityQueue {
+ void push(const WeightedEdge& e) {}
+ const WeightedEdge& top() const {
+ static WeightedEdge w(1,2,3);
+ return w;
+ }
+ void pop() {}
+ void increment_all(float p) {}
+};
+
// based on Trajan 1977
-void ArcFactoredForest::MaximumSpanningTree(SpanningTree* st) const {
+void ArcFactoredForest::MaximumEdgeSubset(EdgeSubset* st) const {
typedef disjoint_sets_with_storage<identity_property_map, identity_property_map,
find_with_full_path_compression> DisjointSet;
DisjointSet strongly(num_words_ + 1);
DisjointSet weakly(num_words_ + 1);
- set<unsigned> roots, h, rset;
- vector<pair<short, short> > enter(num_words_ + 1);
+ set<unsigned> roots, rset;
+ unordered_set<WeightedEdge, boost::hash<WeightedEdge> > h;
+ vector<PriorityQueue> qs(num_words_ + 1);
+ vector<WeightedEdge> enter(num_words_ + 1);
+ vector<unsigned> mins(num_words_ + 1);
+ const WeightedEdge kDUMMY(0,0,0.0f);
for (unsigned i = 0; i <= num_words_; ++i) {
+ if (i > 0) {
+ // I(i) incidence on i -- all incoming edges
+ for (unsigned j = 0; j <= num_words_; ++j) {
+ qs[i].push(WeightedEdge(j, i, Weight(j,i)));
+ }
+ }
strongly.make_set(i);
weakly.make_set(i);
roots.insert(i);
+ enter[i] = kDUMMY;
+ mins[i] = i;
}
while(!roots.empty()) {
set<unsigned>::iterator it = roots.begin();
const unsigned k = *it;
roots.erase(it);
cerr << "k=" << k << endl;
- pair<short,short> ij; // TODO = Max(k);
+ WeightedEdge ij = qs[k].top(); // MAX(k)
+ qs[k].pop();
+ if (ij.weight <= 0) {
+ rset.insert(k);
+ } else {
+ if (strongly.find_set(ij.h) == k) {
+ roots.insert(k);
+ } else {
+ h.insert(ij);
+ if (weakly.find_set(ij.h) != weakly.find_set(ij.m)) {
+ weakly.union_set(ij.h, ij.m);
+ enter[k] = ij;
+ } else {
+ unsigned vertex = 0;
+ float val = 99999999999;
+ WeightedEdge xy = ij;
+ while(xy != kDUMMY) {
+ if (xy.weight < val) {
+ val = xy.weight;
+ vertex = strongly.find_set(xy.m);
+ }
+ xy = enter[strongly.find_set(xy.h)];
+ }
+ qs[k].increment_all(val - ij.weight);
+ mins[k] = mins[vertex];
+ xy = enter[strongly.find_set(ij.h)];
+ while (xy != kDUMMY) {
+ }
+ }
+ }
+ }
}
}
diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h
index e99be482..3003a86e 100644
--- a/rst_parser/arc_factored.h
+++ b/rst_parser/arc_factored.h
@@ -10,11 +10,11 @@
#include "prob.h"
#include "weights.h"
-struct SpanningTree {
- SpanningTree() : roots(1, -1) {}
+struct EdgeSubset {
+ EdgeSubset() {}
std::vector<short> roots; // unless multiroot trees are supported, this
// will have a single member
- std::vector<std::pair<short, short> > h_m_pairs;
+ std::vector<std::pair<short, short> > h_m_pairs; // h,m start at *1*
};
class ArcFactoredForest {
@@ -35,7 +35,14 @@ class ArcFactoredForest {
// compute the maximum spanning tree based on the current weighting
// using the O(n^2) CLE algorithm
- void MaximumSpanningTree(SpanningTree* st) const;
+ void MaximumEdgeSubset(EdgeSubset* st) const;
+
+ // Reweight edges so that edge_prob is the edge's marginals
+ // optionally returns log partition
+ void EdgeMarginals(double* p_log_z = NULL);
+
+ // This may not return a tree
+ void PickBestParentForEachWord(EdgeSubset* st) const;
struct Edge {
Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {}
@@ -61,6 +68,10 @@ class ArcFactoredForest {
return h ? edges_(h - 1, m - 1) : root_edges_[m - 1];
}
+ float Weight(short h, short m) const {
+ return log((*this)(h,m).edge_prob);
+ }
+
template <class V>
void Reweight(const V& weights) {
for (int m = 0; m < num_words_; ++m) {
@@ -85,4 +96,12 @@ inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge&
return os << "(" << edge.h << " < " << edge.m << ")";
}
+inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) {
+ for (unsigned i = 0; i < ss.roots.size(); ++i)
+ os << "ROOT < " << ss.roots[i] << std::endl;
+ for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i)
+ os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl;
+ return os;
+}
+
#endif
diff --git a/rst_parser/arc_factored_marginals.cc b/rst_parser/arc_factored_marginals.cc
new file mode 100644
index 00000000..9851b59a
--- /dev/null
+++ b/rst_parser/arc_factored_marginals.cc
@@ -0,0 +1,52 @@
+#include "arc_factored.h"
+
+#include <iostream>
+
+#include "config.h"
+
+using namespace std;
+
+#if HAVE_EIGEN
+
+#include <Eigen/Dense>
+typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> ArcMatrix;
+typedef Eigen::Matrix<double, Eigen::Dynamic, 1> RootVector;
+
+void ArcFactoredForest::EdgeMarginals(double *plog_z) {
+ ArcMatrix A(num_words_,num_words_);
+ RootVector r(num_words_);
+ for (int h = 0; h < num_words_; ++h) {
+ for (int m = 0; m < num_words_; ++m) {
+ if (h != m)
+ A(h,m) = edges_(h,m).edge_prob.as_float();
+ else
+ A(h,m) = 0;
+ }
+ r(h) = root_edges_[h].edge_prob.as_float();
+ }
+
+ ArcMatrix L = -A;
+ L.diagonal() = A.colwise().sum();
+ L.row(0) = r;
+ ArcMatrix Linv = L.inverse();
+ if (plog_z) *plog_z = log(Linv.determinant());
+ RootVector rootMarginals = r.cwiseProduct(Linv.col(0));
+ for (int h = 0; h < num_words_; ++h) {
+ for (int m = 0; m < num_words_; ++m) {
+ edges_(h,m).edge_prob = prob_t((m == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,m) -
+ (h == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,h));
+ }
+ root_edges_[h].edge_prob = prob_t(rootMarginals(h));
+ }
+ // cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl;
+}
+
+#else
+
+void ArcFactoredForest::EdgeMarginals(double*) {
+ cerr << "EdgeMarginals() requires --with-eigen!\n";
+ abort();
+}
+
+#endif
+
diff --git a/rst_parser/rst_test.cc b/rst_parser/rst_test.cc
index e8fe706e..8995515f 100644
--- a/rst_parser/rst_test.cc
+++ b/rst_parser/rst_test.cc
@@ -26,8 +26,13 @@ int main(int argc, char** argv) {
af(0,2).edge_prob.logeq(10);
af(0,1).edge_prob.logeq(9);
af(0,3).edge_prob.logeq(9);
- SpanningTree tree;
- af.MaximumSpanningTree(&tree);
+ EdgeSubset tree;
+// af.MaximumEdgeSubset(&tree);
+ double lz;
+ af.EdgeMarginals(&lz);
+ cerr << "Z = " << lz << endl;
+ af.PickBestParentForEachWord(&tree);
+ cerr << tree << endl;
return 0;
}