From 19147c5f45b40eac1e0ae1bc8bc8ccf90d1ea56c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 14 Apr 2012 01:52:14 -0400 Subject: matrix tree theorem stuff --- rst_parser/arc_factored.h | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) (limited to 'rst_parser/arc_factored.h') diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index e99be482..3003a86e 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -10,11 +10,11 @@ #include "prob.h" #include "weights.h" -struct SpanningTree { - SpanningTree() : roots(1, -1) {} +struct EdgeSubset { + EdgeSubset() {} std::vector roots; // unless multiroot trees are supported, this // will have a single member - std::vector > h_m_pairs; + std::vector > h_m_pairs; // h,m start at *1* }; class ArcFactoredForest { @@ -35,7 +35,14 @@ class ArcFactoredForest { // compute the maximum spanning tree based on the current weighting // using the O(n^2) CLE algorithm - void MaximumSpanningTree(SpanningTree* st) const; + void MaximumEdgeSubset(EdgeSubset* st) const; + + // Reweight edges so that edge_prob is the edge's marginals + // optionally returns log partition + void EdgeMarginals(double* p_log_z = NULL); + + // This may not return a tree + void PickBestParentForEachWord(EdgeSubset* st) const; struct Edge { Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {} @@ -61,6 +68,10 @@ class ArcFactoredForest { return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; } + float Weight(short h, short m) const { + return log((*this)(h,m).edge_prob); + } + template void Reweight(const V& weights) { for (int m = 0; m < num_words_; ++m) { @@ -85,4 +96,12 @@ inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge& return os << "(" << edge.h << " < " << edge.m << ")"; } +inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) { + for (unsigned i = 0; i < ss.roots.size(); ++i) + os << "ROOT < " << ss.roots[i] << std::endl; + for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i) + os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl; + return os; +} + #endif -- cgit v1.2.3