summaryrefslogtreecommitdiff
path: root/rst_parser/arc_factored.h
diff options
context:
space:
mode:
Diffstat (limited to 'rst_parser/arc_factored.h')
-rw-r--r--rst_parser/arc_factored.h82
1 files changed, 59 insertions, 23 deletions
diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h
index e99be482..c5481d80 100644
--- a/rst_parser/arc_factored.h
+++ b/rst_parser/arc_factored.h
@@ -5,37 +5,65 @@
#include <cassert>
#include <vector>
#include <utility>
+#include <boost/shared_ptr.hpp>
#include "array2d.h"
#include "sparse_vector.h"
#include "prob.h"
#include "weights.h"
+#include "wordid.h"
-struct SpanningTree {
- SpanningTree() : roots(1, -1) {}
+struct TaggedSentence {
+ std::vector<WordID> words;
+ std::vector<WordID> pos;
+};
+
+struct ArcFeatureFunctions;
+struct EdgeSubset {
+ EdgeSubset() {}
std::vector<short> roots; // unless multiroot trees are supported, this
// will have a single member
- std::vector<std::pair<short, short> > h_m_pairs;
+ std::vector<std::pair<short, short> > h_m_pairs; // h,m start at 0
+ // assumes ArcFeatureFunction::PrepareForInput has already been called
+ void ExtractFeatures(const TaggedSentence& sentence,
+ const ArcFeatureFunctions& ffs,
+ SparseVector<double>* features) const;
};
class ArcFactoredForest {
public:
- explicit ArcFactoredForest(short num_words) :
- num_words_(num_words),
- root_edges_(num_words),
- edges_(num_words, num_words) {
+ ArcFactoredForest() : num_words_() {}
+ explicit ArcFactoredForest(short num_words) : num_words_(num_words) {
+ resize(num_words);
+ }
+
+ unsigned size() const { return num_words_; }
+
+ void resize(unsigned num_words) {
+ num_words_ = num_words;
+ root_edges_.clear();
+ edges_.clear();
+ root_edges_.resize(num_words);
+ edges_.resize(num_words, num_words);
for (int h = 0; h < num_words; ++h) {
for (int m = 0; m < num_words; ++m) {
- edges_(h, m).h = h + 1;
- edges_(h, m).m = m + 1;
+ edges_(h, m).h = h;
+ edges_(h, m).m = m;
}
- root_edges_[h].h = 0;
- root_edges_[h].m = h + 1;
+ root_edges_[h].h = -1;
+ root_edges_[h].m = h;
}
}
// compute the maximum spanning tree based on the current weighting
// using the O(n^2) CLE algorithm
- void MaximumSpanningTree(SpanningTree* st) const;
+ void MaximumSpanningTree(EdgeSubset* st) const;
+
+ // Reweight edges so that edge_prob is the edge's marginals
+ // optionally returns log partition
+ void EdgeMarginals(prob_t* p_log_z = NULL);
+
+ // This may not return a tree
+ void PickBestParentForEachWord(EdgeSubset* st) const;
struct Edge {
Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {}
@@ -45,20 +73,20 @@ class ArcFactoredForest {
prob_t edge_prob;
};
+ // set eges_[*].features
+ void ExtractFeatures(const TaggedSentence& sentence,
+ const ArcFeatureFunctions& ffs);
+
const Edge& operator()(short h, short m) const {
- assert(m > 0);
- assert(m <= num_words_);
- assert(h >= 0);
- assert(h <= num_words_);
- return h ? edges_(h - 1, m - 1) : root_edges_[m - 1];
+ return h >= 0 ? edges_(h, m) : root_edges_[m];
}
Edge& operator()(short h, short m) {
- assert(m > 0);
- assert(m <= num_words_);
- assert(h >= 0);
- assert(h <= num_words_);
- return h ? edges_(h - 1, m - 1) : root_edges_[m - 1];
+ return h >= 0 ? edges_(h, m) : root_edges_[m];
+ }
+
+ float Weight(short h, short m) const {
+ return log((*this)(h,m).edge_prob);
}
template <class V>
@@ -76,7 +104,7 @@ class ArcFactoredForest {
}
private:
- unsigned num_words_;
+ int num_words_;
std::vector<Edge> root_edges_;
Array2D<Edge> edges_;
};
@@ -85,4 +113,12 @@ inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge&
return os << "(" << edge.h << " < " << edge.m << ")";
}
+inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) {
+ for (unsigned i = 0; i < ss.roots.size(); ++i)
+ os << "ROOT < " << ss.roots[i] << std::endl;
+ for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i)
+ os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl;
+ return os;
+}
+
#endif