diff options
-rw-r--r-- | decoder/grammar.cc | 24 | ||||
-rw-r--r-- | decoder/grammar.h | 2 | ||||
-rw-r--r-- | rst_parser/Makefile.am | 16 | ||||
-rw-r--r-- | rst_parser/arc_factored.h | 58 | ||||
-rw-r--r-- | rst_parser/mst_train.cc | 11 | ||||
-rw-r--r-- | rst_parser/rst.cc | 2 | ||||
-rw-r--r-- | rst_parser/rst.h | 7 |
7 files changed, 107 insertions, 13 deletions
diff --git a/decoder/grammar.cc b/decoder/grammar.cc index 9e4065a6..714390f0 100644 --- a/decoder/grammar.cc +++ b/decoder/grammar.cc @@ -3,12 +3,14 @@ #include <algorithm> #include <utility> #include <map> +#include <tr1/unordered_map> #include "rule_lexer.h" #include "filelib.h" #include "tdict.h" using namespace std; +using namespace std::tr1; const vector<TRulePtr> Grammar::NO_RULES; @@ -148,24 +150,24 @@ bool GlueGrammar::HasRuleForSpan(int i, int /* j */, int /* distance */) const { return (i == 0); } -PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level) : - has_rule_(input.size() + 1) { +PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level) { + unordered_set<WordID> ss; for (int i = 0; i < input.size(); ++i) { const vector<LatticeArc>& alts = input[i]; for (int k = 0; k < alts.size(); ++k) { const int j = alts[k].dist2next + i; - has_rule_[i].insert(j); const string& src = TD::Convert(alts[k].label); - TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1")); - pt->a_.push_back(AlignmentPoint(0,0)); - AddRule(pt); - RefineRule(pt, ctf_level); + if (ss.count(alts[k].label) == 0) { + TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1")); + pt->a_.push_back(AlignmentPoint(0,0)); + AddRule(pt); + RefineRule(pt, ctf_level); + ss.insert(alts[k].label); + } } } } -bool PassThroughGrammar::HasRuleForSpan(int i, int j, int /* distance */) const { - const set<int>& hr = has_rule_[i]; - if (i == j) { return !hr.empty(); } - return (hr.find(j) != hr.end()); +bool PassThroughGrammar::HasRuleForSpan(int, int, int distance) const { + return (distance < 2); } diff --git a/decoder/grammar.h b/decoder/grammar.h index f5d00817..e6a15a69 100644 --- a/decoder/grammar.h +++ b/decoder/grammar.h @@ -91,8 +91,6 @@ struct GlueGrammar : public TextGrammar { struct PassThroughGrammar : public TextGrammar { PassThroughGrammar(const Lattice& input, const std::string& cat, const unsigned int ctf_level=0); virtual bool HasRuleForSpan(int i, int j, int distance) const; - private: - std::vector<std::set<int> > has_rule_; // index by [i][j] }; void RefineRule(TRulePtr pt, const unsigned int ctf_level); diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am new file mode 100644 index 00000000..fef1c1a2 --- /dev/null +++ b/rst_parser/Makefile.am @@ -0,0 +1,16 @@ +bin_PROGRAMS = \ + mst_train + +noinst_PROGRAMS = \ + rst_test + +TESTS = rst_test + +noinst_LIBRARIES = librst.a + +librst_a_SOURCES = rst.cc + +mst_train_SOURCES = mst_train.cc +mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h new file mode 100644 index 00000000..312d7d67 --- /dev/null +++ b/rst_parser/arc_factored.h @@ -0,0 +1,58 @@ +#ifndef _ARC_FACTORED_H_ +#define _ARC_FACTORED_H_ + +#include <vector> +#include <cassert> +#include "array2d.h" +#include "sparse_vector.h" + +class ArcFactoredForest { + public: + explicit ArcFactoredForest(short num_words) : + num_words_(num_words), + root_edges_(num_words), + edges_(num_words, num_words) {} + + struct Edge { + Edge() : features(), edge_prob(prob_t::Zero()) {} + SparseVector<weight_t> features; + prob_t edge_prob; + }; + + template <class V> + void Reweight(const V& weights) { + for (int m = 0; m < num_words_; ++m) { + for (int h = 0; h < num_words_; ++h) { + if (h != m) { + Edge& e = edges_(h, m); + e.edge_prob.logeq(e.features.dot(weights)); + } + } + if (m) { + Edge& e = root_edges_[m]; + e.edge_prob.logeq(e.features.dot(weights)); + } + } + } + + const Edge& operator()(short h, short m) const { + assert(m > 0); + assert(m <= num_words_); + assert(h >= 0); + assert(h <= num_words_); + return h ? edges_(h - 1, m - 1) : root_edges[m - 1]; + } + Edge& operator()(short h, short m) { + assert(m > 0); + assert(m <= num_words_); + assert(h >= 0); + assert(h <= num_words_); + return h ? edges_(h - 1, m - 1) : root_edges[m - 1]; + } + private: + unsigned num_words_; + std::vector<Edge> root_edges_; + Array2D<Edge> edges_; +}; + +#endif diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc new file mode 100644 index 00000000..1bceaff5 --- /dev/null +++ b/rst_parser/mst_train.cc @@ -0,0 +1,11 @@ +#include "arc_factored.h" + +#include <iostream> + +using namespace std; + +int main(int argc, char** argv) { + ArcFactoredForest af(5); + return 0; +} + diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc new file mode 100644 index 00000000..0ab3e296 --- /dev/null +++ b/rst_parser/rst.cc @@ -0,0 +1,2 @@ +#include "rst.h" + diff --git a/rst_parser/rst.h b/rst_parser/rst.h new file mode 100644 index 00000000..30a1f8a4 --- /dev/null +++ b/rst_parser/rst.h @@ -0,0 +1,7 @@ +#ifndef _RST_H_ +#define _RST_H_ + +struct RandomSpanningTree { +}; + +#endif |