From 34785db78a0ad12f0fe74d98924acc20a8cab79a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 27 Mar 2014 00:07:41 -0400 Subject: breadth first iterator for tree fragment --- decoder/tree2string_translator.cc | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index ac9c0d74..1c249836 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -13,25 +13,6 @@ using namespace std; -// root: S -// A implication: (S [A] *INCOMPLETE* -// B implication: (S [A] [B] *INCOMPLETE* -// *0* implication: (S _[A] [B]) -// a implication: (S (A a *INCOMPLETE* [B]) -// a implication: (S (A a a *INCOMPLETE* [B]) -// *0* implication: (S (A a a) _[B]) -// D implication: (S (A a a) (B [D] *INCOMPLETE*) -// *0* implication: (S (A a a) (B _[D])) -// d implication: (S (A a a) (B (D d *INCOMPLETE*)) -// *0* implication: (S (A a a) (B (D d))) -// --there are no further outgoing links possible-- - -// root: S -// A implication: (S [A] *INCOMPLETE* -// B implication: (S [A] [B] *INCOMPLETE* -// *0* implication: (S _[A] [B]) -// *0* implication: (S [A] _[B]) -// b implication: (S [A] (B b *INCOMPLETE*)) struct Tree2StringGrammarNode { map next; string rules; @@ -44,8 +25,19 @@ void ReadTree2StringGrammar(istream* in, unordered_map 3); - if (line[pos - 1] == ' ') --pos; + unsigned xc = 0; + while (line[pos - 1] == ' ') { --pos; xc++; } cdec::TreeFragment rule_src(line.substr(0, pos), true); + Tree2StringGrammarNode* cur = &roots[rule_src.root]; + for (auto sym : rule_src) + cur = &cur->next[sym]; + pos += 3 + xc; + while(line[pos] == ' ') { ++pos; } + size_t pos2 = line.find("|||", pos); + assert(pos2 != string::npos); + while (line[pos2 - 1] == ' ') { --pos2; } + cur->rules = line.substr(pos, pos2 - pos); + cerr << "OUTPUT = '" << cur->rules << "'\n"; } } -- cgit v1.2.3 From 8372086f2fc4bd765fdd05e8cf95faeb147a6587 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 30 Mar 2014 23:50:17 -0400 Subject: almost complete tree to string translator --- decoder/Makefile.am | 3 + decoder/decoder.cc | 6 +- decoder/t2s_test.cc | 110 ++++++++++++++++++++++++++++++++++ decoder/tree2string_translator.cc | 120 +++++++++++++++++++++++++++++++++----- decoder/tree_fragment.cc | 14 +++-- decoder/tree_fragment.h | 109 ++++++++++++++++++++++++---------- 6 files changed, 311 insertions(+), 51 deletions(-) create mode 100644 decoder/t2s_test.cc (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 7481192b..5c91fe65 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -4,9 +4,12 @@ noinst_PROGRAMS = \ trule_test \ hg_test \ parser_test \ + t2s_test \ grammar_test TESTS = trule_test parser_test grammar_test hg_test +t2s_test_SOURCES = t2s_test.cc +t2s_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a parser_test_SOURCES = parser_test.cc parser_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a grammar_test_SOURCES = grammar_test.cc diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 31049216..43e2640d 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -490,8 +490,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } formalism = LowercaseString(str("formalism",conf)); - if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; + if (formalism != "t2s" && formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 't2s', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -626,6 +626,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream // set up translation back end if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); + else if (formalism == "t2s") + translator.reset(new Tree2StringTranslator(conf)); else if (formalism == "fst") translator.reset(new FSTTranslator(conf)); else if (formalism == "pb") diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc new file mode 100644 index 00000000..3c46ea89 --- /dev/null +++ b/decoder/t2s_test.cc @@ -0,0 +1,110 @@ +#include "tree_fragment.h" + +#define BOOST_TEST_MODULE T2STest +#include +#include +#include +#include "tdict.h" + +using namespace std; + +BOOST_AUTO_TEST_CASE(TestTreeFragments) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + cdec::TreeFragment tree2("(S (NP (DT a) (NN cat)) (VP (V ate) (NP (DT the) (NN cake pie))))"); + vector a, b; + vector aw, bw; + cerr << "TREE1: " << tree << endl; + cerr << "TREE2: " << tree2 << endl; + for (auto& sym : tree) + if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym); + for (auto& sym : tree2) + if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym); + BOOST_CHECK_EQUAL(a.size(), b.size()); + BOOST_CHECK_EQUAL(aw.size() + 1, bw.size()); + BOOST_CHECK_EQUAL(aw.size(), 5); + BOOST_CHECK_EQUAL(TD::GetString(aw), "the boy saw a cat"); + BOOST_CHECK_EQUAL(TD::GetString(bw), "a cat ate the cake pie"); + if (a != b) { + BOOST_CHECK_EQUAL(1,2); + } + + string nts; + for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { + if (cdec::IsNT(*it)) { + if (cdec::IsRHS(*it)) it.truncate(); + if (nts.size()) nts += " "; + if (cdec::IsLHS(*it)) nts += "("; + nts += TD::Convert(*it & cdec::ALL_MASK); + if (cdec::IsFrontier(*it)) nts += "*"; + } + } + BOOST_CHECK_EQUAL(nts, "(S NP* VP*"); + + nts.clear(); + int ntc = 0; + for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { + if (cdec::IsNT(*it)) { + if (cdec::IsRHS(*it)) { + ++ntc; + if (ntc > 1) it.truncate(); + } + if (nts.size()) nts += " "; + if (cdec::IsLHS(*it)) nts += "("; + nts += TD::Convert(*it & cdec::ALL_MASK); + if (cdec::IsFrontier(*it)) nts += "*"; + } + } + BOOST_CHECK_EQUAL(nts, "(S NP VP* (NP DT* NN*"); +} + +BOOST_AUTO_TEST_CASE(TestSharing) { + cdec::TreeFragment rule1("(S [NP] [VP])", true); + cdec::TreeFragment rule2("(S [NP] (VP [V] [NP]))", true); + string r1,r2; + for (auto sym : rule1) { + if (r1.size()) r1 += " "; + if (cdec::IsLHS(sym)) r1 += "("; + r1 += TD::Convert(sym & cdec::ALL_MASK); + if (cdec::IsFrontier(sym)) r1 += "*"; + } + for (auto sym : rule2) { + if (r2.size()) r2 += " "; + if (cdec::IsLHS(sym)) r2 += "("; + r2 += TD::Convert(sym & cdec::ALL_MASK); + if (cdec::IsFrontier(sym)) r2 += "*"; + } + cerr << rule1 << endl; + cerr << r1 << endl; + cerr << rule2 << endl; + cerr << r2 << endl; + BOOST_CHECK_EQUAL(r1, "(S NP* VP*"); + BOOST_CHECK_EQUAL(r2, "(S NP* VP (VP V* NP*"); +} + +BOOST_AUTO_TEST_CASE(TestEndInvariants) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + BOOST_CHECK(tree.end().at_end()); + BOOST_CHECK(!tree.begin().at_end()); +} + +BOOST_AUTO_TEST_CASE(TestBegins) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + for (auto it = tree.begin(1); it != tree.end(); ++it) { + cerr << TD::Convert(*it & cdec::ALL_MASK) << endl; + } +} + +BOOST_AUTO_TEST_CASE(TestRemainder) { + cdec::TreeFragment tree("(S (A a) (B b))"); + auto it = tree.begin(); + ++it; + BOOST_CHECK(cdec::IsRHS(*it)); + cerr << tree << endl; + auto itr = it.remainder(); + while(itr != tree.end()) { + cerr << TD::Convert(*itr & cdec::ALL_MASK) << endl; + ++itr; + } +} + + diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 1c249836..cd6ee550 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -1,5 +1,6 @@ #include #include +#include #include #include #include "tree_fragment.h" @@ -15,11 +16,10 @@ using namespace std; struct Tree2StringGrammarNode { map next; - string rules; + vector rules; }; -void ReadTree2StringGrammar(istream* in, unordered_map* proots) { - unordered_map& roots = *proots; +void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { string line; while(getline(*in, line)) { size_t pos = line.find("|||"); @@ -28,32 +28,124 @@ void ReadTree2StringGrammar(istream* in, unordered_map frhs; + for (auto sym : rule_src) { cur = &cur->next[sym]; + if (sym) { + if (cdec::IsFrontier(sym)) { // frontier symbols -> variables + int nt = (sym & cdec::ALL_MASK); + frhs.push_back(-nt); + } else if (cdec::IsTerminal(sym)) { + frhs.push_back(sym); + } + } + } + os << '[' << TD::Convert(-lhs) << "] |||"; + for (auto x : frhs) { + os << ' '; + if (x < 0) + os << '[' << TD::Convert(-x) << ']'; + else + os << TD::Convert(x); + } pos += 3 + xc; while(line[pos] == ' ') { ++pos; } - size_t pos2 = line.find("|||", pos); - assert(pos2 != string::npos); - while (line[pos2 - 1] == ' ') { --pos2; } - cur->rules = line.substr(pos, pos2 - pos); - cerr << "OUTPUT = '" << cur->rules << "'\n"; + os << " ||| " << line.substr(pos); + TRulePtr rule(new TRule(os.str())); + cur->rules.push_back(rule); } } +struct ParserState { + ParserState() : in_iter(), node() {} + cdec::TreeFragment::iterator in_iter; + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const int rt) : + in_iter(it), + root_type(rt), + node(n) {} + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : + in_iter(it), + future_work(p.future_work), + root_type(p.root_type), + node(n) {} + vector future_work; + int root_type; // lhs of top level NT + Tree2StringGrammarNode* node; +}; + struct Tree2StringTranslatorImpl { - unordered_map roots; // root['S'] gives rule network for S rules + Tree2StringGrammarNode root; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { ReadFile rf(conf["grammar"].as>()[0]); - ReadTree2StringGrammar(rf.stream(), &roots); + ReadTree2StringGrammar(rf.stream(), &root); } bool Translate(const string& input, SentenceMetadata* smeta, const vector& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); - cerr << "Tree2StringTranslatorImpl: please implement this!\n"; - return false; + const int kS = -TD::Convert("S"); + Hypergraph hg; + queue q; + q.push(ParserState(input_tree.begin(), &root, kS)); + while(!q.empty()) { + ParserState& s = q.front(); + + if (s.in_iter.at_end()) { // completed a traversal of a subtree + cerr << "I traversed a subtree of the input...\n"; + if (s.node->rules.size()) { + // TODO: build hypergraph + for (auto& r : s.node->rules) + cerr << "I can build: " << r->AsString() << endl; + for (auto& w : s.future_work) + q.push(w); + } else { + cerr << "I can't build anything :(\n"; + } + } else { // more input tree to match + unsigned sym = *s.in_iter; + if (cdec::IsLHS(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + q.push(ParserState(++s.in_iter, &nit->second, s)); + } + } else if (cdec::IsRHS(sym)) { + //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + cdec::TreeFragment::iterator var = s.in_iter; + var.truncate(); + auto nit1 = s.node->next.find(sym); + auto nit2 = s.node->next.find(*var); + if (nit2 != s.node->next.end()) { + //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + ParserState new_s(++var, &nit2->second, s); + ParserState new_work(s.in_iter.remainder(), &root, -(sym & cdec::ALL_MASK)); + new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q + q.push(new_s); + } + if (nit1 != s.node->next.end()) { + //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + q.push(ParserState(++s.in_iter, &nit1->second, s)); + } + } else if (cdec::IsTerminal(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; + q.push(ParserState(++s.in_iter, &nit->second, s)); + } + } else { + cerr << "This can never happen!\n"; abort(); + } + } + q.pop(); + } + minus_lm_forest->swap(hg); } }; diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 93aad64e..78a993b8 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -36,7 +36,7 @@ void TreeFragment::DebugRec(unsigned cur, ostream* out) const { *out << ' '; if (IsFrontier(x)) { *out << '[' << TD::Convert(x & ALL_MASK) << ']'; - } else if (IsInternalNT(x)) { + } else if (IsRHS(x)) { DebugRec(x & ALL_MASK, out); } else { // must be terminal *out << TD::Convert(x); @@ -66,7 +66,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned // recursively call parser to deal with constituent ParseRec(tree, afs, cp, symp, np, &cp, &symp, &np); unsigned ind = np - 1; - rhs.push_back(ind | NT_BIT); + rhs.push_back(ind | RHS_BIT); } else { // deal with terminal / nonterminal substitution ++symp; assert(tree[cp] != ' '); @@ -95,7 +95,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned } // continuent has completed, cp is at ), build node const unsigned j = symp; // span from (i,j) // add an internal non-terminal symbol - const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | NT_BIT; + const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | RHS_BIT; nodes[np] = TreeFragmentProduction(nt, rhs); //cerr << np << " production(" << i << "," << j << ")= " << TD::Convert(nt & ALL_MASK) << " -->"; //for (auto& x : rhs) { @@ -113,11 +113,15 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned } BreadthFirstIterator TreeFragment::begin() const { - return BreadthFirstIterator(this); + return BreadthFirstIterator(this, nodes.size() - 1); +} + +BreadthFirstIterator TreeFragment::begin(unsigned node_idx) const { + return BreadthFirstIterator(this, node_idx); } BreadthFirstIterator TreeFragment::end() const { - return BreadthFirstIterator(this, 0); + return BreadthFirstIterator(this); } } diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index a38dbdfa..b83afc27 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -1,7 +1,7 @@ #ifndef TREE_FRAGMENT #define TREE_FRAGMENT -#include +#include #include #include #include @@ -12,18 +12,32 @@ namespace cdec { class BreadthFirstIterator; -static const unsigned NT_BIT = 0x40000000u; -static const unsigned FRONTIER_BIT = 0x80000000u; -static const unsigned ALL_MASK = 0x0FFFFFFFu; +static const unsigned LHS_BIT = 0x10000000u; +static const unsigned RHS_BIT = 0x20000000u; +static const unsigned FRONTIER_BIT = 0x40000000u; +static const unsigned RESERVED_BIT = 0x80000000u; +static const unsigned ALL_MASK = 0x0FFFFFFFu; -inline bool IsInternalNT(unsigned x) { - return (x & NT_BIT); +inline bool IsNT(unsigned x) { + return (x & (LHS_BIT | RHS_BIT | FRONTIER_BIT)); +} + +inline bool IsLHS(unsigned x) { + return (x & LHS_BIT); +} + +inline bool IsRHS(unsigned x) { + return (x & RHS_BIT); } inline bool IsFrontier(unsigned x) { return (x & FRONTIER_BIT); } +inline bool IsTerminal(unsigned x) { + return (x & ALL_MASK) == x; +} + struct TreeFragmentProduction { TreeFragmentProduction() {} TreeFragmentProduction(int nttype, const std::vector& r) : lhs(nttype), rhs(r) {} @@ -46,6 +60,7 @@ class TreeFragment { typedef const unsigned & reference; iterator begin() const; + iterator begin(unsigned node_idx) const; iterator end() const; private: @@ -62,24 +77,28 @@ class TreeFragment { }; struct TFIState { - TFIState() : node(), rhspos() {} - TFIState(unsigned n, unsigned p) : node(n), rhspos(p) {} - bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos; } - bool operator!=(const TFIState& o) const { return node != o.node && rhspos != o.rhspos; } + TFIState() : node(), rhspos(), state() {} + TFIState(unsigned n, unsigned p, unsigned s) : node(n), rhspos(p), state(s) {} + bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; } + bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; } unsigned short node; unsigned short rhspos; + unsigned char state; }; class BreadthFirstIterator : public std::iterator { const TreeFragment* tf_; - std::queue q_; + std::deque q_; unsigned sym; public: - explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) { - q_.push(TFIState(tf->nodes.size() - 1, 0)); + BreadthFirstIterator() : tf_(), sym() {} + // used for begin + explicit BreadthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) { + q_.push_back(TFIState(node_idx, 0, 0)); Stage(); } - BreadthFirstIterator(const TreeFragment* tf, int) : tf_(tf) {} + // used for end + explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) {} const unsigned& operator*() const { return sym; } const unsigned* operator->() const { return &sym; } bool operator==(const BreadthFirstIterator& other) const { @@ -88,26 +107,20 @@ class BreadthFirstIterator : public std::iteratornodes[s.node].rhs[s.rhspos]; - if (IsInternalNT(sym)) { - q_.push(TFIState(sym & ALL_MASK, 0)); - sym = tf_->nodes[sym & ALL_MASK].lhs; - } - } const BreadthFirstIterator& operator++() { TFIState& s = q_.front(); - const unsigned len = tf_->nodes[s.node].rhs.size(); - s.rhspos++; - if (s.rhspos > len) { - q_.pop(); + if (s.state == 0) { + s.state++; Stage(); - } else if (s.rhspos == len) { - sym = 0; } else { - Stage(); + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos >= len) { + q_.pop_front(); + Stage(); + } else { + Stage(); + } } return *this; } @@ -116,6 +129,42 @@ class BreadthFirstIterator : public std::iteratornodes[s.node].lhs & ALL_MASK) | LHS_BIT; + } else { + sym = tf_->nodes[s.node].rhs[s.rhspos]; + if (IsRHS(sym)) { + q_.push_back(TFIState(sym & ALL_MASK, 0, 0)); + sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT; + } + } + } + + // used by remainder + BreadthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) { + q_.push_back(s); + Stage(); + } }; inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) { -- cgit v1.2.3 From 8dc828ac79e14179e90280b4255449f620550e63 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 1 Apr 2014 00:19:19 -0400 Subject: minimally tested t2s translator --- decoder/tree2string_translator.cc | 48 +++++++++++++++++++++++++++++---------- decoder/tree_fragment.h | 1 + 2 files changed, 37 insertions(+), 12 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index cd6ee550..09eca147 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -65,17 +65,17 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { struct ParserState { ParserState() : in_iter(), node() {} cdec::TreeFragment::iterator in_iter; - ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const int rt) : + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n) : in_iter(it), - root_type(rt), + input_node_idx(it.node_idx()), node(n) {} ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : in_iter(it), future_work(p.future_work), - root_type(p.root_type), + input_node_idx(p.input_node_idx), node(n) {} vector future_work; - int root_type; // lhs of top level NT + int input_node_idx; // lhs of top level NT Tree2StringGrammarNode* node; }; @@ -90,23 +90,40 @@ struct Tree2StringTranslatorImpl { const vector& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); - const int kS = -TD::Convert("S"); Hypergraph hg; + hg.ReserveNodes(input_tree.nodes.size()); + vector tree2hg(input_tree.nodes.size() + 1, -1); queue q; - q.push(ParserState(input_tree.begin(), &root, kS)); + q.push(ParserState(input_tree.begin(), &root)); + unsigned tree_top = q.front().input_node_idx; while(!q.empty()) { ParserState& s = q.front(); if (s.in_iter.at_end()) { // completed a traversal of a subtree - cerr << "I traversed a subtree of the input...\n"; + //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << + // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { - // TODO: build hypergraph - for (auto& r : s.node->rules) - cerr << "I can build: " << r->AsString() << endl; + TailNodeVector tail; + int& node_id = tree2hg[s.input_node_idx]; + if (node_id < 0) + node_id = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK))->id_; + for (auto& n : s.future_work) { + int& nix = tree2hg[n.input_node_idx]; + if (nix < 0) + nix = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK))->id_; + tail.push_back(nix); + } + for (auto& r : s.node->rules) { + assert(tail.size() == r->Arity()); + HG::Edge* new_edge = hg.AddEdge(r, tail); + new_edge->feature_values_ = r->GetFeatureValues(); + // TODO: set i and j + hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); + } for (auto& w : s.future_work) q.push(w); } else { - cerr << "I can't build anything :(\n"; + //cerr << "I can't build anything :(\n"; } } else { // more input tree to match unsigned sym = *s.in_iter; @@ -125,7 +142,7 @@ struct Tree2StringTranslatorImpl { if (nit2 != s.node->next.end()) { //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; ParserState new_s(++var, &nit2->second, s); - ParserState new_work(s.in_iter.remainder(), &root, -(sym & cdec::ALL_MASK)); + ParserState new_work(s.in_iter.remainder(), &root); new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q q.push(new_s); } @@ -145,7 +162,14 @@ struct Tree2StringTranslatorImpl { } q.pop(); } + int goal = tree2hg[tree_top]; + if (goal < 0) return false; + //cerr << "Goal node: " << goal << endl; + hg.TopologicallySortNodesAndEdges(goal); + hg.Reweight(weights); + //hg.PrintGraphviz(); minus_lm_forest->swap(hg); + return true; } }; diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index b83afc27..ceb7fa60 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -107,6 +107,7 @@ class BreadthFirstIterator : public std::iterator Date: Tue, 1 Apr 2014 00:53:00 -0400 Subject: tree2string test, fix for edge case --- decoder/tree2string_translator.cc | 6 ++++++ tests/system_tests/t2s/cdec.ini | 2 ++ tests/system_tests/t2s/gold.statistics | 3 +++ tests/system_tests/t2s/gold.stdout | 1 + tests/system_tests/t2s/grammar.t2s | 8 ++++++++ tests/system_tests/t2s/input.txt | 1 + tests/system_tests/t2s/weights | 6 ++++++ 7 files changed, 27 insertions(+) create mode 100644 tests/system_tests/t2s/cdec.ini create mode 100644 tests/system_tests/t2s/gold.statistics create mode 100644 tests/system_tests/t2s/gold.stdout create mode 100644 tests/system_tests/t2s/grammar.t2s create mode 100644 tests/system_tests/t2s/input.txt create mode 100644 tests/system_tests/t2s/weights (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 09eca147..7bc49132 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -167,6 +167,12 @@ struct Tree2StringTranslatorImpl { //cerr << "Goal node: " << goal << endl; hg.TopologicallySortNodesAndEdges(goal); hg.Reweight(weights); + + // there might be nodes that cannot be derived + // the following takes care of them + vector prune(hg.edges_.size(), false); + hg.PruneEdges(prune, true); + //hg.PrintGraphviz(); minus_lm_forest->swap(hg); return true; diff --git a/tests/system_tests/t2s/cdec.ini b/tests/system_tests/t2s/cdec.ini new file mode 100644 index 00000000..ad83438f --- /dev/null +++ b/tests/system_tests/t2s/cdec.ini @@ -0,0 +1,2 @@ +formalism=t2s +grammar=grammar.t2s diff --git a/tests/system_tests/t2s/gold.statistics b/tests/system_tests/t2s/gold.statistics new file mode 100644 index 00000000..452cc93e --- /dev/null +++ b/tests/system_tests/t2s/gold.statistics @@ -0,0 +1,3 @@ +-lm_nodes 6 +-lm_edges 8 +-lm_paths 4 diff --git a/tests/system_tests/t2s/gold.stdout b/tests/system_tests/t2s/gold.stdout new file mode 100644 index 00000000..afb11818 --- /dev/null +++ b/tests/system_tests/t2s/gold.stdout @@ -0,0 +1 @@ +qiangshou bei jingfang jibi . diff --git a/tests/system_tests/t2s/grammar.t2s b/tests/system_tests/t2s/grammar.t2s new file mode 100644 index 00000000..2e6cf68c --- /dev/null +++ b/tests/system_tests/t2s/grammar.t2s @@ -0,0 +1,8 @@ +(S [NP-C] [VP] (PUNC .)) ||| [1] [2] . ||| R1=1 +(NP-C (DT the) (NN gunman)) ||| qiangshou ||| R2=1 +(NP-C (DT the) [NN]) ||| [1] ||| R2a=1 +(NN gunman) ||| qiangshou ||| R2b=1 +(VP (VBD was) (VP-C [VBN] (PP (IN by) [NP-C]))) ||| bei [2] [1] ||| R3=1 +(NP-C (DT the) (NN police)) ||| jingfang ||| R4=1 +(VBN killed) ||| jibi ||| R5=1 +(VBN killed) ||| killed' ||| R6=1 diff --git a/tests/system_tests/t2s/input.txt b/tests/system_tests/t2s/input.txt new file mode 100644 index 00000000..b8fe314e --- /dev/null +++ b/tests/system_tests/t2s/input.txt @@ -0,0 +1 @@ +(S (NP-C (DT the) (NN gunman)) (VP (VBD was) (VP-C (VBN killed) (PP (IN by) (NP-C (DT the) (NN police))))) (PUNC .)) diff --git a/tests/system_tests/t2s/weights b/tests/system_tests/t2s/weights new file mode 100644 index 00000000..4980db45 --- /dev/null +++ b/tests/system_tests/t2s/weights @@ -0,0 +1,6 @@ +R1 1 +R2a 1 +R2b 1 +R3 1 +R5 1 +R4 1 -- cgit v1.2.3 From 241a9932588563f7952f7d758e3f77d8c499443c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 1 Apr 2014 18:47:20 -0400 Subject: deal with multiple grammars in t2s --- decoder/tree2string_translator.cc | 80 +++++++++++++++++++++++++++++---------- 1 file changed, 61 insertions(+), 19 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 7bc49132..6966ccf8 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -1,8 +1,10 @@ #include #include #include +#include +#include +#include #include -#include #include "tree_fragment.h" #include "translator.h" #include "hg.h" @@ -74,16 +76,43 @@ struct ParserState { future_work(p.future_work), input_node_idx(p.input_node_idx), node(n) {} - vector future_work; + bool operator==(const ParserState& o) const { + return node == o.node && input_node_idx == o.input_node_idx && + future_work == o.future_work && in_iter == o.in_iter; + } + vector future_work; int input_node_idx; // lhs of top level NT Tree2StringGrammarNode* node; }; +namespace std { + template<> + struct hash { + size_t operator()(const ParserState& s) const { + size_t h = boost::hash_range(s.future_work.begin(), s.future_work.end()); + boost::hash_combine(h, boost::hash_value(s.node)); + boost::hash_combine(h, boost::hash_value(s.input_node_idx)); + //boost::hash_combine(h, ); + return h; + } + }; +}; + struct Tree2StringTranslatorImpl { - Tree2StringGrammarNode root; - Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { - ReadFile rf(conf["grammar"].as>()[0]); - ReadTree2StringGrammar(rf.stream(), &root); + vector> root; + bool add_pass_through_rules; + Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) : + add_pass_through_rules(conf.count("add_pass_through_rules")) { + if (conf.count("grammar")) { + const vector gf = conf["grammar"].as>(); + root.resize(gf.size()); + unsigned gc = 0; + for (auto& f : gf) { + ReadFile rf(f); + root[gc].reset(new Tree2StringGrammarNode); + ReadTree2StringGrammar(rf.stream(), &*root[gc++]); + } + } } bool Translate(const string& input, SentenceMetadata* smeta, @@ -94,7 +123,11 @@ struct Tree2StringTranslatorImpl { hg.ReserveNodes(input_tree.nodes.size()); vector tree2hg(input_tree.nodes.size() + 1, -1); queue q; - q.push(ParserState(input_tree.begin(), &root)); + unordered_set unique; // only create items one time + for (auto& g : root) { + q.push(ParserState(input_tree.begin(), g.get())); + unique.insert(q.back()); + } unsigned tree_top = q.front().input_node_idx; while(!q.empty()) { ParserState& s = q.front(); @@ -103,14 +136,14 @@ struct Tree2StringTranslatorImpl { //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { - TailNodeVector tail; int& node_id = tree2hg[s.input_node_idx]; if (node_id < 0) node_id = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK))->id_; - for (auto& n : s.future_work) { - int& nix = tree2hg[n.input_node_idx]; + TailNodeVector tail; + for (auto n : s.future_work) { + int& nix = tree2hg[n]; if (nix < 0) - nix = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK))->id_; + nix = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK))->id_; tail.push_back(nix); } for (auto& r : s.node->rules) { @@ -120,8 +153,13 @@ struct Tree2StringTranslatorImpl { // TODO: set i and j hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); } - for (auto& w : s.future_work) - q.push(w); + for (auto n : s.future_work) { + const auto it = input_tree.begin(n); // start tree iterator at node n + for (auto& g : root) { + ParserState s(it, g.get()); + if (unique.insert(s).second) q.push(s); + } + } } else { //cerr << "I can't build anything :(\n"; } @@ -131,7 +169,8 @@ struct Tree2StringTranslatorImpl { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; - q.push(ParserState(++s.in_iter, &nit->second, s)); + ParserState news(++s.in_iter, &nit->second, s); + if (unique.insert(news).second) q.push(news); } } else if (cdec::IsRHS(sym)) { //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; @@ -141,20 +180,23 @@ struct Tree2StringTranslatorImpl { auto nit2 = s.node->next.find(*var); if (nit2 != s.node->next.end()) { //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; - ParserState new_s(++var, &nit2->second, s); - ParserState new_work(s.in_iter.remainder(), &root); + ++var; + const unsigned new_work = s.in_iter.child_node(); + ParserState new_s(var, &nit2->second, s); new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q - q.push(new_s); + if (unique.insert(new_s).second) q.push(new_s); } if (nit1 != s.node->next.end()) { //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; - q.push(ParserState(++s.in_iter, &nit1->second, s)); + const ParserState new_s(++s.in_iter, &nit1->second, s); + if (unique.insert(new_s).second) q.push(new_s); } } else if (cdec::IsTerminal(sym)) { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; - q.push(ParserState(++s.in_iter, &nit->second, s)); + const ParserState new_s(++s.in_iter, &nit->second, s); + if (unique.insert(new_s).second) q.push(new_s); } } else { cerr << "This can never happen!\n"; abort(); -- cgit v1.2.3 From b6925f42fb7518fa1a201fe20596f8d4036a3b80 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 1 Apr 2014 21:24:11 -0400 Subject: check for empty hg --- decoder/tree2string_translator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 6966ccf8..6f65658e 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -214,7 +214,7 @@ struct Tree2StringTranslatorImpl { // the following takes care of them vector prune(hg.edges_.size(), false); hg.PruneEdges(prune, true); - + if (hg.edges_.size() == 0) return false; //hg.PrintGraphviz(); minus_lm_forest->swap(hg); return true; -- cgit v1.2.3 From 32dcedf28adef39ddd07aaa1ad49e3e73e50b98e Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 1 Apr 2014 21:52:15 -0400 Subject: deal with pass through rules --- decoder/tree2string_translator.cc | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 6f65658e..4cd584fb 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -101,6 +101,7 @@ namespace std { struct Tree2StringTranslatorImpl { vector> root; bool add_pass_through_rules; + unsigned remove_grammars; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) : add_pass_through_rules(conf.count("add_pass_through_rules")) { if (conf.count("grammar")) { @@ -114,11 +115,56 @@ struct Tree2StringTranslatorImpl { } } } + + void CreatePassThroughRules(const cdec::TreeFragment& tree) { + static const int kFID = FD::Convert("PassThrough"); + root.resize(root.size() + 1); + root.back().reset(new Tree2StringGrammarNode); + ++remove_grammars; + for (auto& prod : tree.nodes) { + ostringstream os; + vector rhse, rhsf; + int ntc = 0; + int lhs = -(prod.lhs & cdec::ALL_MASK); + os << '(' << TD::Convert(-lhs); + for (auto& sym : prod.rhs) { + os << ' '; + if (cdec::IsTerminal(sym)) { + os << TD::Convert(sym); + rhse.push_back(sym); + rhsf.push_back(sym); + } else { + unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; + os << '[' << TD::Convert(id) << ']'; + rhsf.push_back(-id); + rhse.push_back(-ntc); + ++ntc; + } + } + os << ')'; + cdec::TreeFragment rule_src(os.str(), true); + Tree2StringGrammarNode* cur = root.back().get(); + for (auto sym : rule_src) + cur = &cur->next[sym]; + TRulePtr rule(new TRule(rhse, rhsf, lhs)); + rule->ComputeArity(); + rule->scores_.set_value(kFID, 1.0); + cur->rules.push_back(rule); + } + } + + void RemoveGrammars() { + assert(remove_grammars < root.size()); + root.resize(root.size() - remove_grammars); + } + bool Translate(const string& input, SentenceMetadata* smeta, const vector& weights, Hypergraph* minus_lm_forest) { + remove_grammars = 0; cdec::TreeFragment input_tree(input, false); + if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; hg.ReserveNodes(input_tree.nodes.size()); vector tree2hg(input_tree.nodes.size() + 1, -1); @@ -235,6 +281,7 @@ void Tree2StringTranslator::ProcessMarkupHintsImpl(const map& kv } void Tree2StringTranslator::SentenceCompleteImpl() { + pimpl_->RemoveGrammars(); } std::string Tree2StringTranslator::GetDecoderType() const { -- cgit v1.2.3 From e32e9fdd48ef6466fbb257d92e250816f5b69114 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 7 Apr 2014 00:54:52 -0400 Subject: clean up dead TRule code --- decoder/bottom_up_parser.cc | 2 +- decoder/grammar_test.cc | 6 +- decoder/hg_test.cc | 60 ++++++---- decoder/hg_test.h | 69 ++++------- decoder/rule_lexer.h | 1 + decoder/rule_lexer.ll | 47 ++++++-- decoder/scfg_translator.cc | 2 +- decoder/test_data/hg_test.hg | 1 + decoder/test_data/hg_test.hg_balanced | 1 + decoder/test_data/hg_test.hg_int | 1 + decoder/test_data/hg_test.lattice | 1 + decoder/test_data/hg_test.tiny | 1 + decoder/test_data/hg_test.tiny_lattice | 1 + decoder/test_data/small.json.gz | Bin 1561 -> 1733 bytes decoder/tree2string_translator.cc | 1 + decoder/trule.cc | 202 +++------------------------------ decoder/trule.h | 22 ++-- training/dpmert/lo_test.cc | 2 +- 18 files changed, 142 insertions(+), 278 deletions(-) create mode 100644 decoder/test_data/hg_test.hg create mode 100644 decoder/test_data/hg_test.hg_balanced create mode 100644 decoder/test_data/hg_test.hg_int create mode 100644 decoder/test_data/hg_test.lattice create mode 100644 decoder/test_data/hg_test.tiny create mode 100644 decoder/test_data/hg_test.tiny_lattice (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index 8738c8f1..ff4c7a90 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -159,7 +159,7 @@ PassiveChart::PassiveChart(const string& goal, chart_(input.size()+1, input.size()+1), nodemap_(input.size()+1, input.size()+1), goal_cat_(TD::Convert(goal) * -1), - goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")), + goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), goal_idx_(-1), lc_fid_(FD::Convert("LatticeCost")), unaries_() { diff --git a/decoder/grammar_test.cc b/decoder/grammar_test.cc index 6d2c6e67..69240139 100644 --- a/decoder/grammar_test.cc +++ b/decoder/grammar_test.cc @@ -33,9 +33,9 @@ BOOST_AUTO_TEST_CASE(TestTextGrammar) { ModelSet models(w, ms); TextGrammar g; - TRulePtr r1(new TRule("[X] ||| a b c ||| A B C ||| 0.1 0.2 0.3", true)); - TRulePtr r2(new TRule("[X] ||| a b c ||| 1 2 3 ||| 0.2 0.3 0.4", true)); - TRulePtr r3(new TRule("[X] ||| a b c d ||| A B C D ||| 0.1 0.2 0.3", true)); + TRulePtr r1(new TRule("[X] ||| a b c ||| A B C ||| 0.1 0.2 0.3")); + TRulePtr r2(new TRule("[X] ||| a b c ||| 1 2 3 ||| 0.2 0.3 0.4")); + TRulePtr r3(new TRule("[X] ||| a b c d ||| A B C D ||| 0.1 0.2 0.3")); cerr << r1->AsString() << endl; g.AddRule(r1); g.AddRule(r2); diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 8519e559..95cfae51 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -18,8 +18,10 @@ using namespace std; BOOST_FIXTURE_TEST_SUITE( s, HGSetup ); BOOST_AUTO_TEST_CASE(Controlled) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); + cerr << "PATH: " << path << "/hg.tiny\n"; Hypergraph hg; - CreateHG_tiny(&hg); + CreateHG_tiny(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -37,10 +39,11 @@ BOOST_AUTO_TEST_CASE(Controlled) { } BOOST_AUTO_TEST_CASE(Union) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg1; Hypergraph hg2; - CreateHG_tiny(&hg1); - CreateHG(&hg2); + CreateHG_tiny(path, &hg1); + CreateHG(path, &hg2); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 1.0); @@ -84,8 +87,9 @@ BOOST_AUTO_TEST_CASE(Union) { } BOOST_AUTO_TEST_CASE(ControlledKBest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); vector w(2); w[0]=0.4; w[1]=0.8; hg.Reweight(w); vector trans; @@ -107,10 +111,11 @@ BOOST_AUTO_TEST_CASE(ControlledKBest) { BOOST_AUTO_TEST_CASE(InsideScore) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); SparseVector wts; wts.set_value(FD::Convert("f1"), 1.0); Hypergraph hg; - CreateTinyLatticeHG(&hg); + CreateTinyLatticeHG(path, &hg); hg.Reweight(wts); vector trans; prob_t cost = ViterbiESentence(hg, &trans); @@ -130,10 +135,11 @@ BOOST_AUTO_TEST_CASE(InsideScore) { BOOST_AUTO_TEST_CASE(PruneInsideOutside) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); SparseVector wts; wts.set_value(FD::Convert("Feature_1"), 1.0); Hypergraph hg; - CreateLatticeHG(&hg); + CreateLatticeHG(path, &hg); hg.Reweight(wts); vector trans; prob_t cost = ViterbiESentence(hg, &trans); @@ -152,8 +158,9 @@ BOOST_AUTO_TEST_CASE(PruneInsideOutside) { } BOOST_AUTO_TEST_CASE(TestPruneEdges) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateLatticeHG(&hg); + CreateLatticeHG(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -166,8 +173,9 @@ BOOST_AUTO_TEST_CASE(TestPruneEdges) { } BOOST_AUTO_TEST_CASE(TestIntersect) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG_int(&hg); + CreateHG_int(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -192,8 +200,9 @@ BOOST_AUTO_TEST_CASE(TestIntersect) { } BOOST_AUTO_TEST_CASE(TestPrune2) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG_int(&hg); + CreateHG_int(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -207,8 +216,9 @@ BOOST_AUTO_TEST_CASE(TestPrune2) { } BOOST_AUTO_TEST_CASE(Sample) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateLatticeHG(&hg); + CreateLatticeHG(path, &hg); SparseVector wts; wts.set_value(FD::Convert("Feature_1"), 0.0); hg.Reweight(wts); @@ -220,6 +230,7 @@ BOOST_AUTO_TEST_CASE(Sample) { } BOOST_AUTO_TEST_CASE(PLF) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)"; HypergraphIO::ReadFromPLF(inplf, &hg); @@ -234,8 +245,9 @@ BOOST_AUTO_TEST_CASE(PLF) { } BOOST_AUTO_TEST_CASE(PushWeightsToGoal) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); vector w(2); w[0]=0.4; w[1]=0.8; hg.Reweight(w); vector trans; @@ -248,8 +260,9 @@ BOOST_AUTO_TEST_CASE(PushWeightsToGoal) { } BOOST_AUTO_TEST_CASE(TestSpecialKBest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHGBalanced(&hg); + CreateHGBalanced(path, &hg); vector w(1); w[0]=0; hg.Reweight(w); vector, prob_t> > list; @@ -264,8 +277,9 @@ BOOST_AUTO_TEST_CASE(TestSpecialKBest) { } BOOST_AUTO_TEST_CASE(TestGenericViterbi) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG_tiny(&hg); + CreateHG_tiny(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -279,8 +293,9 @@ BOOST_AUTO_TEST_CASE(TestGenericViterbi) { } BOOST_AUTO_TEST_CASE(TestGenericInside) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateTinyLatticeHG(&hg); + CreateTinyLatticeHG(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -296,8 +311,9 @@ BOOST_AUTO_TEST_CASE(TestGenericInside) { } BOOST_AUTO_TEST_CASE(TestGenericInside2) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -322,8 +338,9 @@ BOOST_AUTO_TEST_CASE(TestGenericInside2) { } BOOST_AUTO_TEST_CASE(TestAddExpectations) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -338,8 +355,8 @@ BOOST_AUTO_TEST_CASE(TestAddExpectations) { } BOOST_AUTO_TEST_CASE(Small) { - Hypergraph hg; std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); + Hypergraph hg; CreateSmallHG(&hg, path); SparseVector wts; wts.set_value(FD::Convert("Model_0"), -2.0); @@ -361,6 +378,7 @@ BOOST_AUTO_TEST_CASE(Small) { } BOOST_AUTO_TEST_CASE(JSONTest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); ostringstream os; JSONParser::WriteEscapedString("\"I don't know\", she said.", &os); BOOST_CHECK_EQUAL("\"\\\"I don't know\\\", she said.\"", os.str()); @@ -370,9 +388,10 @@ BOOST_AUTO_TEST_CASE(JSONTest) { } BOOST_AUTO_TEST_CASE(TestGenericKBest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); - //CreateHGBalanced(&hg); + CreateHG(path, &hg); + //CreateHGBalanced(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 1.0); @@ -392,8 +411,9 @@ BOOST_AUTO_TEST_CASE(TestGenericKBest) { } BOOST_AUTO_TEST_CASE(TestReadWriteHG) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg,hg2; - CreateHG(&hg); + CreateHG(path, &hg); hg.edges_.front().j_ = 23; hg.edges_.back().prev_i_ = 99; ostringstream os; diff --git a/decoder/hg_test.h b/decoder/hg_test.h index e96cb0b1..b7bab3c2 100644 --- a/decoder/hg_test.h +++ b/decoder/hg_test.h @@ -23,25 +23,13 @@ Name perro_wts="SameFirstLetter 1 LongerThanPrev 1 ShorterThanPrev 1 GlueTop 0.0 // you can inherit from this or just use the static methods struct HGSetup { - enum { - HG, - HG_int, - HG_tiny, - HGBalanced, - LatticeHG, - TinyLatticeHG, - }; - static void CreateHG(Hypergraph* hg); - static void CreateHG_int(Hypergraph* hg); - static void CreateHG_tiny(Hypergraph* hg); - static void CreateHGBalanced(Hypergraph* hg); - static void CreateLatticeHG(Hypergraph* hg); - static void CreateTinyLatticeHG(Hypergraph* hg); - - static void Json(Hypergraph *hg,std::string const& json) { - std::istringstream i(json); - HypergraphIO::ReadFromJSON(&i, hg); - } + static void CreateHG(const std::string& path,Hypergraph* hg); + static void CreateHG_int(const std::string& path,Hypergraph* hg); + static void CreateHG_tiny(const std::string& path, Hypergraph* hg); + static void CreateHGBalanced(const std::string& path,Hypergraph* hg); + static void CreateLatticeHG(const std::string& path,Hypergraph* hg); + static void CreateTinyLatticeHG(const std::string& path,Hypergraph* hg); + static void JsonFile(Hypergraph *hg,std::string f) { ReadFile rf(f); HypergraphIO::ReadFromJSON(rf.stream(), hg); @@ -52,18 +40,6 @@ struct HGSetup { static void CreateSmallHG(Hypergraph *hg, std::string path) { JsonTestFile(hg,path,small_json); } }; -namespace { -Name HGjsons[]= { - "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}", -"{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| b\",3,\"[X] ||| a [1]\",4,\"[X] ||| [1] b\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,0.1],\"rule\":1},{\"tail\":[],\"feats\":[0,0.1],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X\"},\"edges\":[{\"tail\":[0],\"feats\":[0,0.3],\"rule\":3},{\"tail\":[0],\"feats\":[0,0.2],\"rule\":4}],\"node\":{\"in_edges\":[2,3],\"cat\":\"Goal\"}}", - "{\"rules\":[1,\"[X] ||| \",2,\"[X] ||| X [1]\",3,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,-2,1,-99],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.5,1,-0.8],\"rule\":2},{\"tail\":[0],\"feats\":[0,-0.7,1,-0.9],\"rule\":3}],\"node\":{\"in_edges\":[1,2]}}", - "{\"rules\":[1,\"[X] ||| i\",2,\"[X] ||| a\",3,\"[X] ||| b\",4,\"[X] ||| [1] [2]\",5,\"[X] ||| [1] [2]\",6,\"[X] ||| c\",7,\"[X] ||| d\",8,\"[X] ||| [1] [2]\",9,\"[X] ||| [1] [2]\",10,\"[X] ||| [1] [2]\",11,\"[X] ||| [1] [2]\",12,\"[X] ||| [1] [2]\",13,\"[X] ||| [1] [2]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[1,2],\"feats\":[],\"rule\":4},{\"tail\":[2,1],\"feats\":[],\"rule\":5}],\"node\":{\"in_edges\":[3,4]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":6}],\"node\":{\"in_edges\":[5]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":7}],\"node\":{\"in_edges\":[6]},\"edges\":[{\"tail\":[4,5],\"feats\":[],\"rule\":8},{\"tail\":[5,4],\"feats\":[],\"rule\":9}],\"node\":{\"in_edges\":[7,8]},\"edges\":[{\"tail\":[3,6],\"feats\":[],\"rule\":10},{\"tail\":[6,3],\"feats\":[],\"rule\":11}],\"node\":{\"in_edges\":[9,10]},\"edges\":[{\"tail\":[7,0],\"feats\":[],\"rule\":12},{\"tail\":[0,7],\"feats\":[],\"rule\":13}],\"node\":{\"in_edges\":[11,12]}}", - "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] A A\",4,\"[X] ||| [1] b\",5,\"[X] ||| [1] c\",6,\"[X] ||| [1] B C\",7,\"[X] ||| [1] A B C\",8,\"[X] ||| [1] CC\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[2,-0.3],\"rule\":1},{\"tail\":[0],\"feats\":[2,-0.6],\"rule\":2},{\"tail\":[0],\"feats\":[2,-1.7],\"rule\":3}],\"node\":{\"in_edges\":[0,1,2]},\"edges\":[{\"tail\":[1],\"feats\":[2,-0.5],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[2],\"feats\":[2,-0.6],\"rule\":5},{\"tail\":[1],\"feats\":[2,-0.8],\"rule\":6},{\"tail\":[0],\"feats\":[2,-0.01],\"rule\":7},{\"tail\":[2],\"feats\":[2,-0.8],\"rule\":8}],\"node\":{\"in_edges\":[4,5,6,7]}}", - "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] b\",4,\"[X] ||| [1] B'\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.2],\"rule\":1},{\"tail\":[0],\"feats\":[0,-0.6],\"rule\":2}],\"node\":{\"in_edges\":[0,1]},\"edges\":[{\"tail\":[1],\"feats\":[0,-0.1],\"rule\":3},{\"tail\":[1],\"feats\":[0,-0.9],\"rule\":4}],\"node\":{\"in_edges\":[2,3]}}", -}; - -} - void AddNullEdge(Hypergraph* hg) { TRule x; x.arity_ = 0; @@ -71,31 +47,36 @@ void AddNullEdge(Hypergraph* hg) { hg->edges_.back().head_node_ = 0; } -void HGSetup::CreateTinyLatticeHG(Hypergraph* hg) { - Json(hg,HGjsons[TinyLatticeHG]); +void HGSetup::CreateTinyLatticeHG(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.tiny_lattice"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); AddNullEdge(hg); } -void HGSetup::CreateLatticeHG(Hypergraph* hg) { - Json(hg,HGjsons[LatticeHG]); +void HGSetup::CreateLatticeHG(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.lattice"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); AddNullEdge(hg); } -void HGSetup::CreateHG_tiny(Hypergraph* hg) { - Json(hg,HGjsons[HG_tiny]); +void HGSetup::CreateHG_tiny(const std::string& path, Hypergraph* hg) { + ReadFile rf(path + "/hg_test.tiny"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } -void HGSetup::CreateHG_int(Hypergraph* hg) { - Json(hg,HGjsons[HG_int]); +void HGSetup::CreateHG_int(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.hg_int"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } -void HGSetup::CreateHG(Hypergraph* hg) { - Json(hg,HGjsons[HG]); +void HGSetup::CreateHG(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.hg"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } -void HGSetup::CreateHGBalanced(Hypergraph* hg) { - Json(hg,HGjsons[HGBalanced]); +void HGSetup::CreateHGBalanced(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.hg_balanced"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } - #endif diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h index f844e5b2..e15c056d 100644 --- a/decoder/rule_lexer.h +++ b/decoder/rule_lexer.h @@ -9,6 +9,7 @@ struct RuleLexer { typedef void (*RuleCallback)(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra); static void ReadRules(std::istream* in, RuleCallback func, const std::string& fname, void* extra); + static void ReadRule(const std::string&, RuleCallback func, bool mono_rule, void* extra); }; #endif diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll index cc73c079..d4a8d86b 100644 --- a/decoder/rule_lexer.ll +++ b/decoder/rule_lexer.ll @@ -14,6 +14,7 @@ #include "verbose.h" #include "tree_fragment.h" +bool lex_mono_rules = false; int lex_line = 0; std::istream* scfglex_stream = NULL; RuleLexer::RuleCallback rule_callback = NULL; @@ -120,7 +121,7 @@ void check_and_update_ctf_stack(const TRulePtr& rp) { %} REAL [\-+]?[0-9]+(\.[0-9]*)?([eE][-+]*[0-9]+)? -NT [^\t \[\],]+ +NT ([^\t \n\r\[\],]+|Goal) %x LHS_END SRC TRG FEATS FEATVAL ALIGNS TREE %% @@ -132,7 +133,7 @@ NT [^\t \[\],]+ \[{NT}\] { scfglex_tmp_token.assign(yytext + 1, yyleng - 2); scfglex_lhs = -TD::Convert(scfglex_tmp_token); - // std::cerr << scfglex_tmp_token << "\n"; + //std::cerr << "LHS: " << scfglex_tmp_token << "\n"; BEGIN(LHS_END); } @@ -199,9 +200,9 @@ NT [^\t \[\],]+ \|\|\| { memset(scfglex_nt_sanity, 0, scfglex_src_arity * sizeof(int)); - BEGIN(TRG); + if (lex_mono_rules) { BEGIN(FEATS); } else { BEGIN(TRG); } } -[^ \t]+ { +[^ \t\n\r]+ { scfglex_tmp_token.assign(yytext, yyleng); scfglex_src_rhs[scfglex_src_rhs_size] = TD::Convert(scfglex_tmp_token); ++scfglex_src_rhs_size; @@ -217,14 +218,28 @@ NT [^\t \[\],]+ \|\|\| { BEGIN(FEATS); } -[^ \t]+ { +[^ \t\n\r]+ { scfglex_tmp_token.assign(yytext, yyleng); scfglex_trg_rhs[scfglex_trg_rhs_size] = TD::Convert(scfglex_tmp_token); ++scfglex_trg_rhs_size; } [ \t]+ { ; } -\n { +\n { + if (lex_mono_rules) { + if (scfglex_trg_rhs_size != 0) { + std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": expected monolingual rule\n"; + abort(); + } + scfglex_trg_arity = scfglex_src_arity; + scfglex_trg_rhs_size = scfglex_src_rhs_size; + int ntc = 0; + for (int i = 0; i < scfglex_src_rhs_size; ++i) + if (scfglex_trg_rhs[i] <= 0) + scfglex_trg_rhs[i] = ntc--; + else + scfglex_trg_rhs[i] = scfglex_src_rhs[i]; + } if (scfglex_src_arity != scfglex_trg_arity) { std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": LHS and RHS arity mismatch!\n"; abort(); @@ -243,7 +258,7 @@ NT [^\t \[\],]+ TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top()); rule_callback(rp, ctf_level, coarse_rp, rule_callback_extra); ctf_rule_stack.push(rp); - // std::cerr << rp->AsString() << std::endl; + //std::cerr << "RULE: " << rp->AsString() << std::endl; num_rules++; lex_line++; if (!SILENT) { @@ -317,7 +332,7 @@ NT [^\t \[\],]+ #include "filelib.h" -void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const std::string& fname, void* extra) { +static void init_default_feature_names() { if (scfglex_phrase_fnames.empty()) { scfglex_phrase_fnames.resize(100); for (int i = 0; i < scfglex_phrase_fnames.size(); ++i) { @@ -326,6 +341,11 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const scfglex_phrase_fnames[i] = FD::Convert(os.str()); } } +} + +void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const std::string& fname, void* extra) { + init_default_feature_names(); + lex_mono_rules = false; lex_line = 1; scfglex_fname = fname; scfglex_stream = in; @@ -334,3 +354,14 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const yylex(); } +void RuleLexer::ReadRule(const std::string& srule, RuleCallback func, bool mono, void* extra) { + init_default_feature_names(); + lex_mono_rules = mono; + lex_line = 1; + rule_callback_extra = extra; + rule_callback = func; + yy_scan_string(srule.c_str()); + yylex(); + yylex_destroy(); +} + diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 159a1d60..88f62769 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -47,7 +47,7 @@ GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt, const TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [1]")); AddRule(stop_glue); RefineRule(stop_glue, ctf_level); - TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["+ default_nt + ",2] ||| [1] [2] ||| Glue=1")); + TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + "] ["+ default_nt + "] ||| [1] [2] ||| Glue=1")); AddRule(glue); RefineRule(glue, ctf_level); } diff --git a/decoder/test_data/hg_test.hg b/decoder/test_data/hg_test.hg new file mode 100644 index 00000000..ef98e9d4 --- /dev/null +++ b/decoder/test_data/hg_test.hg @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| a ||| a",2,"[X] ||| A [X] ||| A [1]",3,"[X] ||| c ||| c",4,"[X] ||| C [X] ||| C [1]",5,"[X] ||| [X] B [X] ||| [1] B [2]",6,"[X] ||| [X] b [X] ||| [1] b [2]",7,"[X] ||| X [X] ||| X [1]",8,"[X] ||| Z [X] ||| Z [1]"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[24568,32767,24568,32767],"feats":[],"rule":1}],"node":{"in_edges":[0],"cat":"X"},"edges":[{"tail":[0],"spans":[24568,32767,24568,32767],"feats":[0,-0.8,1,-0.1],"rule":2}],"node":{"in_edges":[1],"cat":"X"},"edges":[{"tail":[],"spans":[24568,32767,24568,32767],"feats":[1,-1],"rule":3}],"node":{"in_edges":[2],"cat":"X"},"edges":[{"tail":[2],"spans":[24568,32767,24568,32767],"feats":[0,-0.2,1,-0.1],"rule":4}],"node":{"in_edges":[3],"cat":"X"},"edges":[{"tail":[1,3],"spans":[24568,32767,24568,32767],"feats":[0,-1.2,1,-0.2],"rule":5},{"tail":[1,3],"spans":[24568,32767,24568,32767],"feats":[0,-0.5,1,-1.3],"rule":6}],"node":{"in_edges":[4,5],"cat":"X"},"edges":[{"tail":[4],"spans":[24568,32767,24568,32767],"feats":[0,-0.5,1,-0.8],"rule":7},{"tail":[4],"spans":[24568,32767,24568,32767],"feats":[0,-0.7,1,-0.9],"rule":8}],"node":{"in_edges":[6,7],"cat":"X"}} diff --git a/decoder/test_data/hg_test.hg_balanced b/decoder/test_data/hg_test.hg_balanced new file mode 100644 index 00000000..0f0f499f --- /dev/null +++ b/decoder/test_data/hg_test.hg_balanced @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| i ||| i",2,"[X] ||| a ||| a",3,"[X] ||| b ||| b",4,"[X] ||| [X] [X] ||| [1] [2]",5,"[X] ||| [X] [X] ||| [1] [2]",6,"[X] ||| c ||| c",7,"[X] ||| d ||| d",8,"[X] ||| [X] [X] ||| [1] [2]",9,"[X] ||| [X] [X] ||| [1] [2]",10,"[X] ||| [X] [X] ||| [1] [2]",11,"[X] ||| [X] [X] ||| [1] [2]",12,"[X] ||| [X] [X] ||| [1] [2]",13,"[X] ||| [X] [X] ||| [1] [2]"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":1}],"node":{"in_edges":[0],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":2}],"node":{"in_edges":[1],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":3}],"node":{"in_edges":[2],"cat":"X"},"edges":[{"tail":[1,2],"spans":[32760,32767,32760,32767],"feats":[],"rule":4},{"tail":[2,1],"spans":[32760,32767,32760,32767],"feats":[],"rule":5}],"node":{"in_edges":[3,4],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":6}],"node":{"in_edges":[5],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":7}],"node":{"in_edges":[6],"cat":"X"},"edges":[{"tail":[4,5],"spans":[32760,32767,32760,32767],"feats":[],"rule":8},{"tail":[5,4],"spans":[32760,32767,32760,32767],"feats":[],"rule":9}],"node":{"in_edges":[7,8],"cat":"X"},"edges":[{"tail":[3,6],"spans":[32760,32767,32760,32767],"feats":[],"rule":10},{"tail":[6,3],"spans":[32760,32767,32760,32767],"feats":[],"rule":11}],"node":{"in_edges":[9,10],"cat":"X"},"edges":[{"tail":[7,0],"spans":[32760,32767,32760,32767],"feats":[],"rule":12},{"tail":[0,7],"spans":[32760,32767,32760,32767],"feats":[],"rule":13}],"node":{"in_edges":[11,12],"cat":"X"}} diff --git a/decoder/test_data/hg_test.hg_int b/decoder/test_data/hg_test.hg_int new file mode 100644 index 00000000..9c4603bc --- /dev/null +++ b/decoder/test_data/hg_test.hg_int @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| a ||| a",2,"[X] ||| b ||| b",3,"[X] ||| a [X] ||| a [1]",4,"[X] ||| [X] b ||| [1] b"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[-8200,32767,-8200,32767],"feats":[0,0.1],"rule":1},{"tail":[],"spans":[-8200,32767,-8200,32767],"feats":[0,0.1],"rule":2}],"node":{"in_edges":[0,1],"cat":"X"},"edges":[{"tail":[0],"spans":[-8200,32767,-8200,32767],"feats":[0,0.3],"rule":3},{"tail":[0],"spans":[-8200,32767,-8200,32767],"feats":[0,0.2],"rule":4}],"node":{"in_edges":[2,3],"cat":"Goal"}} diff --git a/decoder/test_data/hg_test.lattice b/decoder/test_data/hg_test.lattice new file mode 100644 index 00000000..29e021c5 --- /dev/null +++ b/decoder/test_data/hg_test.lattice @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| [X] a ||| [1] a",2,"[X] ||| [X] A ||| [1] A",3,"[X] ||| [X] A A ||| [1] A A",4,"[X] ||| [X] b ||| [1] b",5,"[X] ||| [X] c ||| [1] c",6,"[X] ||| [X] B C ||| [1] B C",7,"[X] ||| [X] A B C ||| [1] A B C",8,"[X] ||| [X] CC ||| [1] CC"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7"],"edges":[],"node":{"in_edges":[]},"edges":[{"tail":[0],"feats":[2,-0.3],"rule":1},{"tail":[0],"feats":[2,-0.6],"rule":2},{"tail":[0],"feats":[2,-1.7],"rule":3}],"node":{"in_edges":[0,1,2]},"edges":[{"tail":[1],"feats":[2,-0.5],"rule":4}],"node":{"in_edges":[3]},"edges":[{"tail":[2],"feats":[2,-0.6],"rule":5},{"tail":[1],"feats":[2,-0.8],"rule":6},{"tail":[0],"feats":[2,-0.01],"rule":7},{"tail":[2],"feats":[2,-0.8],"rule":8}],"node":{"in_edges":[4,5,6,7]}}" diff --git a/decoder/test_data/hg_test.tiny b/decoder/test_data/hg_test.tiny new file mode 100644 index 00000000..101b96e9 --- /dev/null +++ b/decoder/test_data/hg_test.tiny @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| ||| ",2,"[X] ||| X [X] ||| X [1]",3,"[X] ||| Z [X] ||| Z [1]"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[25080,32767,25080,32767],"feats":[0,-2,1,-99],"rule":1}],"node":{"in_edges":[0],"cat":"X"},"edges":[{"tail":[0],"spans":[25080,32767,25080,32767],"feats":[0,-0.5,1,-0.8],"rule":2},{"tail":[0],"spans":[25080,32767,25080,32767],"feats":[0,-0.7,1,-0.9],"rule":3}],"node":{"in_edges":[1,2],"cat":"X"}} diff --git a/decoder/test_data/hg_test.tiny_lattice b/decoder/test_data/hg_test.tiny_lattice new file mode 100644 index 00000000..b9adf3cd --- /dev/null +++ b/decoder/test_data/hg_test.tiny_lattice @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| [X] a ||| [1] a",2,"[X] ||| [X] A ||| [1] A",3,"[X] ||| [X] b ||| [1] b",4,"[X] ||| [X] B' ||| [1] B'"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7"],"edges":[],"node":{"in_edges":[]},"edges":[{"tail":[0],"feats":[0,-0.2],"rule":1},{"tail":[0],"feats":[0,-0.6],"rule":2}],"node":{"in_edges":[0,1]},"edges":[{"tail":[1],"feats":[0,-0.1],"rule":3},{"tail":[1],"feats":[0,-0.9],"rule":4}],"node":{"in_edges":[2,3]}} diff --git a/decoder/test_data/small.json.gz b/decoder/test_data/small.json.gz index 892ba360..f6f37293 100644 Binary files a/decoder/test_data/small.json.gz and b/decoder/test_data/small.json.gz differ diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 4cd584fb..f288ab4e 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -174,6 +174,7 @@ struct Tree2StringTranslatorImpl { q.push(ParserState(input_tree.begin(), g.get())); unique.insert(q.back()); } + if (q.size() == 0) return false; unsigned tree_top = q.front().input_node_idx; while(!q.empty()) { ParserState& s = q.front(); diff --git a/decoder/trule.cc b/decoder/trule.cc index c22baae3..1bd5425f 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -17,73 +17,16 @@ bool TRule::IsGoal() const { return GetLHS() == kGOAL; } -static WordID ConvertTrgString(const string& w) { - const unsigned len = w.size(); - WordID id = 0; - // [X,0] or [0] - // for target rules, we ignore the category, just keep the index - if (len > 2 && w[0]=='[' && w[len-1]==']' && w[len-2] > '0' && w[len-2] <= '9' && - (len == 3 || (len > 4 && w[len-3] == ','))) { - id = w[len-2] - '0'; - id = 1 - id; - } else { - id = TD::Convert(w); - } - return id; -} - -static WordID ConvertSrcString(const string& w, bool mono = false) { - const unsigned len = w.size(); - // [X,0] - // for source rules, we keep the category and ignore the index (source rules are - // always numbered 1, 2, 3... - if (mono) { - if (len > 2 && w[0]=='[' && w[len-1]==']') { - if (len > 4 && w[len-3] == ',') { - cerr << "[ERROR] Monolingual rules mut not have non-terminal indices:\n " - << w << endl; - exit(1); - } - // TODO check that source indices go 1,2,3,etc. - return TD::Convert(w.substr(1, len-2)) * -1; - } else { - return TD::Convert(w); - } - } else { - if (len > 4 && w[0]=='[' && w[len-1]==']' && w[len-3] == ',' && w[len-2] > '0' && w[len-2] <= '9') { - return TD::Convert(w.substr(1, len-4)) * -1; - } else { - return TD::Convert(w); - } - } -} - -static WordID ConvertLHS(const string& w) { - if (w[0] == '[') { - const unsigned len = w.size(); - if (len < 3) { cerr << "Format error: " << w << endl; exit(1); } - return TD::Convert(w.substr(1, len-2)) * -1; - } else { - return TD::Convert(w) * -1; - } -} - TRule* TRule::CreateRuleSynchronous(const string& rule) { TRule* res = new TRule; - if (res->ReadFromString(rule, true, false)) return res; + if (res->ReadFromString(rule)) return res; cerr << "[ERROR] Failed to creating rule from: " << rule << endl; delete res; return NULL; } TRule* TRule::CreateRulePhrasetable(const string& rule) { - // TODO make this faster - // TODO add configuration for default NT type - if (rule[0] == '[') { - cerr << "Phrasetable rules shouldn't have a LHS / non-terminals:\n " << rule << endl; - return NULL; - } - TRule* res = new TRule("[X] ||| " + rule, true, false); + TRule* res = new TRule("[X] ||| " + rule); if (res->Arity() != 0) { cerr << "Phrasetable rules should have arity 0:\n " << rule << endl; delete res; @@ -93,138 +36,27 @@ TRule* TRule::CreateRulePhrasetable(const string& rule) { } TRule* TRule::CreateRuleMonolingual(const string& rule) { - return new TRule(rule, false, true); + return new TRule(rule, true); } namespace { -// callback for lexer +// callback for single rule lexer int n_assigned=0; -void assign_trule(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) { - (void) ctf_level; - (void) coarse_rule; - TRule *assignto=(TRule *)extra; - *assignto=*new_rule; - ++n_assigned; -} - -} - -bool TRule::ReadFromString(const string& line, bool strict, bool mono) { - if (!is_single_line_stripped(line)) - cerr<<"\nWARNING: building rule from multi-line string "<1) - cerr<<"\nWARNING: more than one rule parsed from multi-line string; kept last: "<(extra) = *new_rule; + ++n_assigned; } - if (format >= 2 || (mono && format == 1)) { - while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); } - while(is>>w && w!="|||") { f_.push_back(ConvertSrcString(w, mono)); } - if (!mono) { - while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); } - } - int fv = 0; - if (is) { - string ss; - getline(is, ss); - //cerr << "L: " << ss << endl; - unsigned start = 0; - unsigned len = ss.size(); - const size_t ppos = ss.find(" |||"); - if (ppos != string::npos) { len = ppos; } - while (start < len) { - while(start < len && (ss[start] == ' ' || ss[start] == ';')) - ++start; - if (start == len) break; - unsigned end = start + 1; - while(end < len && (ss[end] != '=' && ss[end] != ' ' && ss[end] != ';')) - ++end; - if (end == len || ss[end] == ' ' || ss[end] == ';') { - //cerr << "PROC: '" << ss.substr(start, end - start) << "'\n"; - // non-named features - if (end != len) { ss[end] = 0; } - string fname = "PhraseModel_X"; - if (fv > 9) { cerr << "Too many phrasetable scores - used named format\n"; abort(); } - fname[12]='0' + fv; - ++fv; - // if the feature set is frozen, this may return zero, indicating an - // undefined feature - const int fid = FD::Convert(fname); - if (fid) - scores_.set_value(fid, atof(&ss[start])); - //cerr << "F: " << fname << " VAL=" << scores_.value(FD::Convert(fname)) << endl; - } else { - const int fid = FD::Convert(ss.substr(start, end - start)); - start = end + 1; - end = start + 1; - while(end < len && (ss[end] != ' ' && ss[end] != ';')) - ++end; - if (end < len) { ss[end] = 0; } - assert(start < len); - if (fid) - scores_.set_value(fid, atof(&ss[start])); - //cerr << "F: " << FD::Convert(fid) << " VAL=" << scores_.value(fid) << endl; - } - start = end + 1; - } - } - } else if (format == 1) { - while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); } - while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); } - f_ = e_; - int x = ConvertLHS("[X]"); - for (unsigned i = 0; i < f_.size(); ++i) - if (f_[i] <= 0) { f_[i] = x; } - } else { - cerr << "F: " << format << endl; - cerr << "[ERROR] Don't know how to read:\n" << line << endl; - } - if (mono) { - e_ = f_; - int ci = 0; - for (unsigned i = 0; i < e_.size(); ++i) - if (e_[i] < 0) - e_[i] = ci--; - } - ComputeArity(); - return SanityCheck(); } -bool TRule::SanityCheck() const { - vector used(f_.size(), 0); - int ac = 0; - for (unsigned i = 0; i < e_.size(); ++i) { - int ind = e_[i]; - if (ind > 0) continue; - ind = -ind; - if ((++used[ind]) != 1) { - cerr << "[ERROR] e-side variable index " << (ind+1) << " used more than once!\n"; - return false; - } - ac++; - } - if (ac != Arity()) { - cerr << "[ERROR] e-side arity mismatches f-side\n"; - return false; - } - return true; +bool TRule::ReadFromString(const string& line, bool mono) { + n_assigned = 0; + //cerr << "LINE: " << line << " -- mono=" << mono << endl; + RuleLexer::ReadRule(line + '\n', assign_trule, mono, this); + if (n_assigned > 1) + cerr<<"\nWARNING: more than one rule parsed from multi-line string; kept last: "< 9 variables won't work - explicit TRule(const std::string& text, bool strict = false, bool mono = false) : prev_i(-1), prev_j(-1) { - ReadFromString(text, strict, mono); + explicit TRule(const std::string& text, bool mono = false) : prev_i(-1), prev_j(-1) { + ReadFromString(text, mono); } - // deprecated, use lexer // make a rule from a hiero-like rule table, e.g. // [X] ||| [X,1] DE [X,2] ||| [X,2] of the [X,1] - // if misformatted, returns NULL static TRule* CreateRuleSynchronous(const std::string& rule); - // deprecated, use lexer // make a rule from a phrasetable entry (i.e., one that has no LHS type), e.g: // el gato ||| the cat ||| Feature_2=0.34 static TRule* CreateRulePhrasetable(const std::string& rule); - // deprecated, use lexer // make a rule from a non-synchrnous CFG representation, e.g.: // [LHS] ||| term1 [NT] term2 [OTHER_NT] [YET_ANOTHER_NT] static TRule* CreateRuleMonolingual(const std::string& rule); @@ -80,11 +75,10 @@ class TRule { std::vector* result) const { unsigned vc = 0; result->clear(); - for (std::vector::const_iterator i = e_.begin(); i != e_.end(); ++i) { - const WordID& c = *i; + for (const auto& c : e_) { if (c < 1) { ++vc; - const std::vector& var_value = *var_values[-c]; + const auto& var_value = *var_values[-c]; std::copy(var_value.begin(), var_value.end(), std::back_inserter(*result)); @@ -99,10 +93,9 @@ class TRule { std::vector* result) const { unsigned vc = 0; result->clear(); - for (std::vector::const_iterator i = f_.begin(); i != f_.end(); ++i) { - const WordID& c = *i; + for (const auto& c : f_) { if (c < 1) { - const std::vector& var_value = *var_values[vc++]; + const auto& var_value = *var_values[vc++]; std::copy(var_value.begin(), var_value.end(), std::back_inserter(*result)); @@ -113,7 +106,7 @@ class TRule { assert(vc == var_values.size()); } - bool ReadFromString(const std::string& line, bool strict = false, bool monolingual = false); + bool ReadFromString(const std::string& line, bool monolingual = false); bool Initialized() const { return e_.size(); } @@ -166,7 +159,6 @@ class TRule { private: TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {} - bool SanityCheck() const; }; inline size_t hash_value(const TRule& r) { diff --git a/training/dpmert/lo_test.cc b/training/dpmert/lo_test.cc index d89bcd99..b8776169 100644 --- a/training/dpmert/lo_test.cc +++ b/training/dpmert/lo_test.cc @@ -56,7 +56,7 @@ BOOST_AUTO_TEST_CASE(TestConvexHull) { } BOOST_AUTO_TEST_CASE(TestConvexHullInside) { - const string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}"; + const string json = "{\"rules\":[1,\"[X] ||| a ||| a\",2,\"[X] ||| A [X] ||| A [1]\",3,\"[X] ||| c ||| c\",4,\"[X] ||| C [X] ||| C [1]\",5,\"[X] ||| [X] B [X] ||| [1] B [2]\",6,\"[X] ||| [X] b [X] ||| [1] b [2]\",7,\"[X] ||| X [X] ||| X [1]\",8,\"[X] ||| Z [X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}"; Hypergraph hg; istringstream instr(json); HypergraphIO::ReadFromJSON(&instr, &hg); -- cgit v1.2.3 From b9e6e7e24cc48021090b689e143288e2b7f2b5fc Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 7 Apr 2014 22:56:34 -0400 Subject: track node state for smarter union --- Makefile.am | 2 +- decoder/Makefile.am | 1 + decoder/apply_models.cc | 306 +++++++++++++++++----------------- decoder/bottom_up_parser.cc | 10 ++ decoder/decoder.cc | 13 ++ decoder/fst_translator.cc | 6 + decoder/hg.cc | 47 ++---- decoder/hg.h | 10 +- decoder/lexalign.cc | 5 + decoder/lextrans.cc | 5 + decoder/node_state_hash.h | 36 ++++ decoder/nt_span.h | 2 +- decoder/tagger.cc | 5 + decoder/tree2string_translator.cc | 14 +- mteval/Makefile.am | 8 +- tests/tools/filter-stderr.pl | 1 + utils/Makefile.am | 3 +- utils/hash.h | 21 +-- utils/murmur_hash.h | 186 --------------------- utils/murmur_hash3.cc | 340 ++++++++++++++++++++++++++++++++++++++ utils/murmur_hash3.h | 67 ++++++++ 21 files changed, 699 insertions(+), 389 deletions(-) create mode 100644 decoder/node_state_hash.h delete mode 100644 utils/murmur_hash.h create mode 100644 utils/murmur_hash3.cc create mode 100644 utils/murmur_hash3.h (limited to 'decoder/tree2string_translator.cc') diff --git a/Makefile.am b/Makefile.am index 598293d1..88327477 100644 --- a/Makefile.am +++ b/Makefile.am @@ -3,13 +3,13 @@ # cyclic dependencies between these directories! SUBDIRS = \ utils \ - mteval \ klm/util/double-conversion \ klm/util \ klm/util/stream \ klm/lm \ klm/lm/builder \ klm/search \ + mteval \ decoder \ training \ word-aligner \ diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 5c91fe65..c85f17ed 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -144,6 +144,7 @@ libcdec_a_SOURCES = \ lattice.cc \ lexalign.cc \ lextrans.cc \ + node_state_hash.h \ tree_fragment.cc \ tree_fragment.h \ maxtrans_blunsom.cc \ diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9a8f60be..9f8bbead 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -19,6 +19,7 @@ namespace std { using std::tr1::unordered_map; using std::tr1::unordered_set; } #include +#include "node_state_hash.h" #include "verbose.h" #include "hg.h" #include "ff.h" @@ -229,7 +230,7 @@ public: D.clear(); } - void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) { + void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) { Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_); new_edge->edge_prob_ = item->out_edge_.edge_prob_; Candidate*& o_item = (*s2n)[item->state_]; @@ -238,6 +239,7 @@ public: int& node_id = o_item->node_index_; if (node_id < 0) { Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_); + new_node->node_hash = cdec::HashNode(head_node_hash, item->state_); // ID is combination of existing state + residual state node_states_.push_back(item->state_); node_id = new_node->id_; } @@ -287,7 +289,7 @@ public: cand.pop_back(); // cerr << "POPPED: " << *item << endl; PushSucc(*item, is_goal, &cand, &unique_cands); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); ++pops; } D_v.resize(state2node.size()); @@ -306,112 +308,112 @@ public: } void KBestFast(const int vert_index, const bool is_goal) { - // cerr << "KBest(" << vert_index << ")\n"; - CandidateList& D_v = D[vert_index]; - assert(D_v.empty()); - const Hypergraph::Node& v = in.nodes_[vert_index]; - // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; - const vector& in_edges = v.in_edges_; - CandidateHeap cand; - CandidateList freelist; - cand.reserve(in_edges.size()); - //init with j<0,0> for all rules-edges that lead to node-(NT-span) - for (int i = 0; i < in_edges.size(); ++i) { - const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; - const JVector j(edge.tail_nodes_.size(), 0); - cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); - } - // cerr << " making heap of " << cand.size() << " candidates\n"; - make_heap(cand.begin(), cand.end(), HeapCandCompare()); - State2Node state2node; // "buf" in Figure 2 - int pops = 0; - while(!cand.empty() && pops < pop_limit_) { - pop_heap(cand.begin(), cand.end(), HeapCandCompare()); - Candidate* item = cand.back(); - cand.pop_back(); - // cerr << "POPPED: " << *item << endl; - - PushSuccFast(*item, is_goal, &cand); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); - ++pops; - } - D_v.resize(state2node.size()); - int c = 0; - for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ - D_v[c++] = i->second; - // cerr << "MERGED: " << *i->second << endl; - } - //cerr <<"Node id: "<< vert_index<< endl; - //#ifdef MEASURE_CA - // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + + PushSuccFast(*item, is_goal, &cand); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (auto& i : state2node) { + D_v[c++] = i.second; + // cerr << "MERGED: " << *i.second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<& in_edges = v.in_edges_; - CandidateHeap cand; - CandidateList freelist; - cand.reserve(in_edges.size()); - UniqueCandidateSet unique_accepted; - //init with j<0,0> for all rules-edges that lead to node-(NT-span) - for (int i = 0; i < in_edges.size(); ++i) { - const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; - const JVector j(edge.tail_nodes_.size(), 0); - cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); - } - // cerr << " making heap of " << cand.size() << " candidates\n"; - make_heap(cand.begin(), cand.end(), HeapCandCompare()); - State2Node state2node; // "buf" in Figure 2 - int pops = 0; - while(!cand.empty() && pops < pop_limit_) { - pop_heap(cand.begin(), cand.end(), HeapCandCompare()); - Candidate* item = cand.back(); - cand.pop_back(); + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_accepted; + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); bool is_new = unique_accepted.insert(item).second; - assert(is_new); // these should all be unique! - // cerr << "POPPED: " << *item << endl; - - PushSuccFast2(*item, is_goal, &cand, &unique_accepted); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); - ++pops; - } - D_v.resize(state2node.size()); - int c = 0; - for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ - D_v[c++] = i->second; - // cerr << "MERGED: " << *i->second << endl; - } - //cerr <<"Node id: "<< vert_index<< endl; - //#ifdef MEASURE_CA - // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<tail_nodes_[i]].size()) { - Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); - cand.push_back(new_cand); - push_heap(cand.begin(), cand.end(), HeapCandCompare()); - } - if(item.j_[i]!=0){ - return; - } - } + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + if(item.j_[i]!=0){ + return; + } + } } //PushSucc only if all ancest Cand are added void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){ - CandidateHeap& cand = *pcand; - for (int i = 0; i < item.j_.size(); ++i) { - JVector j = item.j_; - ++j[i]; - if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { - Candidate query_unique(*item.in_edge_, j); - if (HasAllAncestors(&query_unique,ps)) { - Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); - cand.push_back(new_cand); - push_heap(cand.begin(), cand.end(), HeapCandCompare()); - } - } - } + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (HasAllAncestors(&query_unique,ps)) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + } + } } bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){ - for (int i = 0; i < item->j_.size(); ++i) { - JVector j = item->j_; - --j[i]; - if (j[i] >=0) { - Candidate query_unique(*item->in_edge_, j); - if (cs->count(&query_unique) == 0) { - return false; - } - } - } - return true; + for (int i = 0; i < item->j_.size(); ++i) { + JVector j = item->j_; + --j[i]; + if (j[i] >=0) { + Candidate query_unique(*item->in_edge_, j); + if (cs->count(&query_unique) == 0) { + return false; + } + } + } + return true; } const ModelSet& models; @@ -491,7 +493,7 @@ public: FFStates node_states_; // for each node in the out-HG what is // its q function value? const int pop_limit_; - const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) + const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) }; struct NoPruningRescorer { @@ -507,7 +509,7 @@ struct NoPruningRescorer { typedef unordered_map > State2NodeIndex; - void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, State2NodeIndex* state2node) { + void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, size_t head_node_hash, State2NodeIndex* state2node) { const int arity = in_edge.Arity(); Hypergraph::TailNodeVector ends(arity); for (int i = 0; i < arity; ++i) @@ -531,7 +533,9 @@ struct NoPruningRescorer { } int& head_plus1 = (*state2node)[head_state]; if (!head_plus1) { - head_plus1 = out.AddNode(in_edge.rule_->GetLHS())->id_ + 1; + HG::Node* new_node = out.AddNode(in_edge.rule_->GetLHS()); + new_node->node_hash = cdec::HashNode(head_node_hash, head_state); // ID is combination of existing state + residual state + head_plus1 = new_node->id_ + 1; node_states_.push_back(head_state); nodemap[in_edge.head_node_].push_back(head_plus1 - 1); } @@ -553,7 +557,7 @@ struct NoPruningRescorer { const Hypergraph::Node& node = in.nodes_[node_num]; for (int i = 0; i < node.in_edges_.size(); ++i) { const Hypergraph::Edge& edge = in.edges_[node.in_edges_[i]]; - ExpandEdge(edge, is_goal, &state2node); + ExpandEdge(edge, is_goal, node.node_hash, &state2node); } } @@ -605,16 +609,16 @@ void ApplyModelSet(const Hypergraph& in, cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; } if (config.algorithm == IntersectionConfiguration::CUBE) { - CubePruningRescorer ma(models, smeta, in, pl, out); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); } else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){ - CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); + ma.Apply(); } else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){ - CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); + ma.Apply(); } } else { diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index ff4c7a90..b30f1ec6 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -7,6 +7,8 @@ #include #include +#include "node_state_hash.h" +#include "nt_span.h" #include "hg.h" #include "array2d.h" #include "tdict.h" @@ -356,5 +358,13 @@ bool ExhaustiveBottomUpParser::Parse(const Lattice& input, kEPS = TD::Convert("*EPS*"); PassiveChart chart(goal_sym_, grammars_, input, forest); const bool result = chart.Parse(); + + if (result) { + for (auto& node : forest->nodes_) { + Span prev; + const Span s = forest->NodeSpan(node.id_, &prev); + node.node_hash = cdec::HashNode(node.cat_, s.l, s.r, prev.l, prev.r); + } + } return result; } diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 43e2640d..354ea2d9 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -750,6 +750,11 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { return false; } + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } + const bool show_tree_structure=conf.count("show_tree_structure"); if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation); if (conf.count("show_expected_length")) { @@ -813,6 +818,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { forest.swap(rescored_forest); forest.Reweight(cur_weights); if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation, conf.count("extract_rules"), extract_file); + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } } if (conf.count("show_partition")) { @@ -984,6 +993,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_; } forest.Reweight(last_weights); + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,oracle.show_derivation); if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; if (conf.count("show_partition")) { diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc index 074de4c9..4253b652 100644 --- a/decoder/fst_translator.cc +++ b/decoder/fst_translator.cc @@ -67,6 +67,12 @@ struct FSTTranslatorImpl { Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); forest->ConnectEdgeToHeadNode(hg_edge, goal); forest->Reweight(weights); + + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; } if (add_pass_through_rules) fst->ClearPassThroughTranslations(); diff --git a/decoder/hg.cc b/decoder/hg.cc index 7240a8ab..405169c6 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -1,14 +1,17 @@ -//TODO: lazily generate feature vectors for hyperarcs (because some of them will be pruned). this means 1) storing ref to rule for those features 2) providing ff interface for regenerating its feature vector from hyperedge+states and probably 3) still caching feat. vect on hyperedge once it's been generated. ff would normally just contribute its weighted score and result state, not component features. however, the hypergraph drops the state used by ffs after rescoring is done, so recomputation would have to start at the leaves and work bottom up. question: which takes more space, feature id+value, or state? - #include "hg.h" #include #include #include -#include #include #include #include +#ifndef HAVE_OLD_CPP +# include +#else +# include +namespace std { using std::tr1::unordered_set; } +#endif #include "viterbi.h" #include "inside_outside.h" @@ -17,28 +20,21 @@ using namespace std; -#if 0 -Hypergraph::Edge const* Hypergraph::ViterbiGoalEdge() const -{ - Edge const* r=0; - for (unsigned i=0,e=edges_.size();iIsGoal() && (!r || e.edge_prob_ > r->edge_prob_)) - r=&e; - } - return r; +bool Hypergraph::AreNodesUniquelyIdentified() const { + unordered_set s(nodes_.size() * 3 + 7); + for (const auto& n : nodes_) + if (!s.insert(n.node_hash).second) + return false; + return true; } -#endif -Hypergraph::Edge const* Hypergraph::ViterbiSortInEdges() -{ +Hypergraph::Edge const* Hypergraph::ViterbiSortInEdges() { NodeProbs nv; ComputeNodeViterbi(&nv); return SortInEdgesByNodeViterbi(nv); } -Hypergraph::Edge const* Hypergraph::SortInEdgesByNodeViterbi(NodeProbs const& nv) -{ +Hypergraph::Edge const* Hypergraph::SortInEdgesByNodeViterbi(NodeProbs const& nv) { EdgeProbs ev; ComputeEdgeViterbi(nv,&ev); return ViterbiSortInEdges(ev); @@ -375,9 +371,7 @@ bool Hypergraph::PruneInsideOutside(double alpha,double density,const EdgeMask* void Hypergraph::PrintGraphviz() const { int ei = 0; cerr << "digraph G {\n rankdir=LR;\n nodesep=.05;\n"; - for (vector::const_iterator i = edges_.begin(); - i != edges_.end(); ++i) { - const Edge& edge=*i; + for (const auto& edge : edges_) { ++ei; static const string none = ""; string rule = (edge.rule_ ? edge.rule_->AsString(false) : none); @@ -399,14 +393,9 @@ void Hypergraph::PrintGraphviz() const { } cerr << " A_" << ei << " -> " << edge.head_node_ << ";\n"; } - for (vector::const_iterator ni = nodes_.begin(); - ni != nodes_.end(); ++ni) { - cerr << " " << ni->id_ << "[label=\"" << (ni->cat_ < 0 ? TD::Convert(ni->cat_ * -1) : "") - //cerr << " " << ni->id_ << "[label=\"" << ni->cat_ - << " n=" << ni->id_ -// << ",x=" << &*ni -// << ",in=" << ni->in_edges_.size() -// << ",out=" << ni->out_edges_.size() + for (const auto& node : nodes_) { + cerr << " " << node.id_ << "[label=\"" << (node.cat_ < 0 ? TD::Convert(node.cat_ * -1) : "") + << " n=" << node.id_ << "\"];\n"; } cerr << "}\n"; diff --git a/decoder/hg.h b/decoder/hg.h index 343b99cf..43fb275b 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -142,13 +142,15 @@ namespace HG { // TODO get rid of cat_? // TODO keep cat_ and add span and/or state? :) struct Node { - Node() : id_(), cat_() {} + Node() : node_hash(), id_(), cat_() {} + size_t node_hash; // hash of all the information that makes this node unique int id_; // equal to this object's position in the nodes_ vector WordID cat_; // non-terminal category if <0, 0 if not set WordID NT() const { return -cat_; } EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_ EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_ void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting + node_hash = o.node_hash; cat_=o.cat_; } void copy_reindex(Node const& o,indices_after const& n2,indices_after const& e2) { @@ -192,13 +194,14 @@ public: SetNodeOrigin(nodeid,r); return r; } - Span NodeSpan(int nodeid) const { + Span NodeSpan(int nodeid, Span* prev = nullptr) const { Span s; Node const &n=nodes_[nodeid]; if (!n.in_edges_.empty()) { Edge const& e=edges_[n.in_edges_.front()]; s.l=e.i_; s.r=e.j_; + if (prev) { prev->l = e.prev_i_; prev->r = e.prev_j_; } } return s; } @@ -262,6 +265,9 @@ public: for (int i = 0; i < size; ++i) nodes_[i].id_ = i; } + // if all node states are unique, return true + bool AreNodesUniquelyIdentified() const; + // reserves space in the nodes vector to prevent memory locations // from changing void ReserveNodes(size_t n, size_t e = 0) { diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc index 6adb1892..11f20de7 100644 --- a/decoder/lexalign.cc +++ b/decoder/lexalign.cc @@ -124,6 +124,11 @@ bool LexicalAlign::TranslateImpl(const string& input, pimpl_->BuildTrellis(lattice, *smeta, forest); forest->is_linear_chain_ = true; forest->Reweight(weights); + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; return true; } diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index 8c3269bf..74a18c3f 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -280,6 +280,11 @@ bool LexicalTrans::TranslateImpl(const string& input, smeta->SetSourceLength(lattice.size()); if (!pimpl_->BuildTrellis(lattice, *smeta, forest)) return false; forest->Reweight(weights); + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; return true; } diff --git a/decoder/node_state_hash.h b/decoder/node_state_hash.h new file mode 100644 index 00000000..cdc05877 --- /dev/null +++ b/decoder/node_state_hash.h @@ -0,0 +1,36 @@ +#ifndef _NODE_STATE_HASH_ +#define _NODE_STATE_HASH_ + +#include +#include +#include "murmur_hash3.h" +#include "ffset.h" + +namespace cdec { + + struct FirstPassNode { + FirstPassNode(int cat, int i, int j, int pi, int pj) : lhs(cat), s(i), t(j), u(pi), v(pj) {} + int32_t lhs; + short s; + short t; + short u; + short v; + }; + + inline uint64_t HashNode(int cat, int i, int j, int pi, int pj) { + FirstPassNode fpn(cat, i, j, pi, pj); + return MurmurHash3_64(&fpn, sizeof(FirstPassNode), 2654435769U); + } + + inline uint64_t HashNode(uint64_t old_hash, const FFState& state) { + uint8_t buf[1024]; + std::memcpy(buf, &old_hash, sizeof(uint64_t)); + assert(state.size() < (1024u - sizeof(uint64_t))); + std::memcpy(&buf[sizeof(uint64_t)], state.begin(), state.size()); + return MurmurHash3_64(buf, sizeof(uint64_t) + state.size(), 2654435769U); + } + +} + +#endif + diff --git a/decoder/nt_span.h b/decoder/nt_span.h index a918f301..6ff9391f 100644 --- a/decoder/nt_span.h +++ b/decoder/nt_span.h @@ -7,7 +7,7 @@ struct Span { int l,r; - Span() : l(-1) { } + Span() : l(-1), r(-1) { } bool is_null() const { return l<0; } void print(std::ostream &o,char const* for_null="") const { if (is_null()) diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 63e855c8..30fb055f 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -108,6 +108,11 @@ bool Tagger::TranslateImpl(const string& input, pimpl_->BuildTrellis(sequence, forest); forest->Reweight(weights); forest->is_linear_chain_ = true; + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; return true; } diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index f288ab4e..8d12d01d 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -184,13 +184,19 @@ struct Tree2StringTranslatorImpl { // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { int& node_id = tree2hg[s.input_node_idx]; - if (node_id < 0) - node_id = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK))->id_; + if (node_id < 0) { + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = s.input_node_idx + 1; + node_id = new_node->id_; + } TailNodeVector tail; for (auto n : s.future_work) { int& nix = tree2hg[n]; - if (nix < 0) - nix = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK))->id_; + if (nix < 0) { + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK)); + new_node->node_hash = n + 1; + nix = new_node->id_; + } tail.push_back(nix); } for (auto& r : s.node->rules) { diff --git a/mteval/Makefile.am b/mteval/Makefile.am index 681e798e..08591c9a 100644 --- a/mteval/Makefile.am +++ b/mteval/Makefile.am @@ -1,6 +1,7 @@ bin_PROGRAMS = \ fast_score \ - mbr_kbest + mbr_kbest\ + marginalize noinst_PROGRAMS = \ scorer_test @@ -46,4 +47,7 @@ mbr_kbest_LDADD = libmteval.a ../utils/libutils.a scorer_test_SOURCES = scorer_test.cc scorer_test_LDADD = libmteval.a ../utils/libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -AM_CPPFLAGS = -DTEST_DATA=\"$(top_srcdir)/mteval/test_data\" -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare -I$(top_srcdir) -I$(top_srcdir)/utils +marginalize_SOURCES = marginalize.cc +marginalize_LDADD = libmteval.a ../klm/search/libksearch.a ../klm/lm/libklm.a ../klm/util/libklm_util.a ../klm/util/double-conversion/libklm_util_double.a ../utils/libutils.a + +AM_CPPFLAGS = -DTEST_DATA=\"$(top_srcdir)/mteval/test_data\" -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/klm diff --git a/tests/tools/filter-stderr.pl b/tests/tools/filter-stderr.pl index 4a762324..54fe9210 100755 --- a/tests/tools/filter-stderr.pl +++ b/tests/tools/filter-stderr.pl @@ -13,6 +13,7 @@ if (/Init.*\s+Viterbi:\s+($REAL)/) { # -LM Viterbi: australia is have diplomatic relations with north korea one of the few countries . print "-lm_trans $1\n"; } +if (/NODES NOT UNIQUELY IDENTIFIED/) { print "NODES_NOT_UNIQUE 1\n"; } #Constr. forest (nodes/edges): 111/305 #Constr. forest (paths): 9899 if (/Constr\. forest\s+\(nodes\/edges\): (\d+)\/(\d+)/) { print "constr_nodes $1\nconstr_edges $2\n"; } diff --git a/utils/Makefile.am b/utils/Makefile.am index c0ce3509..341fd80b 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -39,7 +39,8 @@ libutils_a_SOURCES = \ kernel_string_subseq.h \ logval.h \ m.h \ - murmur_hash.h \ + murmur_hash3.h \ + murmur_hash3.cc \ named_enum.h \ null_deleter.h \ null_traits.h \ diff --git a/utils/hash.h b/utils/hash.h index e1426ffb..24d2b6ad 100644 --- a/utils/hash.h +++ b/utils/hash.h @@ -3,7 +3,7 @@ #include -#include "murmur_hash.h" +#include "murmur_hash3.h" #ifdef HAVE_CONFIG_H #include "config.h" @@ -44,23 +44,21 @@ const unsigned GOLDEN_MEAN_FRACTION=2654435769U; // assumes C is POD template -struct murmur_hash -{ - typedef MurmurInt result_type; +struct murmur_hash { + typedef size_t result_type; typedef C /*const&*/ argument_type; result_type operator()(argument_type const& c) const { - return MurmurHash((void*)&c,sizeof(c)); + return cdec::MurmurHash3_64((void*)&c, sizeof(c), GOLDEN_MEAN_FRACTION); } }; // murmur_hash_array isn't std guaranteed safe (you need to use string::data()) template <> -struct murmur_hash -{ - typedef MurmurInt result_type; +struct murmur_hash { + typedef size_t result_type; typedef std::string /*const&*/ argument_type; result_type operator()(argument_type const& c) const { - return MurmurHash(c.data(),c.size()); + return cdec::MurmurHash3_64(c.data(), c.size(), GOLDEN_MEAN_FRACTION); } }; @@ -68,10 +66,10 @@ struct murmur_hash template struct murmur_hash_array { - typedef MurmurInt result_type; + typedef size_t result_type; typedef C /*const&*/ argument_type; result_type operator()(argument_type const& c) const { - return MurmurHash(&*c.begin(),c.size()*sizeof(*c.begin())); + return cdec::MurmurHash3_64(&*c.begin(), c.size()*sizeof(*c.begin()), GOLDEN_MEAN_FRACTION); } }; @@ -95,7 +93,6 @@ typename H::mapped_type & get_or_construct(H &ht,K const& k,C0 const& c0) { } } - // get_or_call (0 arg) template typename H::mapped_type & get_or_call(H &ht,K const& k,F const& f) { diff --git a/utils/murmur_hash.h b/utils/murmur_hash.h deleted file mode 100644 index 6063d524..00000000 --- a/utils/murmur_hash.h +++ /dev/null @@ -1,186 +0,0 @@ -#ifndef _MURMUR_HASH_H_ -#define _MURMUR_HASH_H_ - -//NOTE: quite fast, nice collision properties, but endian dependent hash values - -#include "have_64_bits.h" -typedef uintptr_t MurmurInt; - -// MurmurHash2, by Austin Appleby - -static const uint32_t DEFAULT_SEED=2654435769U; - -#if HAVE_64_BITS -//MurmurInt MurmurHash(void const *key, int len, uint32_t seed=DEFAULT_SEED); - -inline uint64_t MurmurHash64( const void * key, int len, unsigned int seed=DEFAULT_SEED ) -{ - const uint64_t m = 0xc6a4a7935bd1e995ULL; - const int r = 47; - - uint64_t h = seed ^ (len * m); - - const uint64_t * data = (const uint64_t *)key; - const uint64_t * end = data + (len/8); - - while(data != end) - { - uint64_t k = *data++; - - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - - const unsigned char * data2 = (const unsigned char*)data; - - switch(len & 7) - { - case 7: h ^= uint64_t(data2[6]) << 48; - case 6: h ^= uint64_t(data2[5]) << 40; - case 5: h ^= uint64_t(data2[4]) << 32; - case 4: h ^= uint64_t(data2[3]) << 24; - case 3: h ^= uint64_t(data2[2]) << 16; - case 2: h ^= uint64_t(data2[1]) << 8; - case 1: h ^= uint64_t(data2[0]); - h *= m; - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - - return h; -} - -inline uint32_t MurmurHash32(void const *key, int len, uint32_t seed=DEFAULT_SEED) -{ - return (uint32_t) MurmurHash64(key,len,seed); -} - -inline MurmurInt MurmurHash(void const *key, int len, uint32_t seed=DEFAULT_SEED) -{ - return MurmurHash64(key,len,seed); -} - -#else -// 32-bit - -// Note - This code makes a few assumptions about how your machine behaves - -// 1. We can read a 4-byte value from any address without crashing -// 2. sizeof(int) == 4 -inline uint32_t MurmurHash32 ( const void * key, int len, uint32_t seed=DEFAULT_SEED) -{ - // 'm' and 'r' are mixing constants generated offline. - // They're not really 'magic', they just happen to work well. - - const uint32_t m = 0x5bd1e995; - const int r = 24; - - // Initialize the hash to a 'random' value - - uint32_t h = seed ^ len; - - // Mix 4 bytes at a time into the hash - - const unsigned char * data = (const unsigned char *)key; - - while(len >= 4) - { - uint32_t k = *(uint32_t *)data; - - k *= m; - k ^= k >> r; - k *= m; - - h *= m; - h ^= k; - - data += 4; - len -= 4; - } - - // Handle the last few bytes of the input array - - switch(len) - { - case 3: h ^= data[2] << 16; - case 2: h ^= data[1] << 8; - case 1: h ^= data[0]; - h *= m; - }; - - // Do a few final mixes of the hash to ensure the last few - // bytes are well-incorporated. - - h ^= h >> 13; - h *= m; - h ^= h >> 15; - - return h; -} - -inline MurmurInt MurmurHash ( const void * key, int len, uint32_t seed=DEFAULT_SEED) { - return MurmurHash32(key,len,seed); -} - -// 64-bit hash for 32-bit platforms - -inline uint64_t MurmurHash64 ( const void * key, int len, uint32_t seed=DEFAULT_SEED) -{ - const uint32_t m = 0x5bd1e995; - const int r = 24; - - uint32_t h1 = seed ^ len; - uint32_t h2 = 0; - - const uint32_t * data = (const uint32_t *)key; - - while(len >= 8) - { - uint32_t k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - - uint32_t k2 = *data++; - k2 *= m; k2 ^= k2 >> r; k2 *= m; - h2 *= m; h2 ^= k2; - len -= 4; - } - - if(len >= 4) - { - uint32_t k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - } - - switch(len) - { - case 3: h2 ^= ((unsigned char*)data)[2] << 16; - case 2: h2 ^= ((unsigned char*)data)[1] << 8; - case 1: h2 ^= ((unsigned char*)data)[0]; - h2 *= m; - }; - - h1 ^= h2 >> 18; h1 *= m; - h2 ^= h1 >> 22; h2 *= m; - h1 ^= h2 >> 17; h1 *= m; - h2 ^= h1 >> 19; h2 *= m; - - uint64_t h = h1; - - h = (h << 32) | h2; - - return h; -} - -#endif -//32bit - -#endif diff --git a/utils/murmur_hash3.cc b/utils/murmur_hash3.cc new file mode 100644 index 00000000..68a71d02 --- /dev/null +++ b/utils/murmur_hash3.cc @@ -0,0 +1,340 @@ +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. + +// Note - The x86 and x64 versions do _not_ produce the same results, as the +// algorithms are optimized for their respective platforms. You can still +// compile and run any of them on any platform, but your performance with the +// non-native version will be less than optimal. + +#include "murmur_hash3.h" + +//----------------------------------------------------------------------------- +// Platform-specific functions and macros + +// Microsoft Visual Studio + +#if defined(_MSC_VER) + +#define FORCE_INLINE __forceinline + +#include + +#define ROTL32(x,y) _rotl(x,y) +#define ROTL64(x,y) _rotl64(x,y) + +#define BIG_CONSTANT(x) (x) + +// Other compilers + +#else // defined(_MSC_VER) + +#define FORCE_INLINE inline __attribute__((always_inline)) + +namespace cdec { + +inline uint32_t rotl32 ( uint32_t x, int8_t r ) +{ + return (x << r) | (x >> (32 - r)); +} + +inline uint64_t rotl64 ( uint64_t x, int8_t r ) +{ + return (x << r) | (x >> (64 - r)); +} + +#define ROTL32(x,y) rotl32(x,y) +#define ROTL64(x,y) rotl64(x,y) + +#define BIG_CONSTANT(x) (x##LLU) + +#endif // !defined(_MSC_VER) + +//----------------------------------------------------------------------------- +// Block read - if your platform needs to do endian-swapping or can only +// handle aligned reads, do the conversion here + +FORCE_INLINE uint32_t getblock32 ( const uint32_t * p, int i ) +{ + return p[i]; +} + +FORCE_INLINE uint64_t getblock64 ( const uint64_t * p, int i ) +{ + return p[i]; +} + +//----------------------------------------------------------------------------- +// Finalization mix - force all bits of a hash block to avalanche + +FORCE_INLINE uint32_t fmix32 ( uint32_t h ) +{ + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + + return h; +} + +//---------- + +FORCE_INLINE uint64_t fmix64 ( uint64_t k ) +{ + k ^= k >> 33; + k *= BIG_CONSTANT(0xff51afd7ed558ccd); + k ^= k >> 33; + k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); + k ^= k >> 33; + + return k; +} + +//----------------------------------------------------------------------------- + +void MurmurHash3_x86_32 ( const void * key, int len, + uint32_t seed, void * out ) +{ + const uint8_t * data = (const uint8_t*)key; + const int nblocks = len / 4; + + uint32_t h1 = seed; + + const uint32_t c1 = 0xcc9e2d51; + const uint32_t c2 = 0x1b873593; + + //---------- + // body + + const uint32_t * blocks = (const uint32_t *)(data + nblocks*4); + + for(int i = -nblocks; i; i++) + { + uint32_t k1 = getblock32(blocks,i); + + k1 *= c1; + k1 = ROTL32(k1,15); + k1 *= c2; + + h1 ^= k1; + h1 = ROTL32(h1,13); + h1 = h1*5+0xe6546b64; + } + + //---------- + // tail + + const uint8_t * tail = (const uint8_t*)(data + nblocks*4); + + uint32_t k1 = 0; + + switch(len & 3) + { + case 3: k1 ^= tail[2] << 16; + case 2: k1 ^= tail[1] << 8; + case 1: k1 ^= tail[0]; + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; + + h1 = fmix32(h1); + + *(uint32_t*)out = h1; +} + +//----------------------------------------------------------------------------- + +void MurmurHash3_x86_128 ( const void * key, const int len, + uint32_t seed, void * out ) +{ + const uint8_t * data = (const uint8_t*)key; + const int nblocks = len / 16; + + uint32_t h1 = seed; + uint32_t h2 = seed; + uint32_t h3 = seed; + uint32_t h4 = seed; + + const uint32_t c1 = 0x239b961b; + const uint32_t c2 = 0xab0e9789; + const uint32_t c3 = 0x38b34ae5; + const uint32_t c4 = 0xa1e38b93; + + //---------- + // body + + const uint32_t * blocks = (const uint32_t *)(data + nblocks*16); + + for(int i = -nblocks; i; i++) + { + uint32_t k1 = getblock32(blocks,i*4+0); + uint32_t k2 = getblock32(blocks,i*4+1); + uint32_t k3 = getblock32(blocks,i*4+2); + uint32_t k4 = getblock32(blocks,i*4+3); + + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + + h1 = ROTL32(h1,19); h1 += h2; h1 = h1*5+0x561ccd1b; + + k2 *= c2; k2 = ROTL32(k2,16); k2 *= c3; h2 ^= k2; + + h2 = ROTL32(h2,17); h2 += h3; h2 = h2*5+0x0bcaa747; + + k3 *= c3; k3 = ROTL32(k3,17); k3 *= c4; h3 ^= k3; + + h3 = ROTL32(h3,15); h3 += h4; h3 = h3*5+0x96cd1c35; + + k4 *= c4; k4 = ROTL32(k4,18); k4 *= c1; h4 ^= k4; + + h4 = ROTL32(h4,13); h4 += h1; h4 = h4*5+0x32ac3b17; + } + + //---------- + // tail + + const uint8_t * tail = (const uint8_t*)(data + nblocks*16); + + uint32_t k1 = 0; + uint32_t k2 = 0; + uint32_t k3 = 0; + uint32_t k4 = 0; + + switch(len & 15) + { + case 15: k4 ^= tail[14] << 16; + case 14: k4 ^= tail[13] << 8; + case 13: k4 ^= tail[12] << 0; + k4 *= c4; k4 = ROTL32(k4,18); k4 *= c1; h4 ^= k4; + + case 12: k3 ^= tail[11] << 24; + case 11: k3 ^= tail[10] << 16; + case 10: k3 ^= tail[ 9] << 8; + case 9: k3 ^= tail[ 8] << 0; + k3 *= c3; k3 = ROTL32(k3,17); k3 *= c4; h3 ^= k3; + + case 8: k2 ^= tail[ 7] << 24; + case 7: k2 ^= tail[ 6] << 16; + case 6: k2 ^= tail[ 5] << 8; + case 5: k2 ^= tail[ 4] << 0; + k2 *= c2; k2 = ROTL32(k2,16); k2 *= c3; h2 ^= k2; + + case 4: k1 ^= tail[ 3] << 24; + case 3: k1 ^= tail[ 2] << 16; + case 2: k1 ^= tail[ 1] << 8; + case 1: k1 ^= tail[ 0] << 0; + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; h2 ^= len; h3 ^= len; h4 ^= len; + + h1 += h2; h1 += h3; h1 += h4; + h2 += h1; h3 += h1; h4 += h1; + + h1 = fmix32(h1); + h2 = fmix32(h2); + h3 = fmix32(h3); + h4 = fmix32(h4); + + h1 += h2; h1 += h3; h1 += h4; + h2 += h1; h3 += h1; h4 += h1; + + ((uint32_t*)out)[0] = h1; + ((uint32_t*)out)[1] = h2; + ((uint32_t*)out)[2] = h3; + ((uint32_t*)out)[3] = h4; +} + +//----------------------------------------------------------------------------- + +void MurmurHash3_x64_128 ( const void * key, const int len, + const uint32_t seed, void * out ) +{ + const uint8_t * data = (const uint8_t*)key; + const int nblocks = len / 16; + + uint64_t h1 = seed; + uint64_t h2 = seed; + + const uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5); + const uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f); + + //---------- + // body + + const uint64_t * blocks = (const uint64_t *)(data); + + for(int i = 0; i < nblocks; i++) + { + uint64_t k1 = getblock64(blocks,i*2+0); + uint64_t k2 = getblock64(blocks,i*2+1); + + k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; + + h1 = ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729; + + k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + h2 = ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5; + } + + //---------- + // tail + + const uint8_t * tail = (const uint8_t*)(data + nblocks*16); + + uint64_t k1 = 0; + uint64_t k2 = 0; + + switch(len & 15) + { + case 15: k2 ^= ((uint64_t)tail[14]) << 48; + case 14: k2 ^= ((uint64_t)tail[13]) << 40; + case 13: k2 ^= ((uint64_t)tail[12]) << 32; + case 12: k2 ^= ((uint64_t)tail[11]) << 24; + case 11: k2 ^= ((uint64_t)tail[10]) << 16; + case 10: k2 ^= ((uint64_t)tail[ 9]) << 8; + case 9: k2 ^= ((uint64_t)tail[ 8]) << 0; + k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + case 8: k1 ^= ((uint64_t)tail[ 7]) << 56; + case 7: k1 ^= ((uint64_t)tail[ 6]) << 48; + case 6: k1 ^= ((uint64_t)tail[ 5]) << 40; + case 5: k1 ^= ((uint64_t)tail[ 4]) << 32; + case 4: k1 ^= ((uint64_t)tail[ 3]) << 24; + case 3: k1 ^= ((uint64_t)tail[ 2]) << 16; + case 2: k1 ^= ((uint64_t)tail[ 1]) << 8; + case 1: k1 ^= ((uint64_t)tail[ 0]) << 0; + k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; h2 ^= len; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + h2 += h1; + + ((uint64_t*)out)[0] = h1; + ((uint64_t*)out)[1] = h2; +} + +//----------------------------------------------------------------------------- + +} // namespace cdec + + diff --git a/utils/murmur_hash3.h b/utils/murmur_hash3.h new file mode 100644 index 00000000..a125d775 --- /dev/null +++ b/utils/murmur_hash3.h @@ -0,0 +1,67 @@ +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. + +#ifndef _MURMURHASH3_H_ +#define _MURMURHASH3_H_ + +//----------------------------------------------------------------------------- +// Platform-specific functions and macros + +// Microsoft Visual Studio + +#if defined(_MSC_VER) && (_MSC_VER < 1600) + +typedef unsigned char uint8_t; +typedef unsigned int uint32_t; +typedef unsigned __int64 uint64_t; + +// Other compilers + +#else // defined(_MSC_VER) + +#include + +#endif // !defined(_MSC_VER) + +//----------------------------------------------------------------------------- + +namespace cdec { + +void MurmurHash3_x86_32 ( const void * key, int len, uint32_t seed, void * out ); + +void MurmurHash3_x86_128 ( const void * key, int len, uint32_t seed, void * out ); + +void MurmurHash3_x64_128 ( const void * key, int len, uint32_t seed, void * out ); + +namespace { + #ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wunused-function" + #endif + template inline void cdecMurmurHashNativeBackend(const void * key, int len, uint32_t seed, void * out) { + MurmurHash3_x86_128(key, len, seed, out); + } + template <> inline void cdecMurmurHashNativeBackend<4>(const void * key, int len, uint32_t seed, void * out) { + MurmurHash3_x64_128(key, len, seed, out); + } + #ifdef __clang__ + #pragma clang diagnostic pop + #endif +} // namespace + +inline uint64_t MurmurHash3_64(const void * key, int len, uint32_t seed) { + uint64_t out[2]; + cdecMurmurHashNativeBackend(key, len, seed, &out); + return out[0]; +} + +inline void MurmurHash3_128(const void * key, int len, uint32_t seed, void * out) { + cdecMurmurHashNativeBackend(key, len, seed, out); +} + +} + +//----------------------------------------------------------------------------- + +#endif // _MURMURHASH3_H_ -- cgit v1.2.3 From 649b5ffc7c81182ba39d338b11bfe2e9a05544b5 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 16 Apr 2014 00:36:30 -0400 Subject: fix for bug due to using wrong tree traversal --- decoder/t2s_test.cc | 8 ++- decoder/tree2string_translator.cc | 18 +++--- decoder/tree_fragment.cc | 12 ---- decoder/tree_fragment.h | 112 +++++++++++++++++++++++++++++++++++++- 4 files changed, 125 insertions(+), 25 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc index 3c46ea89..5ebb2662 100644 --- a/decoder/t2s_test.cc +++ b/decoder/t2s_test.cc @@ -15,8 +15,11 @@ BOOST_AUTO_TEST_CASE(TestTreeFragments) { vector aw, bw; cerr << "TREE1: " << tree << endl; cerr << "TREE2: " << tree2 << endl; - for (auto& sym : tree) + for (auto& sym : tree) { + if (cdec::IsLHS(sym)) cerr << "("; + cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym); + } for (auto& sym : tree2) if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym); BOOST_CHECK_EQUAL(a.size(), b.size()); @@ -38,11 +41,12 @@ BOOST_AUTO_TEST_CASE(TestTreeFragments) { if (cdec::IsFrontier(*it)) nts += "*"; } } + cerr << "Truncated: " << nts << endl; BOOST_CHECK_EQUAL(nts, "(S NP* VP*"); nts.clear(); int ntc = 0; - for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { + for (auto it = tree.bfs_begin(); it != tree.bfs_end(); ++it) { if (cdec::IsNT(*it)) { if (cdec::IsRHS(*it)) { ++ntc; diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 8d12d01d..3fbf1ee5 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -38,14 +38,13 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { // but it will not generate source strings correctly vector frhs; for (auto sym : rule_src) { + //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; cur = &cur->next[sym]; - if (sym) { - if (cdec::IsFrontier(sym)) { // frontier symbols -> variables - int nt = (sym & cdec::ALL_MASK); - frhs.push_back(-nt); - } else if (cdec::IsTerminal(sym)) { - frhs.push_back(sym); - } + if (cdec::IsFrontier(sym)) { // frontier symbols -> variables + int nt = (sym & cdec::ALL_MASK); + frhs.push_back(-nt); + } else if (cdec::IsTerminal(sym)) { + frhs.push_back(sym); } } os << '[' << TD::Convert(-lhs) << "] |||"; @@ -61,6 +60,7 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { os << " ||| " << line.substr(pos); TRulePtr rule(new TRule(os.str())); cur->rules.push_back(rule); + //cerr << "RULE: " << rule->AsString() << "\n\n"; } } @@ -82,7 +82,7 @@ struct ParserState { } vector future_work; int input_node_idx; // lhs of top level NT - Tree2StringGrammarNode* node; + Tree2StringGrammarNode* node; // pointer into grammar }; namespace std { @@ -239,11 +239,13 @@ struct Tree2StringTranslatorImpl { new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q if (unique.insert(new_s).second) q.push(new_s); } + //else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; } if (nit1 != s.node->next.end()) { //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; const ParserState new_s(++s.in_iter, &nit1->second, s); if (unique.insert(new_s).second) q.push(new_s); } + //else { cerr << "did not match " << TD::Convert(sym & cdec::ALL_MASK) << "\n"; } } else if (cdec::IsTerminal(sym)) { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 78a993b8..4d429f42 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -112,16 +112,4 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned *psymp = symp; } -BreadthFirstIterator TreeFragment::begin() const { - return BreadthFirstIterator(this, nodes.size() - 1); -} - -BreadthFirstIterator TreeFragment::begin(unsigned node_idx) const { - return BreadthFirstIterator(this, node_idx); -} - -BreadthFirstIterator TreeFragment::end() const { - return BreadthFirstIterator(this); -} - } diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index f1c4c106..4a704bc4 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -11,6 +11,7 @@ namespace cdec { class BreadthFirstIterator; +class DepthFirstIterator; static const unsigned LHS_BIT = 0x10000000u; static const unsigned RHS_BIT = 0x20000000u; @@ -53,16 +54,21 @@ class TreeFragment { // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar)))) explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false); void DebugRec(unsigned cur, std::ostream* out) const; - typedef BreadthFirstIterator iterator; + typedef DepthFirstIterator iterator; typedef ptrdiff_t difference_type; typedef unsigned value_type; typedef const unsigned * pointer; typedef const unsigned & reference; + // default iterator is DFS iterator begin() const; iterator begin(unsigned node_idx) const; iterator end() const; + BreadthFirstIterator bfs_begin() const; + BreadthFirstIterator bfs_begin(unsigned node_idx) const; + BreadthFirstIterator bfs_end() const; + private: // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built @@ -78,14 +84,106 @@ class TreeFragment { struct TFIState { TFIState() : node(), rhspos(), state() {} - TFIState(unsigned n, unsigned p, unsigned s) : node(n), rhspos(p), state(s) {} + TFIState(unsigned n, int p, unsigned s) : node(n), rhspos(p), state(s) {} bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; } bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; } unsigned short node; - unsigned short rhspos; + short rhspos; unsigned char state; }; +class DepthFirstIterator : public std::iterator { + const TreeFragment* tf_; + std::deque q_; + unsigned sym; + public: + DepthFirstIterator() : tf_(), sym() {} + // used for begin + explicit DepthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) { + q_.push_back(TFIState(node_idx, -1, 0)); + Stage(); + q_.back().state++; + } + // used for end + explicit DepthFirstIterator(const TreeFragment* tf) : tf_(tf) {} + const unsigned& operator*() const { return sym; } + const unsigned* operator->() const { return &sym; } + bool operator==(const DepthFirstIterator& other) const { + return (tf_ == other.tf_) && (q_ == other.q_); + } + bool operator!=(const DepthFirstIterator& other) const { + return (tf_ != other.tf_) || (q_ != other.q_); + } + unsigned node_idx() const { return q_.front().node; } + const DepthFirstIterator& operator++() { + TFIState& s = q_.back(); + if (s.state == 0) { + Stage(); + s.state++; + } else if (s.state == 1) { + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos >= len) { + q_.pop_back(); + while (!q_.empty()) { + TFIState& s = q_.back(); + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos < len) break; + q_.pop_back(); + } + } + Stage(); + } + return *this; + } + DepthFirstIterator operator++(int) { + DepthFirstIterator res = *this; + ++(*this); + return res; + } + // tell iterator not to explore the subtree rooted at sym + // should only be called once per NT symbol encountered + const DepthFirstIterator& truncate() { + assert(IsRHS(sym)); + sym &= ALL_MASK; + sym |= FRONTIER_BIT; + q_.pop_back(); + return *this; + } + unsigned child_node() const { + assert(IsRHS(sym)); + return q_.back().node; + } + DepthFirstIterator remainder() const { + assert(IsRHS(sym)); + return DepthFirstIterator(tf_, q_.back()); + } + bool at_end() const { + return q_.empty(); + } + private: + void Stage() { + if (q_.empty()) return; + const TFIState& s = q_.back(); + if (s.state == 0) { + sym = (tf_->nodes[s.node].lhs & ALL_MASK) | LHS_BIT; + } else if (s.state == 1) { + sym = tf_->nodes[s.node].rhs[s.rhspos]; + if (IsRHS(sym)) { + q_.push_back(TFIState(sym & ALL_MASK, -1, 0)); + sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT; + } + } + } + + // used by remainder + DepthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) { + q_.push_back(s); + Stage(); + } +}; + class BreadthFirstIterator : public std::iterator { const TreeFragment* tf_; std::deque q_; @@ -172,6 +270,14 @@ class BreadthFirstIterator : public std::iterator Date: Fri, 25 Apr 2014 02:01:59 -0400 Subject: support for multiple xRs states in parser (not yet in rules) --- decoder/tree2string_translator.cc | 121 ++++++++++++++++++++++++++------------ decoder/trule.h | 3 + training/utils/grammar_convert.cc | 5 +- 3 files changed, 89 insertions(+), 40 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 3fbf1ee5..29caaf8f 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -30,12 +30,15 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { unsigned xc = 0; while (line[pos - 1] == ' ') { --pos; xc++; } cdec::TreeFragment rule_src(line.substr(0, pos), true); - Tree2StringGrammarNode* cur = root; + // TODO transducer_state should (optionally?) be read from input + const unsigned transducer_state = 0; + Tree2StringGrammarNode* cur = &root->next[transducer_state]; ostringstream os; int lhs = -(rule_src.root & cdec::ALL_MASK); // build source RHS for SCFG projection // TODO - this is buggy - it will generate a well-formed SCFG rule - // but it will not generate source strings correctly + // so it will not generate source strings correctly + // it will, however, generate target translations appropriately vector frhs; for (auto sym : rule_src) { //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; @@ -59,40 +62,65 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { while(line[pos] == ' ') { ++pos; } os << " ||| " << line.substr(pos); TRulePtr rule(new TRule(os.str())); + // TODO the transducer_state you end up in after using this rule (for each NT) + // needs to be read and encoded somehow in the rule (for use XXX) cur->rules.push_back(rule); //cerr << "RULE: " << rule->AsString() << "\n\n"; } } +// represents where in an input parse tree the transducer must continue +// and what state it is in +struct TransducerState { + TransducerState() : input_node_idx(), transducer_state() {} + TransducerState(unsigned n, unsigned q) : input_node_idx(n), transducer_state(q) {} + bool operator==(const TransducerState& o) const { + return input_node_idx == o.input_node_idx && + transducer_state == o.transducer_state; + } + unsigned input_node_idx; + unsigned transducer_state; +}; + +// represents the state of the composition algorithm struct ParserState { ParserState() : in_iter(), node() {} cdec::TreeFragment::iterator in_iter; - ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n) : + ParserState(const cdec::TreeFragment::iterator& it, unsigned q, Tree2StringGrammarNode* n) : in_iter(it), - input_node_idx(it.node_idx()), + task(it.node_idx(), q), node(n) {} ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : in_iter(it), future_work(p.future_work), - input_node_idx(p.input_node_idx), + task(p.task), node(n) {} bool operator==(const ParserState& o) const { - return node == o.node && input_node_idx == o.input_node_idx && + return node == o.node && task == o.task && future_work == o.future_work && in_iter == o.in_iter; } - vector future_work; - int input_node_idx; // lhs of top level NT - Tree2StringGrammarNode* node; // pointer into grammar + vector future_work; + TransducerState task; // subtree root where and in what state did the transducer start? + Tree2StringGrammarNode* node; // pointer into grammar trie }; namespace std { + template<> + struct hash { + size_t operator()(const TransducerState& q) const { + size_t h = boost::hash_value(q.transducer_state); + boost::hash_combine(h, boost::hash_value(q.input_node_idx)); + return h; + } + }; template<> struct hash { size_t operator()(const ParserState& s) const { - size_t h = boost::hash_range(s.future_work.begin(), s.future_work.end()); - boost::hash_combine(h, boost::hash_value(s.node)); - boost::hash_combine(h, boost::hash_value(s.input_node_idx)); - //boost::hash_combine(h, ); + size_t h = boost::hash_value(s.node); + for (auto& w : s.future_work) + boost::hash_combine(h, hash()(w)); + boost::hash_combine(h, hash()(s.task)); + // TODO hash with iterator return h; } }; @@ -144,6 +172,9 @@ struct Tree2StringTranslatorImpl { os << ')'; cdec::TreeFragment rule_src(os.str(), true); Tree2StringGrammarNode* cur = root.back().get(); + // do we need all transducer states here??? a list??? no pass through rules??? + unsigned transducer_state = 0; + cur = &cur->next[transducer_state]; for (auto sym : rule_src) cur = &cur->next[sym]; TRulePtr rule(new TRule(rhse, rhsf, lhs)); @@ -167,15 +198,19 @@ struct Tree2StringTranslatorImpl { if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; hg.ReserveNodes(input_tree.nodes.size()); - vector tree2hg(input_tree.nodes.size() + 1, -1); + unordered_map x2hg(input_tree.nodes.size() * 5); queue q; unordered_set unique; // only create items one time for (auto& g : root) { - q.push(ParserState(input_tree.begin(), g.get())); - unique.insert(q.back()); + unsigned q_0 = 0; // TODO initialize q_0 properly once multi-state transducers are supported + auto rit = g->next.find(q_0); + if (rit != g->next.end()) { // does this g have this transducer state? + q.push(ParserState(input_tree.begin(), q_0, &rit->second)); + unique.insert(q.back()); + } } if (q.size() == 0) return false; - unsigned tree_top = q.front().input_node_idx; + const TransducerState tree_top = q.front().task; while(!q.empty()) { ParserState& s = q.front(); @@ -183,21 +218,24 @@ struct Tree2StringTranslatorImpl { //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { - int& node_id = tree2hg[s.input_node_idx]; - if (node_id < 0) { - HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK)); - new_node->node_hash = s.input_node_idx + 1; - node_id = new_node->id_; + auto it = x2hg.find(s.task); + if (it == x2hg.end()) { + // TODO create composite state symbol that encodes transducer state type? + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.task.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = std::hash()(s.task); + it = x2hg.insert(make_pair(s.task, new_node->id_)).first; } + const unsigned node_id = it->second; TailNodeVector tail; - for (auto n : s.future_work) { - int& nix = tree2hg[n]; - if (nix < 0) { - HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK)); - new_node->node_hash = n + 1; - nix = new_node->id_; + for (const auto& n : s.future_work) { + auto it = x2hg.find(n); + if (it == x2hg.end()) { + // TODO create composite state symbol that encodes transducer state type? + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = std::hash()(n); + it = x2hg.insert(make_pair(n, new_node->id_)).first; } - tail.push_back(nix); + tail.push_back(it->second); } for (auto& r : s.node->rules) { assert(tail.size() == r->Arity()); @@ -206,11 +244,14 @@ struct Tree2StringTranslatorImpl { // TODO: set i and j hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); } - for (auto n : s.future_work) { - const auto it = input_tree.begin(n); // start tree iterator at node n + for (const auto& n : s.future_work) { + const auto it = input_tree.begin(n.input_node_idx); // start tree iterator at node n for (auto& g : root) { - ParserState s(it, g.get()); - if (unique.insert(s).second) q.push(s); + auto rit = g->next.find(n.transducer_state); + if (rit != g->next.end()) { // does this g have this transducer state? + const ParserState s(it, n.transducer_state, &rit->second); + if (unique.insert(s).second) q.push(s); + } } } } else { @@ -234,9 +275,13 @@ struct Tree2StringTranslatorImpl { if (nit2 != s.node->next.end()) { //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; ++var; - const unsigned new_work = s.in_iter.child_node(); + // TODO: find out from rule what the new target state is (the 0 in the next line) + // if it is associated with the rule, we won't know until we match the whole input + // so the 0 may be okay (if this is the case, which is probably the easiest thing, + // then the state must be dealt with when the future work becomes real work) + const TransducerState new_task(s.in_iter.child_node(), 0); ParserState new_s(var, &nit2->second, s); - new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q + new_s.future_work.push_back(new_task); // if this traversal of the input succeeds, future_work goes on the q if (unique.insert(new_s).second) q.push(new_s); } //else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; } @@ -259,10 +304,10 @@ struct Tree2StringTranslatorImpl { } q.pop(); } - int goal = tree2hg[tree_top]; - if (goal < 0) return false; + const auto goal_it = x2hg.find(tree_top); + if (goal_it == x2hg.end()) return false; //cerr << "Goal node: " << goal << endl; - hg.TopologicallySortNodesAndEdges(goal); + hg.TopologicallySortNodesAndEdges(goal_it->second); hg.Reweight(weights); // there might be nodes that cannot be derived diff --git a/decoder/trule.h b/decoder/trule.h index 7dced5a1..cc370757 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -42,6 +42,9 @@ class TRule { scores_.set_value(feat_ids[i], feat_vals[i]); } + TRule(WordID lhs, const WordID* src, int src_size, const WordID* trg, int trg_size, int arity, int pi, int pj) : + e_(trg, trg + trg_size), f_(src, src + src_size), lhs_(lhs), arity_(arity), prev_i(pi), prev_j(pj) {} + bool IsGoal() const; explicit TRule(const std::vector& e) : e_(e), lhs_(0), prev_i(-1), prev_j(-1) {} diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc index 607a7cb9..58d1957c 100644 --- a/training/utils/grammar_convert.cc +++ b/training/utils/grammar_convert.cc @@ -292,10 +292,10 @@ int main(int argc, char **argv) { int lc = 0; Hypergraph hg; map lhs2node; + string line; while(*in) { - string line; + getline(*in,line); ++lc; - getline(*in, line); if (is_json_input) { if (line.empty() || line[0] == '#') continue; string ref; @@ -342,6 +342,7 @@ int main(int argc, char **argv) { edge->feature_values_ = tr->scores_; Hypergraph::Node* node = &hg.nodes_[head]; hg.ConnectEdgeToHeadNode(edge, node); + node->node_hash = lc; } } } -- cgit v1.2.3 From 18a1d98f5bd60ea195a6c3aaf8feb740da752f7e Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 25 Apr 2014 13:20:06 -0400 Subject: fix tree-to-string forest so it works with cube pruning assumptions --- decoder/tree2string_translator.cc | 17 ++++++++++++++++- tests/system_tests/t2s/gold.statistics | 4 ++-- 2 files changed, 18 insertions(+), 3 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 29caaf8f..c353f7ca 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -126,6 +126,17 @@ namespace std { }; }; +void AddDummyGoalNode(Hypergraph* hg) { + static const int kGOAL = -TD::Convert("Goal"); + static TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X] ||| [1]")); + unsigned old_goal_node_idx = hg->nodes_.size() - 1; + HG::Node* goal_node = hg->AddNode(kGOAL); + goal_node->node_hash = 1; + TailNodeVector tail(1, old_goal_node_idx); + HG::Edge* new_edge = hg->AddEdge(kGOAL_RULE, tail); + hg->ConnectEdgeToHeadNode(new_edge, goal_node); +} + struct Tree2StringTranslatorImpl { vector> root; bool add_pass_through_rules; @@ -308,14 +319,18 @@ struct Tree2StringTranslatorImpl { if (goal_it == x2hg.end()) return false; //cerr << "Goal node: " << goal << endl; hg.TopologicallySortNodesAndEdges(goal_it->second); - hg.Reweight(weights); // there might be nodes that cannot be derived // the following takes care of them vector prune(hg.edges_.size(), false); hg.PruneEdges(prune, true); if (hg.edges_.size() == 0) return false; + // rescoring assumes the goal edge is arity 1 (code laziness), add that here + AddDummyGoalNode(&hg); + + hg.Reweight(weights); //hg.PrintGraphviz(); + minus_lm_forest->swap(hg); return true; } diff --git a/tests/system_tests/t2s/gold.statistics b/tests/system_tests/t2s/gold.statistics index 452cc93e..5778a24c 100644 --- a/tests/system_tests/t2s/gold.statistics +++ b/tests/system_tests/t2s/gold.statistics @@ -1,3 +1,3 @@ --lm_nodes 6 --lm_edges 8 +-lm_nodes 7 +-lm_edges 9 -lm_paths 4 -- cgit v1.2.3 From 77938f83c0b450a6a9229414dc415608fde5bfb9 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 8 May 2014 12:11:08 -0400 Subject: turn of span filtering --- decoder/hg_intersect.cc | 2 +- decoder/tree2string_translator.cc | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc index 31a9a1ce..02f5a401 100644 --- a/decoder/hg_intersect.cc +++ b/decoder/hg_intersect.cc @@ -92,7 +92,7 @@ bool Intersect(const Lattice& target, Hypergraph* hg) { return FastLinearIntersect(target, hg); vector rem(hg->edges_.size(), false); - const RuleFilter filter(target, 15); // TODO make configurable + const RuleFilter filter(target, 9999); // TODO make configurable for (unsigned i = 0; i < rem.size(); ++i) rem[i] = filter(*hg->edges_[i].rule_); hg->PruneEdges(rem, true); diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index c353f7ca..fafb0d97 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -128,12 +128,14 @@ namespace std { void AddDummyGoalNode(Hypergraph* hg) { static const int kGOAL = -TD::Convert("Goal"); - static TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X] ||| [1]")); unsigned old_goal_node_idx = hg->nodes_.size() - 1; + int old_goal_cat = hg->nodes_[old_goal_node_idx].cat_; + TRulePtr goal_rule(new TRule("[Goal] ||| [X] ||| [1]")); + goal_rule->f_[0] = old_goal_cat; HG::Node* goal_node = hg->AddNode(kGOAL); goal_node->node_hash = 1; TailNodeVector tail(1, old_goal_node_idx); - HG::Edge* new_edge = hg->AddEdge(kGOAL_RULE, tail); + HG::Edge* new_edge = hg->AddEdge(goal_rule, tail); hg->ConnectEdgeToHeadNode(new_edge, goal_node); } -- cgit v1.2.3 From 80778dd022150ea4d654cd1952a3f09684a5cfbb Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 8 May 2014 21:20:46 -0400 Subject: better features --- decoder/tree2string_translator.cc | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index fafb0d97..38daeeb5 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -158,7 +158,11 @@ struct Tree2StringTranslatorImpl { } void CreatePassThroughRules(const cdec::TreeFragment& tree) { + static const int kFIDlex = FD::Convert("PassThrough_Lexical"); + static const int kFIDabs = FD::Convert("PassThrough_Abstract"); + static const int kFIDmix = FD::Convert("PassThrough_Mix"); static const int kFID = FD::Convert("PassThrough"); + static unordered_map pntfid; root.resize(root.size() + 1); root.back().reset(new Tree2StringGrammarNode); ++remove_grammars; @@ -167,14 +171,24 @@ struct Tree2StringTranslatorImpl { vector rhse, rhsf; int ntc = 0; int lhs = -(prod.lhs & cdec::ALL_MASK); + int &ntfid = pntfid[lhs]; + if (!ntfid) { + ostringstream fos; + fos << "PassThrough:" << TD::Convert(-lhs); + ntfid = FD::Convert(fos.str()); + } + bool has_lex = false; + bool has_nt = false; os << '(' << TD::Convert(-lhs); for (auto& sym : prod.rhs) { os << ' '; if (cdec::IsTerminal(sym)) { + has_lex = true; os << TD::Convert(sym); rhse.push_back(sym); rhsf.push_back(sym); } else { + has_nt = true; unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; os << '[' << TD::Convert(id) << ']'; rhsf.push_back(-id); @@ -192,7 +206,12 @@ struct Tree2StringTranslatorImpl { cur = &cur->next[sym]; TRulePtr rule(new TRule(rhse, rhsf, lhs)); rule->ComputeArity(); + rule->scores_.set_value(ntfid, 1.0); rule->scores_.set_value(kFID, 1.0); + if (has_lex && has_nt) + rule->scores_.set_value(kFIDmix, 1.0); + else if (has_lex) rule->scores_.set_value(kFIDlex, 1.0); + else if (has_nt) rule->scores_.set_value(kFIDabs, 1.0); cur->rules.push_back(rule); } } -- cgit v1.2.3 From c11c7af0746edbceec5ad49e2e8efeb34bceaa6b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 9 May 2014 00:01:40 -0400 Subject: remove fixed bug warning --- decoder/tree2string_translator.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 38daeeb5..5d7aa5e2 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -21,6 +21,8 @@ struct Tree2StringGrammarNode { vector rules; }; +// this needs to be rewritten so it is fast and checks errors well +// use a lexer probably void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { string line; while(getline(*in, line)) { @@ -36,10 +38,8 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { ostringstream os; int lhs = -(rule_src.root & cdec::ALL_MASK); // build source RHS for SCFG projection - // TODO - this is buggy - it will generate a well-formed SCFG rule - // so it will not generate source strings correctly - // it will, however, generate target translations appropriately vector frhs; + // we traverse the rule_src in left to right, DFS order for (auto sym : rule_src) { //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; cur = &cur->next[sym]; @@ -48,7 +48,7 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { frhs.push_back(-nt); } else if (cdec::IsTerminal(sym)) { frhs.push_back(sym); - } + } // else internal NT, nothing to do } os << '[' << TD::Convert(-lhs) << "] |||"; for (auto x : frhs) { -- cgit v1.2.3 From bb3f703d572e9f4a4b971bfa2483e0caf060587d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 May 2014 17:46:20 -0400 Subject: stub for t2t translator --- decoder/decoder.cc | 8 +++++--- decoder/translator.h | 3 ++- decoder/tree2string_translator.cc | 12 +++++++----- 3 files changed, 14 insertions(+), 9 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 41f36822..6783cad0 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -490,8 +490,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } formalism = LowercaseString(str("formalism",conf)); - if (formalism != "t2s" && formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 't2s', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; + if (formalism != "t2s" && formalism != "t2t" && formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 't2s', 't2t', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -627,7 +627,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); else if (formalism == "t2s") - translator.reset(new Tree2StringTranslator(conf)); + translator.reset(new Tree2StringTranslator(conf, false)); + else if (formalism == "t2t") + translator.reset(new Tree2StringTranslator(conf, true)); else if (formalism == "fst") translator.reset(new FSTTranslator(conf)); else if (formalism == "pb") diff --git a/decoder/translator.h b/decoder/translator.h index 72b2f0b0..ba218a0b 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -101,7 +101,8 @@ class RescoreTranslator : public Translator { class Tree2StringTranslatorImpl; class Tree2StringTranslator : public Translator { public: - Tree2StringTranslator(const boost::program_options::variables_map& conf); + Tree2StringTranslator(const boost::program_options::variables_map& conf, + bool has_multiple_states); virtual std::string GetDecoderType() const; protected: bool TranslateImpl(const std::string& src, diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 5d7aa5e2..101ed21c 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -23,7 +23,7 @@ struct Tree2StringGrammarNode { // this needs to be rewritten so it is fast and checks errors well // use a lexer probably -void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { +void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { string line; while(getline(*in, line)) { size_t pos = line.find("|||"); @@ -143,7 +143,8 @@ struct Tree2StringTranslatorImpl { vector> root; bool add_pass_through_rules; unsigned remove_grammars; - Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) : + Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf, + bool has_multiple_states) : add_pass_through_rules(conf.count("add_pass_through_rules")) { if (conf.count("grammar")) { const vector gf = conf["grammar"].as>(); @@ -152,7 +153,7 @@ struct Tree2StringTranslatorImpl { for (auto& f : gf) { ReadFile rf(f); root[gc].reset(new Tree2StringGrammarNode); - ReadTree2StringGrammar(rf.stream(), &*root[gc++]); + ReadTree2StringGrammar(rf.stream(), &*root[gc++], has_multiple_states); } } } @@ -357,8 +358,9 @@ struct Tree2StringTranslatorImpl { } }; -Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf) : - pimpl_(new Tree2StringTranslatorImpl(conf)) {} +Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf, + bool has_multiple_states) : + pimpl_(new Tree2StringTranslatorImpl(conf, has_multiple_states)) {} bool Tree2StringTranslator::TranslateImpl(const string& input, SentenceMetadata* smeta, -- cgit v1.2.3 From d1de1bb3fe271ce2f90a5885d70eddb859dd2354 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 May 2014 17:53:53 -0400 Subject: check for duplicates when creating pass through rules --- decoder/tree2string_translator.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 101ed21c..8d624820 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -167,9 +167,8 @@ struct Tree2StringTranslatorImpl { root.resize(root.size() + 1); root.back().reset(new Tree2StringGrammarNode); ++remove_grammars; + unordered_set,boost::hash>> unique_rule_check; for (auto& prod : tree.nodes) { - ostringstream os; - vector rhse, rhsf; int ntc = 0; int lhs = -(prod.lhs & cdec::ALL_MASK); int &ntfid = pntfid[lhs]; @@ -178,8 +177,16 @@ struct Tree2StringTranslatorImpl { fos << "PassThrough:" << TD::Convert(-lhs); ntfid = FD::Convert(fos.str()); } + + // check for duplicate rule in tree + vector key = prod.rhs; + key.push_back(prod.lhs); + if (!unique_rule_check.insert(key).second) continue; + bool has_lex = false; bool has_nt = false; + vector rhse, rhsf; + ostringstream os; os << '(' << TD::Convert(-lhs); for (auto& sym : prod.rhs) { os << ' '; -- cgit v1.2.3 From ddd99766d069c5ce5d41da304d7ba613657ee564 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 May 2014 19:04:31 -0400 Subject: fix unique check --- decoder/tree2string_translator.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 8d624820..7b37887e 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -167,7 +167,7 @@ struct Tree2StringTranslatorImpl { root.resize(root.size() + 1); root.back().reset(new Tree2StringGrammarNode); ++remove_grammars; - unordered_set,boost::hash>> unique_rule_check; + unordered_set,boost::hash>> unique_rule_check; for (auto& prod : tree.nodes) { int ntc = 0; int lhs = -(prod.lhs & cdec::ALL_MASK); @@ -179,9 +179,8 @@ struct Tree2StringTranslatorImpl { } // check for duplicate rule in tree - vector key = prod.rhs; + vector key; key.push_back(prod.lhs); - if (!unique_rule_check.insert(key).second) continue; bool has_lex = false; bool has_nt = false; @@ -195,16 +194,19 @@ struct Tree2StringTranslatorImpl { os << TD::Convert(sym); rhse.push_back(sym); rhsf.push_back(sym); + key.push_back(sym); } else { has_nt = true; unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; os << '[' << TD::Convert(id) << ']'; rhsf.push_back(-id); rhse.push_back(-ntc); + key.push_back(-id); ++ntc; } } os << ')'; + if (!unique_rule_check.insert(key).second) continue; cdec::TreeFragment rule_src(os.str(), true); Tree2StringGrammarNode* cur = root.back().get(); // do we need all transducer states here??? a list??? no pass through rules??? -- cgit v1.2.3 From 2edf6020d71b4f728a473780a8f109bbb98efe2c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 24 May 2014 02:09:49 -0400 Subject: support per sentence tree-to-string grammars --- decoder/tree2string_translator.cc | 46 +++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 7 deletions(-) (limited to 'decoder/tree2string_translator.cc') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 7b37887e..b5b47d5d 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -5,6 +5,7 @@ #include #include #include +#include "fast_lexical_cast.hpp" #include "tree_fragment.h" #include "translator.h" #include "hg.h" @@ -23,7 +24,7 @@ struct Tree2StringGrammarNode { // this needs to be rewritten so it is fast and checks errors well // use a lexer probably -void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { +static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { string line; while(getline(*in, line)) { size_t pos = line.find("|||"); @@ -142,10 +143,12 @@ void AddDummyGoalNode(Hypergraph* hg) { struct Tree2StringTranslatorImpl { vector> root; bool add_pass_through_rules; + bool has_multiple_states; unsigned remove_grammars; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf, bool has_multiple_states) : - add_pass_through_rules(conf.count("add_pass_through_rules")) { + add_pass_through_rules(conf.count("add_pass_through_rules")), + has_multiple_states(has_multiple_states) { if (conf.count("grammar")) { const vector gf = conf["grammar"].as>(); root.resize(gf.size()); @@ -158,6 +161,15 @@ struct Tree2StringTranslatorImpl { } } + // loads a per-sentence grammar + void LoadSupplementalGrammar(const string& gfile) { + root.resize(root.size() + 1); + root.back().reset(new Tree2StringGrammarNode); + ++remove_grammars; + ReadFile rf(gfile); + ReadTree2StringGrammar(rf.stream(), root.back().get(), has_multiple_states); + } + void CreatePassThroughRules(const cdec::TreeFragment& tree) { static const int kFIDlex = FD::Convert("PassThrough_Lexical"); static const int kFIDabs = FD::Convert("PassThrough_Abstract"); @@ -227,7 +239,7 @@ struct Tree2StringTranslatorImpl { } void RemoveGrammars() { - assert(remove_grammars < root.size()); + assert(remove_grammars <= root.size()); root.resize(root.size() - remove_grammars); } @@ -235,7 +247,6 @@ struct Tree2StringTranslatorImpl { SentenceMetadata* smeta, const vector& weights, Hypergraph* minus_lm_forest) { - remove_grammars = 0; cdec::TreeFragment input_tree(input, false); if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; @@ -371,6 +382,30 @@ Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::varia bool has_multiple_states) : pimpl_(new Tree2StringTranslatorImpl(conf, has_multiple_states)) {} +void Tree2StringTranslator::ProcessMarkupHintsImpl(const map& kv) { + pimpl_->remove_grammars = 0; + if (kv.find("grammar0") != kv.end()) { + cerr << "SGML tag grammar0 is not expected (order is: grammar, grammar1, grammar2, ...)\n"; + abort(); + } + unsigned gc = 0; + set loaded; + while(true) { + string gkey = "grammar"; + if (gc > 0) gkey += boost::lexical_cast(gc); + ++gc; + map::const_iterator it = kv.find(gkey); + if (it == kv.end()) break; + const string& gfile = it->second; + if (loaded.count(gfile) == 1) { + cerr << "Attempting to load " << gfile << " twice!\n"; + abort(); + } + loaded.insert(gfile); + pimpl_->LoadSupplementalGrammar(gfile); + } +} + bool Tree2StringTranslator::TranslateImpl(const string& input, SentenceMetadata* smeta, const vector& weights, @@ -378,9 +413,6 @@ bool Tree2StringTranslator::TranslateImpl(const string& input, return pimpl_->Translate(input, smeta, weights, minus_lm_forest); } -void Tree2StringTranslator::ProcessMarkupHintsImpl(const map& kv) { -} - void Tree2StringTranslator::SentenceCompleteImpl() { pimpl_->RemoveGrammars(); } -- cgit v1.2.3