diff options
author | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-04-07 16:58:55 +0200 |
---|---|---|
committer | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-04-07 16:58:55 +0200 |
commit | e91553ae70907e243a554e4a549c53df57b78478 (patch) | |
tree | a4d044093f5937d0152b573c99914746b5a2b8ef /rst_parser | |
parent | fb714888562845a8ae10fd4411cf199961193833 (diff) | |
parent | 2fe4323cbfc34de906a2869f98c017b41e4ccae7 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'rst_parser')
-rw-r--r-- | rst_parser/Makefile.am | 19 | ||||
-rw-r--r-- | rst_parser/arc_factored.cc | 31 | ||||
-rw-r--r-- | rst_parser/arc_factored.h | 88 | ||||
-rw-r--r-- | rst_parser/mst_train.cc | 12 | ||||
-rw-r--r-- | rst_parser/rst.cc | 7 | ||||
-rw-r--r-- | rst_parser/rst.h | 10 | ||||
-rw-r--r-- | rst_parser/rst_test.cc | 33 |
7 files changed, 200 insertions, 0 deletions
diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am new file mode 100644 index 00000000..e97ab5c5 --- /dev/null +++ b/rst_parser/Makefile.am @@ -0,0 +1,19 @@ +bin_PROGRAMS = \ + mst_train + +noinst_PROGRAMS = \ + rst_test + +TESTS = rst_test + +noinst_LIBRARIES = librst.a + +librst_a_SOURCES = arc_factored.cc rst.cc + +mst_train_SOURCES = mst_train.cc +mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + +rst_test_SOURCES = rst_test.cc +rst_test_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc new file mode 100644 index 00000000..1e75600b --- /dev/null +++ b/rst_parser/arc_factored.cc @@ -0,0 +1,31 @@ +#include "arc_factored.h" + +#include <set> + +#include <boost/pending/disjoint_sets.hpp> + +using namespace std; +using namespace boost; + +// based on Trajan 1977 +void ArcFactoredForest::MaximumSpanningTree(SpanningTree* st) const { + typedef disjoint_sets_with_storage<identity_property_map, identity_property_map, + find_with_full_path_compression> DisjointSet; + DisjointSet strongly(num_words_ + 1); + DisjointSet weakly(num_words_ + 1); + set<unsigned> roots, h, rset; + vector<pair<short, short> > enter(num_words_ + 1); + for (unsigned i = 0; i <= num_words_; ++i) { + strongly.make_set(i); + weakly.make_set(i); + roots.insert(i); + } + while(!roots.empty()) { + set<unsigned>::iterator it = roots.begin(); + const unsigned k = *it; + roots.erase(it); + cerr << "k=" << k << endl; + pair<short,short> ij; // TODO = Max(k); + } +} + diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h new file mode 100644 index 00000000..e99be482 --- /dev/null +++ b/rst_parser/arc_factored.h @@ -0,0 +1,88 @@ +#ifndef _ARC_FACTORED_H_ +#define _ARC_FACTORED_H_ + +#include <iostream> +#include <cassert> +#include <vector> +#include <utility> +#include "array2d.h" +#include "sparse_vector.h" +#include "prob.h" +#include "weights.h" + +struct SpanningTree { + SpanningTree() : roots(1, -1) {} + std::vector<short> roots; // unless multiroot trees are supported, this + // will have a single member + std::vector<std::pair<short, short> > h_m_pairs; +}; + +class ArcFactoredForest { + public: + explicit ArcFactoredForest(short num_words) : + num_words_(num_words), + root_edges_(num_words), + edges_(num_words, num_words) { + for (int h = 0; h < num_words; ++h) { + for (int m = 0; m < num_words; ++m) { + edges_(h, m).h = h + 1; + edges_(h, m).m = m + 1; + } + root_edges_[h].h = 0; + root_edges_[h].m = h + 1; + } + } + + // compute the maximum spanning tree based on the current weighting + // using the O(n^2) CLE algorithm + void MaximumSpanningTree(SpanningTree* st) const; + + struct Edge { + Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {} + short h; + short m; + SparseVector<weight_t> features; + prob_t edge_prob; + }; + + const Edge& operator()(short h, short m) const { + assert(m > 0); + assert(m <= num_words_); + assert(h >= 0); + assert(h <= num_words_); + return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; + } + + Edge& operator()(short h, short m) { + assert(m > 0); + assert(m <= num_words_); + assert(h >= 0); + assert(h <= num_words_); + return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; + } + + template <class V> + void Reweight(const V& weights) { + for (int m = 0; m < num_words_; ++m) { + for (int h = 0; h < num_words_; ++h) { + if (h != m) { + Edge& e = edges_(h, m); + e.edge_prob.logeq(e.features.dot(weights)); + } + } + Edge& e = root_edges_[m]; + e.edge_prob.logeq(e.features.dot(weights)); + } + } + + private: + unsigned num_words_; + std::vector<Edge> root_edges_; + Array2D<Edge> edges_; +}; + +inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge& edge) { + return os << "(" << edge.h << " < " << edge.m << ")"; +} + +#endif diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc new file mode 100644 index 00000000..7b5af4c1 --- /dev/null +++ b/rst_parser/mst_train.cc @@ -0,0 +1,12 @@ +#include "arc_factored.h" + +#include <iostream> + +using namespace std; + +int main(int argc, char** argv) { + ArcFactoredForest af(5); + cerr << af(0,3) << endl; + return 0; +} + diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc new file mode 100644 index 00000000..f6b295b3 --- /dev/null +++ b/rst_parser/rst.cc @@ -0,0 +1,7 @@ +#include "rst.h" + +using namespace std; + +StochasticForest::StochasticForest(const ArcFactoredForest& af) { +} + diff --git a/rst_parser/rst.h b/rst_parser/rst.h new file mode 100644 index 00000000..865871eb --- /dev/null +++ b/rst_parser/rst.h @@ -0,0 +1,10 @@ +#ifndef _RST_H_ +#define _RST_H_ + +#include "arc_factored.h" + +struct StochasticForest { + explicit StochasticForest(const ArcFactoredForest& af); +}; + +#endif diff --git a/rst_parser/rst_test.cc b/rst_parser/rst_test.cc new file mode 100644 index 00000000..e8fe706e --- /dev/null +++ b/rst_parser/rst_test.cc @@ -0,0 +1,33 @@ +#include "arc_factored.h" + +#include <iostream> + +using namespace std; + +int main(int argc, char** argv) { + // John saw Mary + // (H -> M) + // (1 -> 2) 20 + // (1 -> 3) 3 + // (2 -> 1) 20 + // (2 -> 3) 30 + // (3 -> 2) 0 + // (3 -> 1) 11 + // (0, 2) 10 + // (0, 1) 9 + // (0, 3) 9 + ArcFactoredForest af(3); + af(1,2).edge_prob.logeq(20); + af(1,3).edge_prob.logeq(3); + af(2,1).edge_prob.logeq(20); + af(2,3).edge_prob.logeq(30); + af(3,2).edge_prob.logeq(0); + af(3,1).edge_prob.logeq(11); + af(0,2).edge_prob.logeq(10); + af(0,1).edge_prob.logeq(9); + af(0,3).edge_prob.logeq(9); + SpanningTree tree; + af.MaximumSpanningTree(&tree); + return 0; +} + |