From 34785db78a0ad12f0fe74d98924acc20a8cab79a Mon Sep 17 00:00:00 2001
From: Chris Dyer <redpony@gmail.com>
Date: Thu, 27 Mar 2014 00:07:41 -0400
Subject: breadth first iterator for tree fragment

---
 decoder/tree_fragment.cc | 8 ++++++++
 1 file changed, 8 insertions(+)

(limited to 'decoder/tree_fragment.cc')

diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc
index d5c30f58..93aad64e 100644
--- a/decoder/tree_fragment.cc
+++ b/decoder/tree_fragment.cc
@@ -112,4 +112,12 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned
   *psymp = symp;
 }
 
+BreadthFirstIterator TreeFragment::begin() const {
+  return BreadthFirstIterator(this);
+}
+
+BreadthFirstIterator TreeFragment::end() const {
+  return BreadthFirstIterator(this, 0);
+}
+
 }
-- 
cgit v1.2.3


From 8372086f2fc4bd765fdd05e8cf95faeb147a6587 Mon Sep 17 00:00:00 2001
From: Chris Dyer <redpony@gmail.com>
Date: Sun, 30 Mar 2014 23:50:17 -0400
Subject: almost complete tree to string translator

---
 decoder/Makefile.am               |   3 +
 decoder/decoder.cc                |   6 +-
 decoder/t2s_test.cc               | 110 ++++++++++++++++++++++++++++++++++
 decoder/tree2string_translator.cc | 120 +++++++++++++++++++++++++++++++++-----
 decoder/tree_fragment.cc          |  14 +++--
 decoder/tree_fragment.h           | 109 ++++++++++++++++++++++++----------
 6 files changed, 311 insertions(+), 51 deletions(-)
 create mode 100644 decoder/t2s_test.cc

(limited to 'decoder/tree_fragment.cc')

diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 7481192b..5c91fe65 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -4,9 +4,12 @@ noinst_PROGRAMS = \
   trule_test \
   hg_test \
   parser_test \
+  t2s_test \
   grammar_test
 
 TESTS = trule_test parser_test grammar_test hg_test
+t2s_test_SOURCES = t2s_test.cc
+t2s_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a
 parser_test_SOURCES = parser_test.cc
 parser_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a
 grammar_test_SOURCES = grammar_test.cc
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 31049216..43e2640d 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -490,8 +490,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
   }
 
   formalism = LowercaseString(str("formalism",conf));
-  if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") {
-    cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n";
+  if (formalism != "t2s" && formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") {
+    cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 't2s', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n";
     cerr << dcmdline_options << endl;
     exit(1);
   }
@@ -626,6 +626,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
   // set up translation back end
   if (formalism == "scfg")
     translator.reset(new SCFGTranslator(conf));
+  else if (formalism == "t2s")
+    translator.reset(new Tree2StringTranslator(conf));
   else if (formalism == "fst")
     translator.reset(new FSTTranslator(conf));
   else if (formalism == "pb")
diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc
new file mode 100644
index 00000000..3c46ea89
--- /dev/null
+++ b/decoder/t2s_test.cc
@@ -0,0 +1,110 @@
+#include "tree_fragment.h"
+
+#define BOOST_TEST_MODULE T2STest
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+#include <iostream>
+#include "tdict.h"
+
+using namespace std;
+
+BOOST_AUTO_TEST_CASE(TestTreeFragments) {
+  cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))");
+  cdec::TreeFragment tree2("(S (NP (DT a) (NN cat)) (VP (V ate) (NP (DT the) (NN cake pie))))");
+  vector<unsigned> a, b;
+  vector<WordID> aw, bw;
+  cerr << "TREE1: " << tree << endl;
+  cerr << "TREE2: " << tree2 << endl;
+  for (auto& sym : tree)
+    if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym);
+  for (auto& sym : tree2)
+    if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym);
+  BOOST_CHECK_EQUAL(a.size(), b.size());
+  BOOST_CHECK_EQUAL(aw.size() + 1, bw.size());
+  BOOST_CHECK_EQUAL(aw.size(), 5);
+  BOOST_CHECK_EQUAL(TD::GetString(aw), "the boy saw a cat");
+  BOOST_CHECK_EQUAL(TD::GetString(bw), "a cat ate the cake pie");
+  if (a != b) {
+    BOOST_CHECK_EQUAL(1,2);
+  }
+
+  string nts;
+  for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) {
+    if (cdec::IsNT(*it)) {
+      if (cdec::IsRHS(*it)) it.truncate();
+      if (nts.size()) nts += " ";
+      if (cdec::IsLHS(*it)) nts += "(";
+      nts += TD::Convert(*it & cdec::ALL_MASK);
+      if (cdec::IsFrontier(*it)) nts += "*";
+    }
+  }
+  BOOST_CHECK_EQUAL(nts, "(S NP* VP*");
+
+  nts.clear();
+  int ntc = 0;
+  for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) {
+    if (cdec::IsNT(*it)) {
+      if (cdec::IsRHS(*it)) {
+        ++ntc;
+        if (ntc > 1) it.truncate();
+      }
+      if (nts.size()) nts += " ";
+      if (cdec::IsLHS(*it)) nts += "(";
+      nts += TD::Convert(*it & cdec::ALL_MASK);
+      if (cdec::IsFrontier(*it)) nts += "*";
+    }
+  }
+  BOOST_CHECK_EQUAL(nts, "(S NP VP* (NP DT* NN*");
+}
+
+BOOST_AUTO_TEST_CASE(TestSharing) {
+  cdec::TreeFragment rule1("(S [NP] [VP])", true);
+  cdec::TreeFragment rule2("(S [NP] (VP [V] [NP]))", true);
+  string r1,r2;
+  for (auto sym : rule1) {
+    if (r1.size()) r1 += " ";
+    if (cdec::IsLHS(sym)) r1 += "(";
+    r1 += TD::Convert(sym & cdec::ALL_MASK);
+    if (cdec::IsFrontier(sym)) r1 += "*";
+  }
+  for (auto sym : rule2) {
+    if (r2.size()) r2 += " ";
+    if (cdec::IsLHS(sym)) r2 += "(";
+    r2 += TD::Convert(sym & cdec::ALL_MASK);
+    if (cdec::IsFrontier(sym)) r2 += "*";
+  }
+  cerr << rule1 << endl;
+  cerr << r1 << endl;
+  cerr << rule2 << endl;
+  cerr << r2 << endl;
+  BOOST_CHECK_EQUAL(r1, "(S NP* VP*");
+  BOOST_CHECK_EQUAL(r2, "(S NP* VP (VP V* NP*");
+}
+
+BOOST_AUTO_TEST_CASE(TestEndInvariants) {
+  cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))");
+  BOOST_CHECK(tree.end().at_end());
+  BOOST_CHECK(!tree.begin().at_end());
+}
+
+BOOST_AUTO_TEST_CASE(TestBegins) {
+  cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))");
+  for (auto it = tree.begin(1); it != tree.end(); ++it) {
+    cerr << TD::Convert(*it & cdec::ALL_MASK) << endl;
+  }
+}
+
+BOOST_AUTO_TEST_CASE(TestRemainder) {
+  cdec::TreeFragment tree("(S (A a) (B b))");
+  auto it = tree.begin();
+  ++it;
+  BOOST_CHECK(cdec::IsRHS(*it));
+  cerr << tree << endl;
+  auto itr = it.remainder();
+  while(itr != tree.end()) {
+    cerr << TD::Convert(*itr & cdec::ALL_MASK) << endl;
+    ++itr;
+  }
+}
+
+
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index 1c249836..cd6ee550 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -1,5 +1,6 @@
 #include <algorithm>
 #include <vector>
+#include <queue>
 #include <boost/functional/hash.hpp>
 #include <unordered_map>
 #include "tree_fragment.h"
@@ -15,11 +16,10 @@ using namespace std;
 
 struct Tree2StringGrammarNode {
   map<unsigned, Tree2StringGrammarNode> next;
-  string rules;
+  vector<TRulePtr> rules;
 };
 
