diff options
Diffstat (limited to 'rst_parser/arc_factored.h')
-rw-r--r-- | rst_parser/arc_factored.h | 82 |
1 files changed, 59 insertions, 23 deletions
diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index e99be482..c5481d80 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -5,37 +5,65 @@ #include <cassert> #include <vector> #include <utility> +#include <boost/shared_ptr.hpp> #include "array2d.h" #include "sparse_vector.h" #include "prob.h" #include "weights.h" +#include "wordid.h" -struct SpanningTree { - SpanningTree() : roots(1, -1) {} +struct TaggedSentence { + std::vector<WordID> words; + std::vector<WordID> pos; +}; + +struct ArcFeatureFunctions; +struct EdgeSubset { + EdgeSubset() {} std::vector<short> roots; // unless multiroot trees are supported, this // will have a single member - std::vector<std::pair<short, short> > h_m_pairs; + std::vector<std::pair<short, short> > h_m_pairs; // h,m start at 0 + // assumes ArcFeatureFunction::PrepareForInput has already been called + void ExtractFeatures(const TaggedSentence& sentence, + const ArcFeatureFunctions& ffs, + SparseVector<double>* features) const; }; class ArcFactoredForest { public: - explicit ArcFactoredForest(short num_words) : - num_words_(num_words), - root_edges_(num_words), - edges_(num_words, num_words) { + ArcFactoredForest() : num_words_() {} + explicit ArcFactoredForest(short num_words) : num_words_(num_words) { + resize(num_words); + } + + unsigned size() const { return num_words_; } + + void resize(unsigned num_words) { + num_words_ = num_words; + root_edges_.clear(); + edges_.clear(); + root_edges_.resize(num_words); + edges_.resize(num_words, num_words); for (int h = 0; h < num_words; ++h) { for (int m = 0; m < num_words; ++m) { - edges_(h, m).h = h + 1; - edges_(h, m).m = m + 1; + edges_(h, m).h = h; + edges_(h, m).m = m; } - root_edges_[h].h = 0; - root_edges_[h].m = h + 1; + root_edges_[h].h = -1; + root_edges_[h].m = h; } } // compute the maximum spanning tree based on the current weighting // using the O(n^2) CLE algorithm - void MaximumSpanningTree(SpanningTree* st) const; + void MaximumSpanningTree(EdgeSubset* st) const; + + // Reweight edges so that edge_prob is the edge's marginals + // optionally returns log partition + void EdgeMarginals(prob_t* p_log_z = NULL); + + // This may not return a tree + void PickBestParentForEachWord(EdgeSubset* st) const; struct Edge { Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {} @@ -45,20 +73,20 @@ class ArcFactoredForest { prob_t edge_prob; }; + // set eges_[*].features + void ExtractFeatures(const TaggedSentence& sentence, + const ArcFeatureFunctions& ffs); + const Edge& operator()(short h, short m) const { - assert(m > 0); - assert(m <= num_words_); - assert(h >= 0); - assert(h <= num_words_); - return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; + return h >= 0 ? edges_(h, m) : root_edges_[m]; } Edge& operator()(short h, short m) { - assert(m > 0); - assert(m <= num_words_); - assert(h >= 0); - assert(h <= num_words_); - return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; + return h >= 0 ? edges_(h, m) : root_edges_[m]; + } + + float Weight(short h, short m) const { + return log((*this)(h,m).edge_prob); } template <class V> @@ -76,7 +104,7 @@ class ArcFactoredForest { } private: - unsigned num_words_; + int num_words_; std::vector<Edge> root_edges_; Array2D<Edge> edges_; }; @@ -85,4 +113,12 @@ inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge& return os << "(" << edge.h << " < " << edge.m << ")"; } +inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) { + for (unsigned i = 0; i < ss.roots.size(); ++i) + os << "ROOT < " << ss.roots[i] << std::endl; + for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i) + os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl; + return os; +} + #endif |