From 8372086f2fc4bd765fdd05e8cf95faeb147a6587 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 30 Mar 2014 23:50:17 -0400 Subject: almost complete tree to string translator --- decoder/Makefile.am | 3 + decoder/decoder.cc | 6 +- decoder/t2s_test.cc | 110 ++++++++++++++++++++++++++++++++++ decoder/tree2string_translator.cc | 120 +++++++++++++++++++++++++++++++++----- decoder/tree_fragment.cc | 14 +++-- decoder/tree_fragment.h | 109 ++++++++++++++++++++++++---------- 6 files changed, 311 insertions(+), 51 deletions(-) create mode 100644 decoder/t2s_test.cc (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 7481192b..5c91fe65 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -4,9 +4,12 @@ noinst_PROGRAMS = \ trule_test \ hg_test \ parser_test \ + t2s_test \ grammar_test TESTS = trule_test parser_test grammar_test hg_test +t2s_test_SOURCES = t2s_test.cc +t2s_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a parser_test_SOURCES = parser_test.cc parser_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a grammar_test_SOURCES = grammar_test.cc diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 31049216..43e2640d 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -490,8 +490,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } formalism = LowercaseString(str("formalism",conf)); - if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; + if (formalism != "t2s" && formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 't2s', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -626,6 +626,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream // set up translation back end if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); + else if (formalism == "t2s") + translator.reset(new Tree2StringTranslator(conf)); else if (formalism == "fst") translator.reset(new FSTTranslator(conf)); else if (formalism == "pb") diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc new file mode 100644 index 00000000..3c46ea89 --- /dev/null +++ b/decoder/t2s_test.cc @@ -0,0 +1,110 @@ +#include "tree_fragment.h" + +#define BOOST_TEST_MODULE T2STest +#include +#include +#include +#include "tdict.h" + +using namespace std; + +BOOST_AUTO_TEST_CASE(TestTreeFragments) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + cdec::TreeFragment tree2("(S (NP (DT a) (NN cat)) (VP (V ate) (NP (DT the) (NN cake pie))))"); + vector a, b; + vector aw, bw; + cerr << "TREE1: " << tree << endl; + cerr << "TREE2: " << tree2 << endl; + for (auto& sym : tree) + if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym); + for (auto& sym : tree2) + if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym); + BOOST_CHECK_EQUAL(a.size(), b.size()); + BOOST_CHECK_EQUAL(aw.size() + 1, bw.size()); + BOOST_CHECK_EQUAL(aw.size(), 5); + BOOST_CHECK_EQUAL(TD::GetString(aw), "the boy saw a cat"); + BOOST_CHECK_EQUAL(TD::GetString(bw), "a cat ate the cake pie"); + if (a != b) { + BOOST_CHECK_EQUAL(1,2); + } + + string nts; + for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { + if (cdec::IsNT(*it)) { + if (cdec::IsRHS(*it)) it.truncate(); + if (nts.size()) nts += " "; + if (cdec::IsLHS(*it)) nts += "("; + nts += TD::Convert(*it & cdec::ALL_MASK); + if (cdec::IsFrontier(*it)) nts += "*"; + } + } + BOOST_CHECK_EQUAL(nts, "(S NP* VP*"); + + nts.clear(); + int ntc = 0; + for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { + if (cdec::IsNT(*it)) { + if (cdec::IsRHS(*it)) { + ++ntc; + if (ntc > 1) it.truncate(); + } + if (nts.size()) nts += " "; + if (cdec::IsLHS(*it)) nts += "("; + nts += TD::Convert(*it & cdec::ALL_MASK); + if (cdec::IsFrontier(*it)) nts += "*"; + } + } + BOOST_CHECK_EQUAL(nts, "(S NP VP* (NP DT* NN*"); +} + +BOOST_AUTO_TEST_CASE(TestSharing) { + cdec::TreeFragment rule1("(S [NP] [VP])", true); + cdec::TreeFragment rule2("(S [NP] (VP [V] [NP]))", true); + string r1,r2; + for (auto sym : rule1) { + if (r1.size()) r1 += " "; + if (cdec::IsLHS(sym)) r1 += "("; + r1 += TD::Convert(sym & cdec::ALL_MASK); + if (cdec::IsFrontier(sym)) r1 += "*"; + } + for (auto sym : rule2) { + if (r2.size()) r2 += " "; + if (cdec::IsLHS(sym)) r2 += "("; + r2 += TD::Convert(sym & cdec::ALL_MASK); + if (cdec::IsFrontier(sym)) r2 += "*"; + } + cerr << rule1 << endl; + cerr << r1 << endl; + cerr << rule2 << endl; + cerr << r2 << endl; + BOOST_CHECK_EQUAL(r1, "(S NP* VP*"); + BOOST_CHECK_EQUAL(r2, "(S NP* VP (VP V* NP*"); +} + +BOOST_AUTO_TEST_CASE(TestEndInvariants) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + BOOST_CHECK(tree.end().at_end()); + BOOST_CHECK(!tree.begin().at_end()); +} + +BOOST_AUTO_TEST_CASE(TestBegins) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + for (auto it = tree.begin(1); it != tree.end(); ++it) { + cerr << TD::Convert(*it & cdec::ALL_MASK) << endl; + } +} + +BOOST_AUTO_TEST_CASE(TestRemainder) { + cdec::TreeFragment tree("(S (A a) (B b))"); + auto it = tree.begin(); + ++it; + BOOST_CHECK(cdec::IsRHS(*it)); + cerr << tree << endl; + auto itr = it.remainder(); + while(itr != tree.end()) { + cerr << TD::Convert(*itr & cdec::ALL_MASK) << endl; + ++itr; + } +} + + diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 1c249836..cd6ee550 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -1,5 +1,6 @@ #include #include +#include #include #include #include "tree_fragment.h" @@ -15,11 +16,10 @@ using namespace std; struct Tree2StringGrammarNode { map next; - string rules; + vector rules; }; -void ReadTree2StringGrammar(istream* in, unordered_map* proots) { - unordered_map& roots = *proots; +void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { string line; while(getline(*in, line)) { size_t pos = line.find("|||"); @@ -28,32 +28,124 @@ void ReadTree2StringGrammar(istream* in, unordered_map frhs; + for (auto sym : rule_src) { cur = &cur->next[sym]; + if (sym) { + if (cdec::IsFrontier(sym)) { // frontier symbols -> variables + int nt = (sym & cdec::ALL_MASK); + frhs.push_back(-nt); + } else if (cdec::IsTerminal(sym)) { + frhs.push_back(sym); + } + } + } + os << '[' << TD::Convert(-lhs) << "] |||"; + for (auto x : frhs) { + os << ' '; + if (x < 0) + os << '[' << TD::Convert(-x) << ']'; + else + os << TD::Convert(x); + } pos += 3 + xc; while(line[pos] == ' ') { ++pos; } - size_t pos2 = line.find("|||", pos); - assert(pos2 != string::npos); - while (line[pos2 - 1] == ' ') { --pos2; } - cur->rules = line.substr(pos, pos2 - pos); - cerr << "OUTPUT = '" << cur->rules << "'\n"; + os << " ||| " << line.substr(pos); + TRulePtr rule(new TRule(os.str())); + cur->rules.push_back(rule); } } +struct ParserState { + ParserState() : in_iter(), node() {} + cdec::TreeFragment::iterator in_iter; + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const int rt) : + in_iter(it), + root_type(rt), + node(n) {} + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : + in_iter(it), + future_work(p.future_work), + root_type(p.root_type), + node(n) {} + vector future_work; + int root_type; // lhs of top level NT + Tree2StringGrammarNode* node; +}; + struct Tree2StringTranslatorImpl { - unordered_map roots; // root['S'] gives rule network for S rules + Tree2StringGrammarNode root; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { ReadFile rf(conf["grammar"].as>()[0]); - ReadTree2StringGrammar(rf.stream(), &roots); + ReadTree2StringGrammar(rf.stream(), &root); } bool Translate(const string& input, SentenceMetadata* smeta, const vector& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); - cerr << "Tree2StringTranslatorImpl: please implement this!\n"; - return false; + const int kS = -TD::Convert("S"); + Hypergraph hg; + queue q; + q.push(ParserState(input_tree.begin(), &root, kS)); + while(!q.empty()) { + ParserState& s = q.front(); + + if (s.in_iter.at_end()) { // completed a traversal of a subtree + cerr << "I traversed a subtree of the input...\n"; + if (s.node->rules.size()) { + // TODO: build hypergraph + for (auto& r : s.node->rules) + cerr << "I can build: " << r->AsString() << endl; + for (auto& w : s.future_work) + q.push(w); + } else { + cerr << "I can't build anything :(\n"; + } + } else { // more input tree to match + unsigned sym = *s.in_iter; + if (cdec::IsLHS(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + q.push(ParserState(++s.in_iter, &nit->second, s)); + } + } else if (cdec::IsRHS(sym)) { + //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + cdec::TreeFragment::iterator var = s.in_iter; + var.truncate(); + auto nit1 = s.node->next.find(sym); + auto nit2 = s.node->next.find(*var); + if (nit2 != s.node->next.end()) { + //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + ParserState new_s(++var, &nit2->second, s); + ParserState new_work(s.in_iter.remainder(), &root, -(sym & cdec::ALL_MASK)); + new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q + q.push(new_s); + } + if (nit1 != s.node->next.end()) { + //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + q.push(ParserState(++s.in_iter, &nit1->second, s)); + } + } else if (cdec::IsTerminal(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; + q.push(ParserState(++s.in_iter, &nit->second, s)); + } + } else { + cerr << "This can never happen!\n"; abort(); + } + } + q.pop(); + } + minus_lm_forest->swap(hg); } }; diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 93aad64e..78a993b8 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -36,7 +36,7 @@ void TreeFragment::DebugRec(unsigned cur, ostream* out) const { *out << ' '; if (IsFrontier(x)) { *out << '[' << TD::Convert(x & ALL_MASK) << ']'; - } else if (IsInternalNT(x)) { + } else if (IsRHS(x)) { DebugRec(x & ALL_MASK, out); } else { // must be terminal *out << TD::Convert(x); @@ -66,7 +66,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned // recursively call parser to deal with constituent ParseRec(tree, afs, cp, symp, np, &cp, &symp, &np); unsigned ind = np - 1; - rhs.push_back(ind | NT_BIT); + rhs.push_back(ind | RHS_BIT); } else { // deal with terminal / nonterminal substitution ++symp; assert(tree[cp] != ' '); @@ -95,7 +95,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned } // continuent has completed, cp is at ), build node const unsigned j = symp; // span from (i,j) // add an internal non-terminal symbol - const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | NT_BIT; + const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | RHS_BIT; nodes[np] = TreeFragmentProduction(nt, rhs); //cerr << np << " production(" << i << "," << j << ")= " << TD::Convert(nt & ALL_MASK) << " -->"; //for (auto& x : rhs) { @@ -113,11 +113,15 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned } BreadthFirstIterator TreeFragment::begin() const { - return BreadthFirstIterator(this); + return BreadthFirstIterator(this, nodes.size() - 1); +} + +BreadthFirstIterator TreeFragment::begin(unsigned node_idx) const { + return BreadthFirstIterator(this, node_idx); } BreadthFirstIterator TreeFragment::end() const { - return BreadthFirstIterator(this, 0); + return BreadthFirstIterator(this); } } diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index a38dbdfa..b83afc27 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -1,7 +1,7 @@ #ifndef TREE_FRAGMENT #define TREE_FRAGMENT -#include +#include #include #include #include @@ -12,18 +12,32 @@ namespace cdec { class BreadthFirstIterator; -static const unsigned NT_BIT = 0x40000000u; -static const unsigned FRONTIER_BIT = 0x80000000u; -static const unsigned ALL_MASK = 0x0FFFFFFFu; +static const unsigned LHS_BIT = 0x10000000u; +static const unsigned RHS_BIT = 0x20000000u; +static const unsigned FRONTIER_BIT = 0x40000000u; +static const unsigned RESERVED_BIT = 0x80000000u; +static const unsigned ALL_MASK = 0x0FFFFFFFu; -inline bool IsInternalNT(unsigned x) { - return (x & NT_BIT); +inline bool IsNT(unsigned x) { + return (x & (LHS_BIT | RHS_BIT | FRONTIER_BIT)); +} + +inline bool IsLHS(unsigned x) { + return (x & LHS_BIT); +} + +inline bool IsRHS(unsigned x) { + return (x & RHS_BIT); } inline bool IsFrontier(unsigned x) { return (x & FRONTIER_BIT); } +inline bool IsTerminal(unsigned x) { + return (x & ALL_MASK) == x; +} + struct TreeFragmentProduction { TreeFragmentProduction() {} TreeFragmentProduction(int nttype, const std::vector& r) : lhs(nttype), rhs(r) {} @@ -46,6 +60,7 @@ class TreeFragment { typedef const unsigned & reference; iterator begin() const; + iterator begin(unsigned node_idx) const; iterator end() const; private: @@ -62,24 +77,28 @@ class TreeFragment { }; struct TFIState { - TFIState() : node(), rhspos() {} - TFIState(unsigned n, unsigned p) : node(n), rhspos(p) {} - bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos; } - bool operator!=(const TFIState& o) const { return node != o.node && rhspos != o.rhspos; } + TFIState() : node(), rhspos(), state() {} + TFIState(unsigned n, unsigned p, unsigned s) : node(n), rhspos(p), state(s) {} + bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; } + bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; } unsigned short node; unsigned short rhspos; + unsigned char state; }; class BreadthFirstIterator : public std::iterator { const TreeFragment* tf_; - std::queue q_; + std::deque q_; unsigned sym; public: - explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) { - q_.push(TFIState(tf->nodes.size() - 1, 0)); + BreadthFirstIterator() : tf_(), sym() {} + // used for begin + explicit BreadthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) { + q_.push_back(TFIState(node_idx, 0, 0)); Stage(); } - BreadthFirstIterator(const TreeFragment* tf, int) : tf_(tf) {} + // used for end + explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) {} const unsigned& operator*() const { return sym; } const unsigned* operator->() const { return &sym; } bool operator==(const BreadthFirstIterator& other) const { @@ -88,26 +107,20 @@ class BreadthFirstIterator : public std::iteratornodes[s.node].rhs[s.rhspos]; - if (IsInternalNT(sym)) { - q_.push(TFIState(sym & ALL_MASK, 0)); - sym = tf_->nodes[sym & ALL_MASK].lhs; - } - } const BreadthFirstIterator& operator++() { TFIState& s = q_.front(); - const unsigned len = tf_->nodes[s.node].rhs.size(); - s.rhspos++; - if (s.rhspos > len) { - q_.pop(); + if (s.state == 0) { + s.state++; Stage(); - } else if (s.rhspos == len) { - sym = 0; } else { - Stage(); + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos >= len) { + q_.pop_front(); + Stage(); + } else { + Stage(); + } } return *this; } @@ -116,6 +129,42 @@ class BreadthFirstIterator : public std::iteratornodes[s.node].lhs & ALL_MASK) | LHS_BIT; + } else { + sym = tf_->nodes[s.node].rhs[s.rhspos]; + if (IsRHS(sym)) { + q_.push_back(TFIState(sym & ALL_MASK, 0, 0)); + sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT; + } + } + } + + // used by remainder + BreadthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) { + q_.push_back(s); + Stage(); + } }; inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) { -- cgit v1.2.3