-void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGrammarNode>* proots) {
-  unordered_map<unsigned, Tree2StringGrammarNode>& roots = *proots;
+void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) {
   string line;
   while(getline(*in, line)) {
     size_t pos = line.find("|||");
@@ -28,32 +28,124 @@ void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGram
     unsigned xc = 0;
     while (line[pos - 1] == ' ') { --pos; xc++; }
     cdec::TreeFragment rule_src(line.substr(0, pos), true);
-    Tree2StringGrammarNode* cur = &roots[rule_src.root];
-    for (auto sym : rule_src)
+    Tree2StringGrammarNode* cur = root;
+    ostringstream os;
+    int lhs = -(rule_src.root & cdec::ALL_MASK);
+    // build source RHS for SCFG projection
+    // TODO - this is buggy - it will generate a well-formed SCFG rule
+    // but it will not generate source strings correctly
+    vector<int> frhs;
+    for (auto sym : rule_src) {
       cur = &cur->next[sym];
+      if (sym) {
+        if (cdec::IsFrontier(sym)) {  // frontier symbols -> variables
+          int nt = (sym & cdec::ALL_MASK);
+          frhs.push_back(-nt);
+        } else if (cdec::IsTerminal(sym)) {
+          frhs.push_back(sym);
+        }
+      }
+    }
+    os << '[' << TD::Convert(-lhs) << "] |||";
+    for (auto x : frhs) {
+      os << ' ';
+      if (x < 0)
+        os << '[' << TD::Convert(-x) << ']';
+      else
+        os << TD::Convert(x);
+    }
     pos += 3 + xc;
     while(line[pos] == ' ') { ++pos; }
-    size_t pos2 = line.find("|||", pos);
-    assert(pos2 != string::npos);
-    while (line[pos2 - 1] == ' ') { --pos2; }
-    cur->rules = line.substr(pos, pos2 - pos);
-    cerr << "OUTPUT = '" << cur->rules << "'\n";
+    os << " ||| " << line.substr(pos);
+    TRulePtr rule(new TRule(os.str()));
+    cur->rules.push_back(rule);
   }
 }
 
+struct ParserState {
+  ParserState() : in_iter(), node() {}
+  cdec::TreeFragment::iterator in_iter;
+  ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const int rt) :
+      in_iter(it),
+      root_type(rt),
+      node(n) {}
+  ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) :
+      in_iter(it),
+      future_work(p.future_work),
+      root_type(p.root_type),
+      node(n) {}
+  vector<ParserState> future_work;
+  int root_type; // lhs of top level NT
+  Tree2StringGrammarNode* node;
+};
+
 struct Tree2StringTranslatorImpl {
-  unordered_map<unsigned, Tree2StringGrammarNode> roots; // root['S'] gives rule network for S rules
+  Tree2StringGrammarNode root;
   Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) {
     ReadFile rf(conf["grammar"].as<vector<string>>()[0]);
-    ReadTree2StringGrammar(rf.stream(), &roots);
+    ReadTree2StringGrammar(rf.stream(), &root);
   }
   bool Translate(const string& input,
                  SentenceMetadata* smeta,
                  const vector<double>& weights,
                  Hypergraph* minus_lm_forest) {
     cdec::TreeFragment input_tree(input, false);
-    cerr << "Tree2StringTranslatorImpl: please implement this!\n";
-    return false;
+    const int kS = -TD::Convert("S");
+    Hypergraph hg;
+    queue<ParserState> q;
+    q.push(ParserState(input_tree.begin(), &root, kS));
+    while(!q.empty()) {
+      ParserState& s = q.front();
+
+      if (s.in_iter.at_end()) { // completed a traversal of a subtree
+        cerr << "I traversed a subtree of the input...\n";
+        if (s.node->rules.size()) {
+          // TODO: build hypergraph
+          for (auto& r : s.node->rules)
+            cerr << "I can build: " << r->AsString() << endl;
+          for (auto& w : s.future_work)
+            q.push(w);
+        } else {
+          cerr << "I can't build anything :(\n";
+        }
+      } else { // more input tree to match
+        unsigned sym = *s.in_iter;
+        if (cdec::IsLHS(sym)) {
+          auto nit = s.node->next.find(sym);
+          if (nit != s.node->next.end()) {
+            //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
+            q.push(ParserState(++s.in_iter, &nit->second, s));
+          }
+        } else if (cdec::IsRHS(sym)) {
+          //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
+          cdec::TreeFragment::iterator var = s.in_iter;
+          var.truncate();
+          auto nit1 = s.node->next.find(sym);
+          auto nit2 = s.node->next.find(*var);
+          if (nit2 != s.node->next.end()) {
+            //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
+            ParserState new_s(++var, &nit2->second, s);
+            ParserState new_work(s.in_iter.remainder(), &root, -(sym & cdec::ALL_MASK));
+            new_s.future_work.push_back(new_work);  // if this traversal of the input succeeds, future_work goes on the q
+            q.push(new_s);
+          }
+          if (nit1 != s.node->next.end()) {
+            //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
+            q.push(ParserState(++s.in_iter, &nit1->second, s));
+          }
+        } else if (cdec::IsTerminal(sym)) {
+          auto nit = s.node->next.find(sym);
+          if (nit != s.node->next.end()) {
+            //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl;
+            q.push(ParserState(++s.in_iter, &nit->second, s));
+          }
+        } else {
+          cerr << "This can never happen!\n"; abort();
+        }
+      }
+      q.pop();
+    }
+    minus_lm_forest->swap(hg);
   }
 };
 
diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc
index 93aad64e..78a993b8 100644
--- a/decoder/tree_fragment.cc
+++ b/decoder/tree_fragment.cc
@@ -36,7 +36,7 @@ void TreeFragment::DebugRec(unsigned cur, ostream* out) const {
     *out << ' ';
     if (IsFrontier(x)) {
       *out << '[' << TD::Convert(x & ALL_MASK) << ']';
-    } else if (IsInternalNT(x)) {
+    } else if (IsRHS(x)) {
       DebugRec(x & ALL_MASK, out);
     } else { // must be terminal
       *out << TD::Convert(x);
@@ -66,7 +66,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned
       // recursively call parser to deal with constituent
       ParseRec(tree, afs, cp, symp, np, &cp, &symp, &np);
       unsigned ind = np - 1;
-      rhs.push_back(ind | NT_BIT);
+      rhs.push_back(ind | RHS_BIT);
     } else { // deal with terminal / nonterminal substitution
       ++symp;
       assert(tree[cp] != ' ');
@@ -95,7 +95,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned
   } // continuent has completed, cp is at ), build node
   const unsigned j = symp; // span from (i,j)
   // add an internal non-terminal symbol
-  const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | NT_BIT;
+  const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | RHS_BIT;
   nodes[np] = TreeFragmentProduction(nt, rhs);
   //cerr << np << " production(" << i << "," << j << ")=  " << TD::Convert(nt & ALL_MASK) << " -->";
   //for (auto& x : rhs) {
@@ -113,11 +113,15 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned
 }
 
 BreadthFirstIterator TreeFragment::begin() const {
-  return BreadthFirstIterator(this);
+  return BreadthFirstIterator(this, nodes.size() - 1);
+}
+
+BreadthFirstIterator TreeFragment::begin(unsigned node_idx) const {
+  return BreadthFirstIterator(this, node_idx);
 }
 
 BreadthFirstIterator TreeFragment::end() const {
-  return BreadthFirstIterator(this, 0);
+  return BreadthFirstIterator(this);
 }
 
 }
diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h
index a38dbdfa..b83afc27 100644
--- a/decoder/tree_fragment.h
+++ b/decoder/tree_fragment.h
@@ -1,7 +1,7 @@
 #ifndef TREE_FRAGMENT
 #define TREE_FRAGMENT
 
-#include <queue>
+#include <deque>
 #include <iostream>
 #include <vector>
 #include <string>
