From 34a148b9882cc983c8292978503849b114fc2983 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 16 Apr 2014 00:36:30 -0400 Subject: fix for bug due to using wrong tree traversal --- decoder/t2s_test.cc | 8 ++- decoder/tree2string_translator.cc | 18 +++--- decoder/tree_fragment.cc | 12 ---- decoder/tree_fragment.h | 112 +++++++++++++++++++++++++++++++++++++- 4 files changed, 125 insertions(+), 25 deletions(-) diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc index 3c46ea89..5ebb2662 100644 --- a/decoder/t2s_test.cc +++ b/decoder/t2s_test.cc @@ -15,8 +15,11 @@ BOOST_AUTO_TEST_CASE(TestTreeFragments) { vector aw, bw; cerr << "TREE1: " << tree << endl; cerr << "TREE2: " << tree2 << endl; - for (auto& sym : tree) + for (auto& sym : tree) { + if (cdec::IsLHS(sym)) cerr << "("; + cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym); + } for (auto& sym : tree2) if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym); BOOST_CHECK_EQUAL(a.size(), b.size()); @@ -38,11 +41,12 @@ BOOST_AUTO_TEST_CASE(TestTreeFragments) { if (cdec::IsFrontier(*it)) nts += "*"; } } + cerr << "Truncated: " << nts << endl; BOOST_CHECK_EQUAL(nts, "(S NP* VP*"); nts.clear(); int ntc = 0; - for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { + for (auto it = tree.bfs_begin(); it != tree.bfs_end(); ++it) { if (cdec::IsNT(*it)) { if (cdec::IsRHS(*it)) { ++ntc; diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 8d12d01d..3fbf1ee5 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -38,14 +38,13 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { // but it will not generate source strings correctly vector frhs; for (auto sym : rule_src) { + //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; cur = &cur->next[sym]; - if (sym) { - if (cdec::IsFrontier(sym)) { // frontier symbols -> variables - int nt = (sym & cdec::ALL_MASK); - frhs.push_back(-nt); - } else if (cdec::IsTerminal(sym)) { - frhs.push_back(sym); - } + if (cdec::IsFrontier(sym)) { // frontier symbols -> variables + int nt = (sym & cdec::ALL_MASK); + frhs.push_back(-nt); + } else if (cdec::IsTerminal(sym)) { + frhs.push_back(sym); } } os << '[' << TD::Convert(-lhs) << "] |||"; @@ -61,6 +60,7 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { os << " ||| " << line.substr(pos); TRulePtr rule(new TRule(os.str())); cur->rules.push_back(rule); + //cerr << "RULE: " << rule->AsString() << "\n\n"; } } @@ -82,7 +82,7 @@ struct ParserState { } vector future_work; int input_node_idx; // lhs of top level NT - Tree2StringGrammarNode* node; + Tree2StringGrammarNode* node; // pointer into grammar }; namespace std { @@ -239,11 +239,13 @@ struct Tree2StringTranslatorImpl { new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q if (unique.insert(new_s).second) q.push(new_s); } + //else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; } if (nit1 != s.node->next.end()) { //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; const ParserState new_s(++s.in_iter, &nit1->second, s); if (unique.insert(new_s).second) q.push(new_s); } + //else { cerr << "did not match " << TD::Convert(sym & cdec::ALL_MASK) << "\n"; } } else if (cdec::IsTerminal(sym)) { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 78a993b8..4d429f42 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -112,16 +112,4 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned *psymp = symp; } -BreadthFirstIterator TreeFragment::begin() const { - return BreadthFirstIterator(this, nodes.size() - 1); -} - -BreadthFirstIterator TreeFragment::begin(unsigned node_idx) const { - return BreadthFirstIterator(this, node_idx); -} - -BreadthFirstIterator TreeFragment::end() const { - return BreadthFirstIterator(this); -} - } diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index f1c4c106..4a704bc4 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -11,6 +11,7 @@ namespace cdec { class BreadthFirstIterator; +class DepthFirstIterator; static const unsigned LHS_BIT = 0x10000000u; static const unsigned RHS_BIT = 0x20000000u; @@ -53,16 +54,21 @@ class TreeFragment { // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar)))) explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false); void DebugRec(unsigned cur, std::ostream* out) const; - typedef BreadthFirstIterator iterator; + typedef DepthFirstIterator iterator; typedef ptrdiff_t difference_type; typedef unsigned value_type; typedef const unsigned * pointer; typedef const unsigned & reference; + // default iterator is DFS iterator begin() const; iterator begin(unsigned node_idx) const; iterator end() const; + BreadthFirstIterator bfs_begin() const; + BreadthFirstIterator bfs_begin(unsigned node_idx) const; + BreadthFirstIterator bfs_end() const; + private: // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built @@ -78,14 +84,106 @@ class TreeFragment { struct TFIState { TFIState() : node(), rhspos(), state() {} - TFIState(unsigned n, unsigned p, unsigned s) : node(n), rhspos(p), state(s) {} + TFIState(unsigned n, int p, unsigned s) : node(n), rhspos(p), state(s) {} bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; } bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; } unsigned short node; - unsigned short rhspos; + short rhspos; unsigned char state; }; +class DepthFirstIterator : public std::iterator { + const TreeFragment* tf_; + std::deque q_; + unsigned sym; + public: + DepthFirstIterator() : tf_(), sym() {} + // used for begin + explicit DepthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) { + q_.push_back(TFIState(node_idx, -1, 0)); + Stage(); + q_.back().state++; + } + // used for end + explicit DepthFirstIterator(const TreeFragment* tf) : tf_(tf) {} + const unsigned& operator*() const { return sym; } + const unsigned* operator->() const { return &sym; } + bool operator==(const DepthFirstIterator& other) const { + return (tf_ == other.tf_) && (q_ == other.q_); + } + bool operator!=(const DepthFirstIterator& other) const { + return (tf_ != other.tf_) || (q_ != other.q_); + } + unsigned node_idx() const { return q_.front().node; } + const DepthFirstIterator& operator++() { + TFIState& s = q_.back(); + if (s.state == 0) { + Stage(); + s.state++; + } else if (s.state == 1) { + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos >= len) { + q_.pop_back(); + while (!q_.empty()) { + TFIState& s = q_.back(); + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos < len) break; + q_.pop_back(); + } + } + Stage(); + } + return *this; + } + DepthFirstIterator operator++(int) { + DepthFirstIterator res = *this; + ++(*this); + return res; + } + // tell iterator not to explore the subtree rooted at sym + // should only be called once per NT symbol encountered + const DepthFirstIterator& truncate() { + assert(IsRHS(sym)); + sym &= ALL_MASK; + sym |= FRONTIER_BIT; + q_.pop_back(); + return *this; + } + unsigned child_node() const { + assert(IsRHS(sym)); + return q_.back().node; + } + DepthFirstIterator remainder() const { + assert(IsRHS(sym)); + return DepthFirstIterator(tf_, q_.back()); + } + bool at_end() const { + return q_.empty(); + } + private: + void Stage() { + if (q_.empty()) return; + const TFIState& s = q_.back(); + if (s.state == 0) { + sym = (tf_->nodes[s.node].lhs & ALL_MASK) | LHS_BIT; + } else if (s.state == 1) { + sym = tf_->nodes[s.node].rhs[s.rhspos]; + if (IsRHS(sym)) { + q_.push_back(TFIState(sym & ALL_MASK, -1, 0)); + sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT; + } + } + } + + // used by remainder + DepthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) { + q_.push_back(s); + Stage(); + } +}; + class BreadthFirstIterator : public std::iterator { const TreeFragment* tf_; std::deque q_; @@ -172,6 +270,14 @@ class BreadthFirstIterator : public std::iterator