diff options
| author | Chris Dyer <redpony@gmail.com> | 2014-03-30 23:50:17 -0400 | 
|---|---|---|
| committer | Chris Dyer <redpony@gmail.com> | 2014-03-30 23:50:17 -0400 | 
| commit | 8372086f2fc4bd765fdd05e8cf95faeb147a6587 (patch) | |
| tree | fa4ac0342bc1259ce96c61fa9fffb5f8252d0333 /decoder | |
| parent | ca29417acd47dbbd2aa68cd31fcd3129e6482bf7 (diff) | |
almost complete tree to string translator
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/Makefile.am | 3 | ||||
| -rw-r--r-- | decoder/decoder.cc | 6 | ||||
| -rw-r--r-- | decoder/t2s_test.cc | 110 | ||||
| -rw-r--r-- | decoder/tree2string_translator.cc | 120 | ||||
| -rw-r--r-- | decoder/tree_fragment.cc | 14 | ||||
| -rw-r--r-- | decoder/tree_fragment.h | 109 | 
6 files changed, 311 insertions, 51 deletions
| diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 7481192b..5c91fe65 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -4,9 +4,12 @@ noinst_PROGRAMS = \    trule_test \    hg_test \    parser_test \ +  t2s_test \    grammar_test  TESTS = trule_test parser_test grammar_test hg_test +t2s_test_SOURCES = t2s_test.cc +t2s_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a  parser_test_SOURCES = parser_test.cc  parser_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a  grammar_test_SOURCES = grammar_test.cc diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 31049216..43e2640d 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -490,8 +490,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    }    formalism = LowercaseString(str("formalism",conf)); -  if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { -    cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; +  if (formalism != "t2s" && formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { +    cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 't2s', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n";      cerr << dcmdline_options << endl;      exit(1);    } @@ -626,6 +626,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    // set up translation back end    if (formalism == "scfg")      translator.reset(new SCFGTranslator(conf)); +  else if (formalism == "t2s") +    translator.reset(new Tree2StringTranslator(conf));    else if (formalism == "fst")      translator.reset(new FSTTranslator(conf));    else if (formalism == "pb") diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc new file mode 100644 index 00000000..3c46ea89 --- /dev/null +++ b/decoder/t2s_test.cc @@ -0,0 +1,110 @@ +#include "tree_fragment.h" + +#define BOOST_TEST_MODULE T2STest +#include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp> +#include <iostream> +#include "tdict.h" + +using namespace std; + +BOOST_AUTO_TEST_CASE(TestTreeFragments) { +  cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); +  cdec::TreeFragment tree2("(S (NP (DT a) (NN cat)) (VP (V ate) (NP (DT the) (NN cake pie))))"); +  vector<unsigned> a, b; +  vector<WordID> aw, bw; +  cerr << "TREE1: " << tree << endl; +  cerr << "TREE2: " << tree2 << endl; +  for (auto& sym : tree) +    if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym); +  for (auto& sym : tree2) +    if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym); +  BOOST_CHECK_EQUAL(a.size(), b.size()); +  BOOST_CHECK_EQUAL(aw.size() + 1, bw.size()); +  BOOST_CHECK_EQUAL(aw.size(), 5); +  BOOST_CHECK_EQUAL(TD::GetString(aw), "the boy saw a cat"); +  BOOST_CHECK_EQUAL(TD::GetString(bw), "a cat ate the cake pie"); +  if (a != b) { +    BOOST_CHECK_EQUAL(1,2); +  } + +  string nts; +  for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { +    if (cdec::IsNT(*it)) { +      if (cdec::IsRHS(*it)) it.truncate(); +      if (nts.size()) nts += " "; +      if (cdec::IsLHS(*it)) nts += "("; +      nts += TD::Convert(*it & cdec::ALL_MASK); +      if (cdec::IsFrontier(*it)) nts += "*"; +    } +  } +  BOOST_CHECK_EQUAL(nts, "(S NP* VP*"); + +  nts.clear(); +  int ntc = 0; +  for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { +    if (cdec::IsNT(*it)) { +      if (cdec::IsRHS(*it)) { +        ++ntc; +        if (ntc > 1) it.truncate(); +      } +      if (nts.size()) nts += " "; +      if (cdec::IsLHS(*it)) nts += "("; +      nts += TD::Convert(*it & cdec::ALL_MASK); +      if (cdec::IsFrontier(*it)) nts += "*"; +    } +  } +  BOOST_CHECK_EQUAL(nts, "(S NP VP* (NP DT* NN*"); +} + +BOOST_AUTO_TEST_CASE(TestSharing) { +  cdec::TreeFragment rule1("(S [NP] [VP])", true); +  cdec::TreeFragment rule2("(S [NP] (VP [V] [NP]))", true); +  string r1,r2; +  for (auto sym : rule1) { +    if (r1.size()) r1 += " "; +    if (cdec::IsLHS(sym)) r1 += "("; +    r1 += TD::Convert(sym & cdec::ALL_MASK); +    if (cdec::IsFrontier(sym)) r1 += "*"; +  } +  for (auto sym : rule2) { +    if (r2.size()) r2 += " "; +    if (cdec::IsLHS(sym)) r2 += "("; +    r2 += TD::Convert(sym & cdec::ALL_MASK); +    if (cdec::IsFrontier(sym)) r2 += "*"; +  } +  cerr << rule1 << endl; +  cerr << r1 << endl; +  cerr << rule2 << endl; +  cerr << r2 << endl; +  BOOST_CHECK_EQUAL(r1, "(S NP* VP*"); +  BOOST_CHECK_EQUAL(r2, "(S NP* VP (VP V* NP*"); +} + +BOOST_AUTO_TEST_CASE(TestEndInvariants) { +  cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); +  BOOST_CHECK(tree.end().at_end()); +  BOOST_CHECK(!tree.begin().at_end()); +} + +BOOST_AUTO_TEST_CASE(TestBegins) { +  cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); +  for (auto it = tree.begin(1); it != tree.end(); ++it) { +    cerr << TD::Convert(*it & cdec::ALL_MASK) << endl; +  } +} + +BOOST_AUTO_TEST_CASE(TestRemainder) { +  cdec::TreeFragment tree("(S (A a) (B b))"); +  auto it = tree.begin(); +  ++it; +  BOOST_CHECK(cdec::IsRHS(*it)); +  cerr << tree << endl; +  auto itr = it.remainder(); +  while(itr != tree.end()) { +    cerr << TD::Convert(*itr & cdec::ALL_MASK) << endl; +    ++itr; +  } +} + + diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 1c249836..cd6ee550 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -1,5 +1,6 @@  #include <algorithm>  #include <vector> +#include <queue>  #include <boost/functional/hash.hpp>  #include <unordered_map>  #include "tree_fragment.h" @@ -15,11 +16,10 @@ using namespace std;  struct Tree2StringGrammarNode {    map<unsigned, Tree2StringGrammarNode> next; -  string rules; +  vector<TRulePtr> rules;  }; -void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGrammarNode>* proots) { -  unordered_map<unsigned, Tree2StringGrammarNode>& roots = *proots; +void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) {    string line;    while(getline(*in, line)) {      size_t pos = line.find("|||"); @@ -28,32 +28,124 @@ void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGram      unsigned xc = 0;      while (line[pos - 1] == ' ') { --pos; xc++; }      cdec::TreeFragment rule_src(line.substr(0, pos), true); -    Tree2StringGrammarNode* cur = &roots[rule_src.root]; -    for (auto sym : rule_src) +    Tree2StringGrammarNode* cur = root; +    ostringstream os; +    int lhs = -(rule_src.root & cdec::ALL_MASK); +    // build source RHS for SCFG projection +    // TODO - this is buggy - it will generate a well-formed SCFG rule +    // but it will not generate source strings correctly +    vector<int> frhs; +    for (auto sym : rule_src) {        cur = &cur->next[sym]; +      if (sym) { +        if (cdec::IsFrontier(sym)) {  // frontier symbols -> variables +          int nt = (sym & cdec::ALL_MASK); +          frhs.push_back(-nt); +        } else if (cdec::IsTerminal(sym)) { +          frhs.push_back(sym); +        } +      } +    } +    os << '[' << TD::Convert(-lhs) << "] |||"; +    for (auto x : frhs) { +      os << ' '; +      if (x < 0) +        os << '[' << TD::Convert(-x) << ']'; +      else +        os << TD::Convert(x); +    }      pos += 3 + xc;      while(line[pos] == ' ') { ++pos; } -    size_t pos2 = line.find("|||", pos); -    assert(pos2 != string::npos); -    while (line[pos2 - 1] == ' ') { --pos2; } -    cur->rules = line.substr(pos, pos2 - pos); -    cerr << "OUTPUT = '" << cur->rules << "'\n"; +    os << " ||| " << line.substr(pos); +    TRulePtr rule(new TRule(os.str())); +    cur->rules.push_back(rule);    }  } +struct ParserState { +  ParserState() : in_iter(), node() {} +  cdec::TreeFragment::iterator in_iter; +  ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const int rt) : +      in_iter(it), +      root_type(rt), +      node(n) {} +  ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : +      in_iter(it), +      future_work(p.future_work), +      root_type(p.root_type), +      node(n) {} +  vector<ParserState> future_work; +  int root_type; // lhs of top level NT +  Tree2StringGrammarNode* node; +}; +  struct Tree2StringTranslatorImpl { -  unordered_map<unsigned, Tree2StringGrammarNode> roots; // root['S'] gives rule network for S rules +  Tree2StringGrammarNode root;    Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) {      ReadFile rf(conf["grammar"].as<vector<string>>()[0]); -    ReadTree2StringGrammar(rf.stream(), &roots); +    ReadTree2StringGrammar(rf.stream(), &root);    }    bool Translate(const string& input,                   SentenceMetadata* smeta,                   const vector<double>& weights,                   Hypergraph* minus_lm_forest) {      cdec::TreeFragment input_tree(input, false); -    cerr << "Tree2StringTranslatorImpl: please implement this!\n"; -    return false; +    const int kS = -TD::Convert("S"); +    Hypergraph hg; +    queue<ParserState> q; +    q.push(ParserState(input_tree.begin(), &root, kS)); +    while(!q.empty()) { +      ParserState& s = q.front(); + +      if (s.in_iter.at_end()) { // completed a traversal of a subtree +        cerr << "I traversed a subtree of the input...\n"; +        if (s.node->rules.size()) { +          // TODO: build hypergraph +          for (auto& r : s.node->rules) +            cerr << "I can build: " << r->AsString() << endl; +          for (auto& w : s.future_work) +            q.push(w); +        } else { +          cerr << "I can't build anything :(\n"; +        } +      } else { // more input tree to match +        unsigned sym = *s.in_iter; +        if (cdec::IsLHS(sym)) { +          auto nit = s.node->next.find(sym); +          if (nit != s.node->next.end()) { +            //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; +            q.push(ParserState(++s.in_iter, &nit->second, s)); +          } +        } else if (cdec::IsRHS(sym)) { +          //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; +          cdec::TreeFragment::iterator var = s.in_iter; +          var.truncate(); +          auto nit1 = s.node->next.find(sym); +          auto nit2 = s.node->next.find(*var); +          if (nit2 != s.node->next.end()) { +            //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; +            ParserState new_s(++var, &nit2->second, s); +            ParserState new_work(s.in_iter.remainder(), &root, -(sym & cdec::ALL_MASK)); +            new_s.future_work.push_back(new_work);  // if this traversal of the input succeeds, future_work goes on the q +            q.push(new_s); +          } +          if (nit1 != s.node->next.end()) { +            //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; +            q.push(ParserState(++s.in_iter, &nit1->second, s)); +          } +        } else if (cdec::IsTerminal(sym)) { +          auto nit = s.node->next.find(sym); +          if (nit != s.node->next.end()) { +            //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; +            q.push(ParserState(++s.in_iter, &nit->second, s)); +          } +        } else { +          cerr << "This can never happen!\n"; abort(); +        } +      } +      q.pop(); +    } +    minus_lm_forest->swap(hg);    }  }; diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 93aad64e..78a993b8 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -36,7 +36,7 @@ void TreeFragment::DebugRec(unsigned cur, ostream* out) const {      *out << ' ';      if (IsFrontier(x)) {        *out << '[' << TD::Convert(x & ALL_MASK) << ']'; -    } else if (IsInternalNT(x)) { +    } else if (IsRHS(x)) {        DebugRec(x & ALL_MASK, out);      } else { // must be terminal        *out << TD::Convert(x); @@ -66,7 +66,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned        // recursively call parser to deal with constituent        ParseRec(tree, afs, cp, symp, np, &cp, &symp, &np);        unsigned ind = np - 1; -      rhs.push_back(ind | NT_BIT); +      rhs.push_back(ind | RHS_BIT);      } else { // deal with terminal / nonterminal substitution        ++symp;        assert(tree[cp] != ' '); @@ -95,7 +95,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned    } // continuent has completed, cp is at ), build node    const unsigned j = symp; // span from (i,j)    // add an internal non-terminal symbol -  const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | NT_BIT; +  const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | RHS_BIT;    nodes[np] = TreeFragmentProduction(nt, rhs);    //cerr << np << " production(" << i << "," << j << ")=  " << TD::Convert(nt & ALL_MASK) << " -->";    //for (auto& x : rhs) { @@ -113,11 +113,15 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned  }  BreadthFirstIterator TreeFragment::begin() const { -  return BreadthFirstIterator(this); +  return BreadthFirstIterator(this, nodes.size() - 1); +} + +BreadthFirstIterator TreeFragment::begin(unsigned node_idx) const { +  return BreadthFirstIterator(this, node_idx);  }  BreadthFirstIterator TreeFragment::end() const { -  return BreadthFirstIterator(this, 0); +  return BreadthFirstIterator(this);  }  } diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index a38dbdfa..b83afc27 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -1,7 +1,7 @@  #ifndef TREE_FRAGMENT  #define TREE_FRAGMENT -#include <queue> +#include <deque>  #include <iostream>  #include <vector>  #include <string> @@ -12,18 +12,32 @@ namespace cdec {  class BreadthFirstIterator; -static const unsigned NT_BIT       = 0x40000000u; -static const unsigned FRONTIER_BIT = 0x80000000u; -static const unsigned ALL_MASK     = 0x0FFFFFFFu; +static const unsigned LHS_BIT         = 0x10000000u; +static const unsigned RHS_BIT         = 0x20000000u; +static const unsigned FRONTIER_BIT    = 0x40000000u; +static const unsigned RESERVED_BIT    = 0x80000000u; +static const unsigned ALL_MASK        = 0x0FFFFFFFu; -inline bool IsInternalNT(unsigned x) { -  return (x & NT_BIT); +inline bool IsNT(unsigned x) { +  return (x & (LHS_BIT | RHS_BIT | FRONTIER_BIT)); +} + +inline bool IsLHS(unsigned x) { +  return (x & LHS_BIT); +} + +inline bool IsRHS(unsigned x) { +  return (x & RHS_BIT);  }  inline bool IsFrontier(unsigned x) {    return (x & FRONTIER_BIT);  } +inline bool IsTerminal(unsigned x) { +  return (x & ALL_MASK) == x; +} +  struct TreeFragmentProduction {    TreeFragmentProduction() {}    TreeFragmentProduction(int nttype, const std::vector<unsigned>& r) : lhs(nttype), rhs(r) {} @@ -46,6 +60,7 @@ class TreeFragment {    typedef const unsigned & reference;    iterator begin() const; +  iterator begin(unsigned node_idx) const;    iterator end() const;   private: @@ -62,24 +77,28 @@ class TreeFragment {  };  struct TFIState { -  TFIState() : node(), rhspos() {} -  TFIState(unsigned n, unsigned p) : node(n), rhspos(p) {} -  bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos; } -  bool operator!=(const TFIState& o) const { return node != o.node && rhspos != o.rhspos; } +  TFIState() : node(), rhspos(), state() {} +  TFIState(unsigned n, unsigned p, unsigned s) : node(n), rhspos(p), state(s) {} +  bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; } +  bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; }    unsigned short node;    unsigned short rhspos; +  unsigned char state;  };  class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, unsigned> {    const TreeFragment* tf_; -  std::queue<TFIState> q_; +  std::deque<TFIState> q_;    unsigned sym;   public: -  explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) { -    q_.push(TFIState(tf->nodes.size() - 1, 0)); +  BreadthFirstIterator() : tf_(), sym() {} +  // used for begin +  explicit BreadthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) { +    q_.push_back(TFIState(node_idx, 0, 0));      Stage();    } -  BreadthFirstIterator(const TreeFragment* tf, int) : tf_(tf) {} +  // used for end +  explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) {}    const unsigned& operator*() const { return sym; }    const unsigned* operator->() const { return &sym; }    bool operator==(const BreadthFirstIterator& other) const { @@ -88,26 +107,20 @@ class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, uns    bool operator!=(const BreadthFirstIterator& other) const {      return (tf_ != other.tf_) || (q_ != other.q_);    } -  void Stage() { -    if (q_.empty()) return; -    const TFIState& s = q_.front(); -    sym = tf_->nodes[s.node].rhs[s.rhspos]; -    if (IsInternalNT(sym)) { -      q_.push(TFIState(sym & ALL_MASK, 0)); -      sym = tf_->nodes[sym & ALL_MASK].lhs; -    } -  }    const BreadthFirstIterator& operator++() {      TFIState& s = q_.front(); -    const unsigned len = tf_->nodes[s.node].rhs.size(); -    s.rhspos++; -    if (s.rhspos > len) { -      q_.pop(); +    if (s.state == 0) { +      s.state++;        Stage(); -    } else if (s.rhspos == len) { -      sym = 0;      } else { -      Stage(); +      const unsigned len = tf_->nodes[s.node].rhs.size(); +      s.rhspos++; +      if (s.rhspos >= len) { +        q_.pop_front(); +        Stage(); +      } else { +        Stage(); +      }      }      return *this;    } @@ -116,6 +129,42 @@ class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, uns      ++(*this);      return res;    } +  // tell iterator not to explore the subtree rooted at sym +  // should only be called once per NT symbol encountered +  const BreadthFirstIterator& truncate() { +    assert(IsRHS(sym)); +    sym &= ALL_MASK; +    sym |= FRONTIER_BIT; +    q_.pop_back(); +    return *this; +  } +  BreadthFirstIterator remainder() const { +    assert(IsRHS(sym)); +    return BreadthFirstIterator(tf_, q_.back()); +  } +  bool at_end() const { +    return q_.empty(); +  } + private: +  void Stage() { +    if (q_.empty()) return; +    const TFIState& s = q_.front(); +    if (s.state == 0) { +      sym = (tf_->nodes[s.node].lhs & ALL_MASK) | LHS_BIT; +    } else { +      sym = tf_->nodes[s.node].rhs[s.rhspos]; +      if (IsRHS(sym)) { +        q_.push_back(TFIState(sym & ALL_MASK, 0, 0)); +        sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT; +      } +    } +  } + +  // used by remainder +  BreadthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) { +    q_.push_back(s); +    Stage(); +  }  };  inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) { | 