@@ -12,18 +12,32 @@ namespace cdec {
 
 class BreadthFirstIterator;
 
-static const unsigned NT_BIT       = 0x40000000u;
-static const unsigned FRONTIER_BIT = 0x80000000u;
-static const unsigned ALL_MASK     = 0x0FFFFFFFu;
+static const unsigned LHS_BIT         = 0x10000000u;
+static const unsigned RHS_BIT         = 0x20000000u;
+static const unsigned FRONTIER_BIT    = 0x40000000u;
+static const unsigned RESERVED_BIT    = 0x80000000u;
+static const unsigned ALL_MASK        = 0x0FFFFFFFu;
 
-inline bool IsInternalNT(unsigned x) {
-  return (x & NT_BIT);
+inline bool IsNT(unsigned x) {
+  return (x & (LHS_BIT | RHS_BIT | FRONTIER_BIT));
+}
+
+inline bool IsLHS(unsigned x) {
+  return (x & LHS_BIT);
+}
+
+inline bool IsRHS(unsigned x) {
+  return (x & RHS_BIT);
 }
 
 inline bool IsFrontier(unsigned x) {
   return (x & FRONTIER_BIT);
 }
 
+inline bool IsTerminal(unsigned x) {
+  return (x & ALL_MASK) == x;
+}
+
 struct TreeFragmentProduction {
   TreeFragmentProduction() {}
   TreeFragmentProduction(int nttype, const std::vector<unsigned>& r) : lhs(nttype), rhs(r) {}
@@ -46,6 +60,7 @@ class TreeFragment {
   typedef const unsigned & reference;
 
   iterator begin() const;
+  iterator begin(unsigned node_idx) const;
   iterator end() const;
 
  private:
@@ -62,24 +77,28 @@ class TreeFragment {
 };
 
 struct TFIState {
-  TFIState() : node(), rhspos() {}
-  TFIState(unsigned n, unsigned p) : node(n), rhspos(p) {}
-  bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos; }
-  bool operator!=(const TFIState& o) const { return node != o.node && rhspos != o.rhspos; }
+  TFIState() : node(), rhspos(), state() {}
+  TFIState(unsigned n, unsigned p, unsigned s) : node(n), rhspos(p), state(s) {}
+  bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; }
+  bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; }
   unsigned short node;
   unsigned short rhspos;
+  unsigned char state;
 };
 
 class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, unsigned> {
   const TreeFragment* tf_;
-  std::queue<TFIState> q_;
+  std::deque<TFIState> q_;
   unsigned sym;
  public:
-  explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) {
-    q_.push(TFIState(tf->nodes.size() - 1, 0));
+  BreadthFirstIterator() : tf_(), sym() {}
+  // used for begin
+  explicit BreadthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) {
+    q_.push_back(TFIState(node_idx, 0, 0));
     Stage();
   }
-  BreadthFirstIterator(const TreeFragment* tf, int) : tf_(tf) {}
+  // used for end
+  explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) {}
   const unsigned& operator*() const { return sym; }
   const unsigned* operator->() const { return &sym; }
   bool operator==(const BreadthFirstIterator& other) const {
@@ -88,26 +107,20 @@ class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, uns
   bool operator!=(const BreadthFirstIterator& other) const {
     return (tf_ != other.tf_) || (q_ != other.q_);
   }
-  void Stage() {
-    if (q_.empty()) return;
-    const TFIState& s = q_.front();
-    sym = tf_->nodes[s.node].rhs[s.rhspos];
-    if (IsInternalNT(sym)) {
-      q_.push(TFIState(sym & ALL_MASK, 0));
-      sym = tf_->nodes[sym & ALL_MASK].lhs;
-    }
-  }
   const BreadthFirstIterator& operator++() {
     TFIState& s = q_.front();
-    const unsigned len = tf_->nodes[s.node].rhs.size();
-    s.rhspos++;
-    if (s.rhspos > len) {
-      q_.pop();
+    if (s.state == 0) {
+      s.state++;
       Stage();
-    } else if (s.rhspos == len) {
-      sym = 0;
     } else {
-      Stage();
+      const unsigned len = tf_->nodes[s.node].rhs.size();
+      s.rhspos++;
+      if (s.rhspos >= len) {
+        q_.pop_front();
+        Stage();
+      } else {
+        Stage();
+      }
     }
     return *this;
   }
@@ -116,6 +129,42 @@ class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, uns
     ++(*this);
     return res;
   }
+  // tell iterator not to explore the subtree rooted at sym
+  // should only be called once per NT symbol encountered
+  const BreadthFirstIterator& truncate() {
+    assert(IsRHS(sym));
+    sym &= ALL_MASK;
+    sym |= FRONTIER_BIT;
+    q_.pop_back();
+    return *this;
+  }
+  BreadthFirstIterator remainder() const {
+    assert(IsRHS(sym));
+    return BreadthFirstIterator(tf_, q_.back());
+  }
+  bool at_end() const {
+    return q_.empty();
+  }
+ private:
+  void Stage() {
+    if (q_.empty()) return;
+    const TFIState& s = q_.front();
+    if (s.state == 0) {
+      sym = (tf_->nodes[s.node].lhs & ALL_MASK) | LHS_BIT;
+    } else {
+      sym = tf_->nodes[s.node].rhs[s.rhspos];
+      if (IsRHS(sym)) {
+        q_.push_back(TFIState(sym & ALL_MASK, 0, 0));
+        sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT;
+      }
+    }
+  }
+
+  // used by remainder
+  BreadthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) {
+    q_.push_back(s);
+    Stage();
+  }
 };
 
 inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) {
-- 
cgit v1.2.3


From 649b5ffc7c81182ba39d338b11bfe2e9a05544b5 Mon Sep 17 00:00:00 2001
From: Chris Dyer <redpony@gmail.com>
Date: Wed, 16 Apr 2014 00:36:30 -0400
Subject: fix for bug due to using wrong tree traversal

---
 decoder/t2s_test.cc               |   8 ++-
 decoder/tree2string_translator.cc |  18 +++---
 decoder/tree_fragment.cc          |  12 ----
 decoder/tree_fragment.h           | 112 +++++++++++++++++++++++++++++++++++++-
 4 files changed, 125 insertions(+), 25 deletions(-)

(limited to 'decoder/tree_fragment.cc')

diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc
index 3c46ea89..5ebb2662 100644
--- a/decoder/t2s_test.cc
+++ b/decoder/t2s_test.cc
@@ -15,8 +15,11 @@ BOOST_AUTO_TEST_CASE(TestTreeFragments) {
   vector<WordID> aw, bw;
   cerr << "TREE1: " << tree << endl;
   cerr << "TREE2: " << tree2 << endl;
-  for (auto& sym : tree)
+  for (auto& sym : tree) {
+    if (cdec::IsLHS(sym)) cerr << "(";
+    cerr << TD::Convert(sym & cdec::ALL_MASK) << endl;
     if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym);
+  }
   for (auto& sym : tree2)
     if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym);
   BOOST_CHECK_EQUAL(a.size(), b.size());
