summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/grammar.cc24
-rw-r--r--decoder/grammar.h2
-rw-r--r--rst_parser/Makefile.am16
-rw-r--r--rst_parser/arc_factored.h58
-rw-r--r--rst_parser/mst_train.cc11
-rw-r--r--rst_parser/rst.cc2
-rw-r--r--rst_parser/rst.h7
7 files changed, 107 insertions, 13 deletions
diff --git a/decoder/grammar.cc b/decoder/grammar.cc
index 9e4065a6..714390f0 100644
--- a/decoder/grammar.cc
+++ b/decoder/grammar.cc
@@ -3,12 +3,14 @@
#include <algorithm>
#include <utility>
#include <map>
+#include <tr1/unordered_map>
#include "rule_lexer.h"
#include "filelib.h"
#include "tdict.h"
using namespace std;
+using namespace std::tr1;
const vector<TRulePtr> Grammar::NO_RULES;
@@ -148,24 +150,24 @@ bool GlueGrammar::HasRuleForSpan(int i, int /* j */, int /* distance */) const {
return (i == 0);
}
-PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level) :
- has_rule_(input.size() + 1) {
+PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level) {
+ unordered_set<WordID> ss;
for (int i = 0; i < input.size(); ++i) {
const vector<LatticeArc>& alts = input[i];
for (int k = 0; k < alts.size(); ++k) {
const int j = alts[k].dist2next + i;
- has_rule_[i].insert(j);
const string& src = TD::Convert(alts[k].label);
- TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1"));
- pt->a_.push_back(AlignmentPoint(0,0));
- AddRule(pt);
- RefineRule(pt, ctf_level);
+ if (ss.count(alts[k].label) == 0) {
+ TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1"));
+ pt->a_.push_back(AlignmentPoint(0,0));
+ AddRule(pt);
+ RefineRule(pt, ctf_level);
+ ss.insert(alts[k].label);
+ }
}
}
}
-bool PassThroughGrammar::HasRuleForSpan(int i, int j, int /* distance */) const {
- const set<int>& hr = has_rule_[i];
- if (i == j) { return !hr.empty(); }
- return (hr.find(j) != hr.end());
+bool PassThroughGrammar::HasRuleForSpan(int, int, int distance) const {
+ return (distance < 2);
}
diff --git a/decoder/grammar.h b/decoder/grammar.h
index f5d00817..e6a15a69 100644
--- a/decoder/grammar.h
+++ b/decoder/grammar.h
@@ -91,8 +91,6 @@ struct GlueGrammar : public TextGrammar {
struct PassThroughGrammar : public TextGrammar {
PassThroughGrammar(const Lattice& input, const std::string& cat, const unsigned int ctf_level=0);
virtual bool HasRuleForSpan(int i, int j, int distance) const;
- private:
- std::vector<std::set<int> > has_rule_; // index by [i][j]
};
void RefineRule(TRulePtr pt, const unsigned int ctf_level);
diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am
new file mode 100644
index 00000000..fef1c1a2
--- /dev/null
+++ b/rst_parser/Makefile.am
@@ -0,0 +1,16 @@
+bin_PROGRAMS = \
+ mst_train
+
+noinst_PROGRAMS = \
+ rst_test
+
+TESTS = rst_test
+
+noinst_LIBRARIES = librst.a
+
+librst_a_SOURCES = rst.cc
+
+mst_train_SOURCES = mst_train.cc
+mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm
diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h
new file mode 100644
index 00000000..312d7d67
--- /dev/null
+++ b/rst_parser/arc_factored.h
@@ -0,0 +1,58 @@
+#ifndef _ARC_FACTORED_H_
+#define _ARC_FACTORED_H_
+
+#include <vector>
+#include <cassert>
+#include "array2d.h"
+#include "sparse_vector.h"
+
+class ArcFactoredForest {
+ public:
+ explicit ArcFactoredForest(short num_words) :
+ num_words_(num_words),
+ root_edges_(num_words),
+ edges_(num_words, num_words) {}
+
+ struct Edge {
+ Edge() : features(), edge_prob(prob_t::Zero()) {}
+ SparseVector<weight_t> features;
+ prob_t edge_prob;
+ };
+
+ template <class V>
+ void Reweight(const V& weights) {
+ for (int m = 0; m < num_words_; ++m) {
+ for (int h = 0; h < num_words_; ++h) {
+ if (h != m) {
+ Edge& e = edges_(h, m);
+ e.edge_prob.logeq(e.features.dot(weights));
+ }
+ }
+ if (m) {
+ Edge& e = root_edges_[m];
+ e.edge_prob.logeq(e.features.dot(weights));
+ }
+ }
+ }
+
+ const Edge& operator()(short h, short m) const {
+ assert(m > 0);
+ assert(m <= num_words_);
+ assert(h >= 0);
+ assert(h <= num_words_);
+ return h ? edges_(h - 1, m - 1) : root_edges[m - 1];
+ }
+ Edge& operator()(short h, short m) {
+ assert(m > 0);
+ assert(m <= num_words_);
+ assert(h >= 0);
+ assert(h <= num_words_);
+ return h ? edges_(h - 1, m - 1) : root_edges[m - 1];
+ }
+ private:
+ unsigned num_words_;
+ std::vector<Edge> root_edges_;
+ Array2D<Edge> edges_;
+};
+
+#endif
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc
new file mode 100644
index 00000000..1bceaff5
--- /dev/null
+++ b/rst_parser/mst_train.cc
@@ -0,0 +1,11 @@
+#include "arc_factored.h"
+
+#include <iostream>
+
+using namespace std;
+
+int main(int argc, char** argv) {
+ ArcFactoredForest af(5);
+ return 0;
+}
+
diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc
new file mode 100644
index 00000000..0ab3e296
--- /dev/null
+++ b/rst_parser/rst.cc
@@ -0,0 +1,2 @@
+#include "rst.h"
+
diff --git a/rst_parser/rst.h b/rst_parser/rst.h
new file mode 100644
index 00000000..30a1f8a4
--- /dev/null
+++ b/rst_parser/rst.h
@@ -0,0 +1,7 @@
+#ifndef _RST_H_
+#define _RST_H_
+
+struct RandomSpanningTree {
+};
+
+#endif