diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-04-14 01:52:14 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-04-14 01:52:14 -0400 |
commit | c22e9248a1fa24b0255a55d21afb94a9ed3ddc22 (patch) | |
tree | 8d6fc97f4e41bf78397c48f8f479a74b75a1b602 /rst_parser/arc_factored.h | |
parent | c294227a928672bf108eed81106063a194c872ca (diff) |
matrix tree theorem stuff
Diffstat (limited to 'rst_parser/arc_factored.h')
-rw-r--r-- | rst_parser/arc_factored.h | 27 |
1 files changed, 23 insertions, 4 deletions
diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index e99be482..3003a86e 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -10,11 +10,11 @@ #include "prob.h" #include "weights.h" -struct SpanningTree { - SpanningTree() : roots(1, -1) {} +struct EdgeSubset { + EdgeSubset() {} std::vector<short> roots; // unless multiroot trees are supported, this // will have a single member - std::vector<std::pair<short, short> > h_m_pairs; + std::vector<std::pair<short, short> > h_m_pairs; // h,m start at *1* }; class ArcFactoredForest { @@ -35,7 +35,14 @@ class ArcFactoredForest { // compute the maximum spanning tree based on the current weighting // using the O(n^2) CLE algorithm - void MaximumSpanningTree(SpanningTree* st) const; + void MaximumEdgeSubset(EdgeSubset* st) const; + + // Reweight edges so that edge_prob is the edge's marginals + // optionally returns log partition + void EdgeMarginals(double* p_log_z = NULL); + + // This may not return a tree + void PickBestParentForEachWord(EdgeSubset* st) const; struct Edge { Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {} @@ -61,6 +68,10 @@ class ArcFactoredForest { return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; } + float Weight(short h, short m) const { + return log((*this)(h,m).edge_prob); + } + template <class V> void Reweight(const V& weights) { for (int m = 0; m < num_words_; ++m) { @@ -85,4 +96,12 @@ inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge& return os << "(" << edge.h << " < " << edge.m << ")"; } +inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) { + for (unsigned i = 0; i < ss.roots.size(); ++i) + os << "ROOT < " << ss.roots[i] << std::endl; + for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i) + os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl; + return os; +} + #endif |