@@ -38,11 +41,12 @@ BOOST_AUTO_TEST_CASE(TestTreeFragments) {
       if (cdec::IsFrontier(*it)) nts += "*";
     }
   }
+  cerr << "Truncated: " << nts << endl;
   BOOST_CHECK_EQUAL(nts, "(S NP* VP*");
 
   nts.clear();
   int ntc = 0;
-  for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) {
+  for (auto it = tree.bfs_begin(); it != tree.bfs_end(); ++it) {
     if (cdec::IsNT(*it)) {
       if (cdec::IsRHS(*it)) {
         ++ntc;
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index 8d12d01d..3fbf1ee5 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -38,14 +38,13 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) {
     // but it will not generate source strings correctly
     vector<int> frhs;
     for (auto sym : rule_src) {
+      //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl;
       cur = &cur->next[sym];
-      if (sym) {
-        if (cdec::IsFrontier(sym)) {  // frontier symbols -> variables
-          int nt = (sym & cdec::ALL_MASK);
-          frhs.push_back(-nt);
-        } else if (cdec::IsTerminal(sym)) {
-          frhs.push_back(sym);
-        }
+      if (cdec::IsFrontier(sym)) {  // frontier symbols -> variables
+        int nt = (sym & cdec::ALL_MASK);
+        frhs.push_back(-nt);
+      } else if (cdec::IsTerminal(sym)) {
+        frhs.push_back(sym);
       }
     }
     os << '[' << TD::Convert(-lhs) << "] |||";
@@ -61,6 +60,7 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) {
     os << " ||| " << line.substr(pos);
     TRulePtr rule(new TRule(os.str()));
     cur->rules.push_back(rule);
+    //cerr << "RULE: " << rule->AsString() << "\n\n";
   }
 }
 
@@ -82,7 +82,7 @@ struct ParserState {
   }
   vector<unsigned> future_work;
   int input_node_idx; // lhs of top level NT
-  Tree2StringGrammarNode* node;
+  Tree2StringGrammarNode* node; // pointer into grammar
 };
 
 namespace std {
@@ -239,11 +239,13 @@ struct Tree2StringTranslatorImpl {
             new_s.future_work.push_back(new_work);  // if this traversal of the input succeeds, future_work goes on the q
             if (unique.insert(new_s).second) q.push(new_s);
           }
+          //else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; }
           if (nit1 != s.node->next.end()) {
             //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
             const ParserState new_s(++s.in_iter, &nit1->second, s);
             if (unique.insert(new_s).second) q.push(new_s);
           }
+          //else { cerr << "did not match " << TD::Convert(sym & cdec::ALL_MASK) << "\n"; }
         } else if (cdec::IsTerminal(sym)) {
           auto nit = s.node->next.find(sym);
           if (nit != s.node->next.end()) {
diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc
index 78a993b8..4d429f42 100644
--- a/decoder/tree_fragment.cc
+++ b/decoder/tree_fragment.cc
@@ -112,16 +112,4 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned
   *psymp = symp;
 }
 
-BreadthFirstIterator TreeFragment::begin() const {
-  return BreadthFirstIterator(this, nodes.size() - 1);
-}
-
-BreadthFirstIterator TreeFragment::begin(unsigned node_idx) const {
-  return BreadthFirstIterator(this, node_idx);
-}
-
-BreadthFirstIterator TreeFragment::end() const {
-  return BreadthFirstIterator(this);
-}
-
 }
diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h
index f1c4c106..4a704bc4 100644
--- a/decoder/tree_fragment.h
+++ b/decoder/tree_fragment.h
@@ -11,6 +11,7 @@
 namespace cdec {
 
 class BreadthFirstIterator;
+class DepthFirstIterator;
 
 static const unsigned LHS_BIT         = 0x10000000u;
 static const unsigned RHS_BIT         = 0x20000000u;
@@ -53,16 +54,21 @@ class TreeFragment {
   // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar))))
   explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false);
   void DebugRec(unsigned cur, std::ostream* out) const;
-  typedef BreadthFirstIterator iterator;
+  typedef DepthFirstIterator iterator;
   typedef ptrdiff_t difference_type;
   typedef unsigned value_type;
   typedef const unsigned * pointer;
   typedef const unsigned & reference;
 
+  // default iterator is DFS
   iterator begin() const;
   iterator begin(unsigned node_idx) const;
   iterator end() const;
 
+  BreadthFirstIterator bfs_begin() const;
+  BreadthFirstIterator bfs_begin(unsigned node_idx) const;
+  BreadthFirstIterator bfs_end() const;
+
  private:
   // cp is the character index in the tree
   // np keeps track of the nodes (nonterminals) that have been built
@@ -78,14 +84,106 @@ class TreeFragment {
 
 struct TFIState {
   TFIState() : node(), rhspos(), state() {}
-  TFIState(unsigned n, unsigned p, unsigned s) : node(n), rhspos(p), state(s) {}
+  TFIState(unsigned n, int p, unsigned s) : node(n), rhspos(p), state(s) {}
   bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; }
   bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; }
   unsigned short node;
-  unsigned short rhspos;
+  short rhspos;
   unsigned char state;
 };
 
+class DepthFirstIterator : public std::iterator<std::forward_iterator_tag, unsigned> {
+  const TreeFragment* tf_;
+  std::deque<TFIState> q_;
+  unsigned sym;
+ public:
+  DepthFirstIterator() : tf_(), sym() {}
+  // used for begin
+  explicit DepthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) {
+    q_.push_back(TFIState(node_idx, -1, 0));
+    Stage();
+    q_.back().state++;
+  }
+  // used for end
+  explicit DepthFirstIterator(const TreeFragment* tf) : tf_(tf) {}
+  const unsigned& operator*() const { return sym; }
+  const unsigned* operator->() const { return &sym; }
+  bool operator==(const DepthFirstIterator& other) const {
+    return (tf_ == other.tf_) && (q_ == other.q_);
+  }
+  bool operator!=(const DepthFirstIterator& other) const {
+    return (tf_ != other.tf_) || (q_ != other.q_);
+  }
+  unsigned node_idx() const { return q_.front().node; }
+  const DepthFirstIterator& operator++() {
+    TFIState& s = q_.back();
+    if (s.state == 0) {
+      Stage();
+      s.state++;
+    } else if (s.state == 1) {
+      const unsigned len = tf_->nodes[s.node].rhs.size();
+      s.rhspos++;
+      if (s.rhspos >= len) {
+        q_.pop_back();
+        while (!q_.empty()) {
+          TFIState& s = q_.back();
+          const unsigned len = tf_->nodes[s.node].rhs.size();
+          s.rhspos++;
+          if (s.rhspos < len) break;
+          q_.pop_back();
+        }
+      }
+      Stage();
+    }
+    return *this;
+  }
+  DepthFirstIterator operator++(int) {
+    DepthFirstIterator res = *this;
+    ++(*this);
+    return res;
+  }
+  // tell iterator not to explore the subtree rooted at sym
+  // should only be called once per NT symbol encountered
+  const DepthFirstIterator& truncate() {
+    assert(IsRHS(sym));
+    sym &= ALL_MASK;
+    sym |= FRONTIER_BIT;
+    q_.pop_back();
+    return *this;
+  }
+  unsigned child_node() const {
+    assert(IsRHS(sym));
+    return q_.back().node;
+  }
+  DepthFirstIterator remainder() const {
+    assert(IsRHS(sym));
+    return DepthFirstIterator(tf_, q_.back());
+  }
+  bool at_end() const {
+    return q_.empty();
+  }
+ private:
+  void Stage() {
+    if (q_.empty()) return;
+    const TFIState& s = q_.back();
+    if (s.state == 0) {
+      sym = (tf_->nodes[s.node].lhs & ALL_MASK) | LHS_BIT;
+    } else if (s.state == 1) {
+      sym = tf_->nodes[s.node].rhs[s.rhspos];
+      if (IsRHS(sym)) {
+        q_.push_back(TFIState(sym & ALL_MASK, -1, 0));
+        sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT;
+      }
+    }
+  }
+
+  // used by remainder
+  DepthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) {
+    q_.push_back(s);
+    Stage();
+  }
+};
+
 class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, unsigned> {
   const TreeFragment* tf_;
   std::deque<TFIState> q_;
@@ -172,6 +270,14 @@ class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, uns
   }
 };
 
