summaryrefslogtreecommitdiff
path: root/rst_parser/arc_factored.cc
diff options
context:
space:
mode:
Diffstat (limited to 'rst_parser/arc_factored.cc')
-rw-r--r--rst_parser/arc_factored.cc29
1 files changed, 22 insertions, 7 deletions
diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc
index b2c2c427..44e769b8 100644
--- a/rst_parser/arc_factored.cc
+++ b/rst_parser/arc_factored.cc
@@ -6,23 +6,38 @@
#include <boost/pending/disjoint_sets.hpp>
#include <boost/functional/hash.hpp>
+#include "arc_ff.h"
+
using namespace std;
using namespace std::tr1;
using namespace boost;
+void ArcFactoredForest::ExtractFeatures(const TaggedSentence& sentence,
+ const std::vector<boost::shared_ptr<ArcFeatureFunction> >& ffs) {
+ for (int i = 0; i < ffs.size(); ++i) {
+ const ArcFeatureFunction& ff = *ffs[i];
+ for (int m = 0; m < num_words_; ++m) {
+ for (int h = 0; h < num_words_; ++h) {
+ ff.EgdeFeatures(sentence, h, m, &edges_(h,m).features);
+ }
+ ff.EgdeFeatures(sentence, -1, m, &root_edges_[m].features);
+ }
+ }
+}
+
void ArcFactoredForest::PickBestParentForEachWord(EdgeSubset* st) const {
- for (int m = 1; m <= num_words_; ++m) {
- int best_head = -1;
+ for (int m = 0; m < num_words_; ++m) {
+ int best_head = -2;
prob_t best_score;
- for (int h = 0; h <= num_words_; ++h) {
+ for (int h = -1; h < num_words_; ++h) {
const Edge& edge = (*this)(h,m);
- if (best_head < 0 || edge.edge_prob > best_score) {
+ if (best_head < -1 || edge.edge_prob > best_score) {
best_score = edge.edge_prob;
best_head = h;
}
}
- assert(best_head >= 0);
- if (best_head)
+ assert(best_head >= -1);
+ if (best_head >= 0)
st->h_m_pairs.push_back(make_pair<short,short>(best_head, m));
else
st->roots.push_back(m);
@@ -56,7 +71,7 @@ struct PriorityQueue {
};
// based on Trajan 1977
-void ArcFactoredForest::MaximumEdgeSubset(EdgeSubset* st) const {
+void ArcFactoredForest::MaximumSpanningTree(EdgeSubset* st) const {
typedef disjoint_sets_with_storage<identity_property_map, identity_property_map,
find_with_full_path_compression> DisjointSet;
DisjointSet strongly(num_words_ + 1);