summaryrefslogtreecommitdiff
path: root/rst_parser/arc_factored.h
diff options
context:
space:
mode:
Diffstat (limited to 'rst_parser/arc_factored.h')
-rw-r--r--rst_parser/arc_factored.h27
1 files changed, 23 insertions, 4 deletions
diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h
index e99be482..3003a86e 100644
--- a/rst_parser/arc_factored.h
+++ b/rst_parser/arc_factored.h
@@ -10,11 +10,11 @@
#include "prob.h"
#include "weights.h"
-struct SpanningTree {
- SpanningTree() : roots(1, -1) {}
+struct EdgeSubset {
+ EdgeSubset() {}
std::vector<short> roots; // unless multiroot trees are supported, this
// will have a single member
- std::vector<std::pair<short, short> > h_m_pairs;
+ std::vector<std::pair<short, short> > h_m_pairs; // h,m start at *1*
};
class ArcFactoredForest {
@@ -35,7 +35,14 @@ class ArcFactoredForest {
// compute the maximum spanning tree based on the current weighting
// using the O(n^2) CLE algorithm
- void MaximumSpanningTree(SpanningTree* st) const;
+ void MaximumEdgeSubset(EdgeSubset* st) const;
+
+ // Reweight edges so that edge_prob is the edge's marginals
+ // optionally returns log partition
+ void EdgeMarginals(double* p_log_z = NULL);
+
+ // This may not return a tree
+ void PickBestParentForEachWord(EdgeSubset* st) const;
struct Edge {
Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {}
@@ -61,6 +68,10 @@ class ArcFactoredForest {
return h ? edges_(h - 1, m - 1) : root_edges_[m - 1];
}
+ float Weight(short h, short m) const {
+ return log((*this)(h,m).edge_prob);
+ }
+
template <class V>
void Reweight(const V& weights) {
for (int m = 0; m < num_words_; ++m) {
@@ -85,4 +96,12 @@ inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge&
return os << "(" << edge.h << " < " << edge.m << ")";
}
+inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) {
+ for (unsigned i = 0; i < ss.roots.size(); ++i)
+ os << "ROOT < " << ss.roots[i] << std::endl;
+ for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i)
+ os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl;
+ return os;
+}
+
#endif