+inline TreeFragment::iterator TreeFragment::begin() const { return iterator(this, nodes.size() - 1); }
+inline TreeFragment::iterator TreeFragment::begin(unsigned node_idx) const { return iterator(this, node_idx); }
+inline TreeFragment::iterator TreeFragment::end() const { return iterator(this); }
+
+inline BreadthFirstIterator TreeFragment::bfs_begin() const { return BreadthFirstIterator(this, nodes.size() - 1); }
+inline BreadthFirstIterator TreeFragment::bfs_begin(unsigned node_idx) const { return BreadthFirstIterator(this, node_idx); }
+inline BreadthFirstIterator TreeFragment::bfs_end() const { return BreadthFirstIterator(this); }
+
 inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) {
   x.DebugRec(x.nodes.size() - 1, &os);
   return os;
-- 
cgit v1.2.3


From aa9d5d402c01e45835878c02777442950a0f6c0a Mon Sep 17 00:00:00 2001
From: Chris Dyer <redpony@gmail.com>
Date: Sun, 27 Apr 2014 21:05:33 +0200
Subject: clean up headers

---
 decoder/tree_fragment.cc | 2 ++
 decoder/tree_fragment.h  | 3 +--
 2 files changed, 3 insertions(+), 2 deletions(-)

(limited to 'decoder/tree_fragment.cc')

diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc
index 4d429f42..696c8601 100644
--- a/decoder/tree_fragment.cc
+++ b/decoder/tree_fragment.cc
@@ -2,6 +2,8 @@
 
 #include <cassert>
 
+#include "tdict.h"
+
 using namespace std;
 
 namespace cdec {
diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h
index 4a704bc4..f9dfa8cc 100644
--- a/decoder/tree_fragment.h
+++ b/decoder/tree_fragment.h
@@ -5,8 +5,7 @@
 #include <iostream>
 #include <vector>
 #include <string>
-
-#include "tdict.h"
+#include <cassert>
 
 namespace cdec {
 
-- 
cgit v1.2.3