diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/tree2string_translator.cc | 80 | 
1 files changed, 61 insertions, 19 deletions
| diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 7bc49132..6966ccf8 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -1,8 +1,10 @@  #include <algorithm>  #include <vector>  #include <queue> +#include <map> +#include <unordered_set> +#include <boost/shared_ptr.hpp>  #include <boost/functional/hash.hpp> -#include <unordered_map>  #include "tree_fragment.h"  #include "translator.h"  #include "hg.h" @@ -74,16 +76,43 @@ struct ParserState {        future_work(p.future_work),        input_node_idx(p.input_node_idx),        node(n) {} -  vector<ParserState> future_work; +  bool operator==(const ParserState& o) const { +    return node == o.node && input_node_idx == o.input_node_idx && +           future_work == o.future_work && in_iter == o.in_iter; +  } +  vector<unsigned> future_work;    int input_node_idx; // lhs of top level NT    Tree2StringGrammarNode* node;  }; +namespace std { +  template<> +  struct hash<ParserState> { +    size_t operator()(const ParserState& s) const { +      size_t h = boost::hash_range(s.future_work.begin(), s.future_work.end()); +      boost::hash_combine(h, boost::hash_value(s.node)); +      boost::hash_combine(h, boost::hash_value(s.input_node_idx)); +      //boost::hash_combine(h, ); +      return h; +    } +  };  +}; +  struct Tree2StringTranslatorImpl { -  Tree2StringGrammarNode root; -  Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { -    ReadFile rf(conf["grammar"].as<vector<string>>()[0]); -    ReadTree2StringGrammar(rf.stream(), &root); +  vector<boost::shared_ptr<Tree2StringGrammarNode>> root; +  bool add_pass_through_rules; +  Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) : +      add_pass_through_rules(conf.count("add_pass_through_rules")) { +    if (conf.count("grammar")) { +      const vector<string> gf = conf["grammar"].as<vector<string>>(); +      root.resize(gf.size()); +      unsigned gc = 0; +      for (auto& f : gf) { +        ReadFile rf(f); +        root[gc].reset(new Tree2StringGrammarNode); +        ReadTree2StringGrammar(rf.stream(), &*root[gc++]); +      } +    }    }    bool Translate(const string& input,                   SentenceMetadata* smeta, @@ -94,7 +123,11 @@ struct Tree2StringTranslatorImpl {      hg.ReserveNodes(input_tree.nodes.size());      vector<int> tree2hg(input_tree.nodes.size() + 1, -1);      queue<ParserState> q; -    q.push(ParserState(input_tree.begin(), &root)); +    unordered_set<ParserState> unique;  // only create items one time +    for (auto& g : root) { +      q.push(ParserState(input_tree.begin(), g.get())); +      unique.insert(q.back()); +    }      unsigned tree_top = q.front().input_node_idx;      while(!q.empty()) {        ParserState& s = q.front(); @@ -103,14 +136,14 @@ struct Tree2StringTranslatorImpl {          //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" <<           //   TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl;          if (s.node->rules.size()) { -          TailNodeVector tail;            int& node_id = tree2hg[s.input_node_idx];            if (node_id < 0)              node_id = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK))->id_; -          for (auto& n : s.future_work) { -            int& nix = tree2hg[n.input_node_idx]; +          TailNodeVector tail; +          for (auto n : s.future_work) { +            int& nix = tree2hg[n];              if (nix < 0) -              nix = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK))->id_; +              nix = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK))->id_;              tail.push_back(nix);            }            for (auto& r : s.node->rules) { @@ -120,8 +153,13 @@ struct Tree2StringTranslatorImpl {              // TODO: set i and j              hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]);            } -          for (auto& w : s.future_work) -            q.push(w); +          for (auto n : s.future_work) { +            const auto it = input_tree.begin(n); // start tree iterator at node n +            for (auto& g : root) { +              ParserState s(it, g.get()); +              if (unique.insert(s).second) q.push(s); +            } +          }          } else {            //cerr << "I can't build anything :(\n";          } @@ -131,7 +169,8 @@ struct Tree2StringTranslatorImpl {            auto nit = s.node->next.find(sym);            if (nit != s.node->next.end()) {              //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; -            q.push(ParserState(++s.in_iter, &nit->second, s)); +            ParserState news(++s.in_iter, &nit->second, s); +            if (unique.insert(news).second) q.push(news);            }          } else if (cdec::IsRHS(sym)) {            //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; @@ -141,20 +180,23 @@ struct Tree2StringTranslatorImpl {            auto nit2 = s.node->next.find(*var);            if (nit2 != s.node->next.end()) {              //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; -            ParserState new_s(++var, &nit2->second, s); -            ParserState new_work(s.in_iter.remainder(), &root); +            ++var; +            const unsigned new_work = s.in_iter.child_node(); +            ParserState new_s(var, &nit2->second, s);              new_s.future_work.push_back(new_work);  // if this traversal of the input succeeds, future_work goes on the q -            q.push(new_s); +            if (unique.insert(new_s).second) q.push(new_s);            }            if (nit1 != s.node->next.end()) {              //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; -            q.push(ParserState(++s.in_iter, &nit1->second, s)); +            const ParserState new_s(++s.in_iter, &nit1->second, s); +            if (unique.insert(new_s).second) q.push(new_s);            }          } else if (cdec::IsTerminal(sym)) {            auto nit = s.node->next.find(sym);            if (nit != s.node->next.end()) {              //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; -            q.push(ParserState(++s.in_iter, &nit->second, s)); +            const ParserState new_s(++s.in_iter, &nit->second, s); +            if (unique.insert(new_s).second) q.push(new_s);            }          } else {            cerr << "This can never happen!\n"; abort(); | 
