summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rst_parser/Makefile.am5
-rw-r--r--rst_parser/arc_factored.cc31
-rw-r--r--rst_parser/arc_factored.h72
-rw-r--r--rst_parser/mst_train.cc1
-rw-r--r--rst_parser/rst.cc5
-rw-r--r--rst_parser/rst.h5
-rw-r--r--rst_parser/rst_test.cc33
7 files changed, 129 insertions, 23 deletions
diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am
index fef1c1a2..e97ab5c5 100644
--- a/rst_parser/Makefile.am
+++ b/rst_parser/Makefile.am
@@ -8,9 +8,12 @@ TESTS = rst_test
noinst_LIBRARIES = librst.a
-librst_a_SOURCES = rst.cc
+librst_a_SOURCES = arc_factored.cc rst.cc
mst_train_SOURCES = mst_train.cc
mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
+rst_test_SOURCES = rst_test.cc
+rst_test_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
+
AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm
diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc
new file mode 100644
index 00000000..1e75600b
--- /dev/null
+++ b/rst_parser/arc_factored.cc
@@ -0,0 +1,31 @@
+#include "arc_factored.h"
+
+#include <set>
+
+#include <boost/pending/disjoint_sets.hpp>
+
+using namespace std;
+using namespace boost;
+
+// based on Trajan 1977
+void ArcFactoredForest::MaximumSpanningTree(SpanningTree* st) const {
+ typedef disjoint_sets_with_storage<identity_property_map, identity_property_map,
+ find_with_full_path_compression> DisjointSet;
+ DisjointSet strongly(num_words_ + 1);
+ DisjointSet weakly(num_words_ + 1);
+ set<unsigned> roots, h, rset;
+ vector<pair<short, short> > enter(num_words_ + 1);
+ for (unsigned i = 0; i <= num_words_; ++i) {
+ strongly.make_set(i);
+ weakly.make_set(i);
+ roots.insert(i);
+ }
+ while(!roots.empty()) {
+ set<unsigned>::iterator it = roots.begin();
+ const unsigned k = *it;
+ roots.erase(it);
+ cerr << "k=" << k << endl;
+ pair<short,short> ij; // TODO = Max(k);
+ }
+}
+
diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h
index 312d7d67..e99be482 100644
--- a/rst_parser/arc_factored.h
+++ b/rst_parser/arc_factored.h
@@ -1,58 +1,88 @@
#ifndef _ARC_FACTORED_H_
#define _ARC_FACTORED_H_
-#include <vector>
+#include <iostream>
#include <cassert>
+#include <vector>
+#include <utility>
#include "array2d.h"
#include "sparse_vector.h"
+#include "prob.h"
+#include "weights.h"
+
+struct SpanningTree {
+ SpanningTree() : roots(1, -1) {}
+ std::vector<short> roots; // unless multiroot trees are supported, this
+ // will have a single member
+ std::vector<std::pair<short, short> > h_m_pairs;
+};
class ArcFactoredForest {
public:
explicit ArcFactoredForest(short num_words) :
num_words_(num_words),
root_edges_(num_words),
- edges_(num_words, num_words) {}
+ edges_(num_words, num_words) {
+ for (int h = 0; h < num_words; ++h) {
+ for (int m = 0; m < num_words; ++m) {
+ edges_(h, m).h = h + 1;
+ edges_(h, m).m = m + 1;
+ }
+ root_edges_[h].h = 0;
+ root_edges_[h].m = h + 1;
+ }
+ }
+
+ // compute the maximum spanning tree based on the current weighting
+ // using the O(n^2) CLE algorithm
+ void MaximumSpanningTree(SpanningTree* st) const;
struct Edge {
- Edge() : features(), edge_prob(prob_t::Zero()) {}
+ Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {}
+ short h;
+ short m;
SparseVector<weight_t> features;
prob_t edge_prob;
};
- template <class V>
- void Reweight(const V& weights) {
- for (int m = 0; m < num_words_; ++m) {
- for (int h = 0; h < num_words_; ++h) {
- if (h != m) {
- Edge& e = edges_(h, m);
- e.edge_prob.logeq(e.features.dot(weights));
- }
- }
- if (m) {
- Edge& e = root_edges_[m];
- e.edge_prob.logeq(e.features.dot(weights));
- }
- }
- }
-
const Edge& operator()(short h, short m) const {
assert(m > 0);
assert(m <= num_words_);
assert(h >= 0);
assert(h <= num_words_);
- return h ? edges_(h - 1, m - 1) : root_edges[m - 1];
+ return h ? edges_(h - 1, m - 1) : root_edges_[m - 1];
}
+
Edge& operator()(short h, short m) {
assert(m > 0);
assert(m <= num_words_);
assert(h >= 0);
assert(h <= num_words_);
- return h ? edges_(h - 1, m - 1) : root_edges[m - 1];
+ return h ? edges_(h - 1, m - 1) : root_edges_[m - 1];
+ }
+
+ template <class V>
+ void Reweight(const V& weights) {
+ for (int m = 0; m < num_words_; ++m) {
+ for (int h = 0; h < num_words_; ++h) {
+ if (h != m) {
+ Edge& e = edges_(h, m);
+ e.edge_prob.logeq(e.features.dot(weights));
+ }
+ }
+ Edge& e = root_edges_[m];
+ e.edge_prob.logeq(e.features.dot(weights));
+ }
}
+
private:
unsigned num_words_;
std::vector<Edge> root_edges_;
Array2D<Edge> edges_;
};
+inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge& edge) {
+ return os << "(" << edge.h << " < " << edge.m << ")";
+}
+
#endif
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc
index 1bceaff5..7b5af4c1 100644
--- a/rst_parser/mst_train.cc
+++ b/rst_parser/mst_train.cc
@@ -6,6 +6,7 @@ using namespace std;
int main(int argc, char** argv) {
ArcFactoredForest af(5);
+ cerr << af(0,3) << endl;
return 0;
}
diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc
index 0ab3e296..f6b295b3 100644
--- a/rst_parser/rst.cc
+++ b/rst_parser/rst.cc
@@ -1,2 +1,7 @@
#include "rst.h"
+using namespace std;
+
+StochasticForest::StochasticForest(const ArcFactoredForest& af) {
+}
+
diff --git a/rst_parser/rst.h b/rst_parser/rst.h
index 30a1f8a4..865871eb 100644
--- a/rst_parser/rst.h
+++ b/rst_parser/rst.h
@@ -1,7 +1,10 @@
#ifndef _RST_H_
#define _RST_H_
-struct RandomSpanningTree {
+#include "arc_factored.h"
+
+struct StochasticForest {
+ explicit StochasticForest(const ArcFactoredForest& af);
};
#endif
diff --git a/rst_parser/rst_test.cc b/rst_parser/rst_test.cc
new file mode 100644
index 00000000..e8fe706e
--- /dev/null
+++ b/rst_parser/rst_test.cc
@@ -0,0 +1,33 @@
+#include "arc_factored.h"
+
+#include <iostream>
+
+using namespace std;
+
+int main(int argc, char** argv) {
+ // John saw Mary
+ // (H -> M)
+ // (1 -> 2) 20
+ // (1 -> 3) 3
+ // (2 -> 1) 20
+ // (2 -> 3) 30
+ // (3 -> 2) 0
+ // (3 -> 1) 11
+ // (0, 2) 10
+ // (0, 1) 9
+ // (0, 3) 9
+ ArcFactoredForest af(3);
+ af(1,2).edge_prob.logeq(20);
+ af(1,3).edge_prob.logeq(3);
+ af(2,1).edge_prob.logeq(20);
+ af(2,3).edge_prob.logeq(30);
+ af(3,2).edge_prob.logeq(0);
+ af(3,1).edge_prob.logeq(11);
+ af(0,2).edge_prob.logeq(10);
+ af(0,1).edge_prob.logeq(9);
+ af(0,3).edge_prob.logeq(9);
+ SpanningTree tree;
+ af.MaximumSpanningTree(&tree);
+ return 0;
+}
+