diff options
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/tree2string_translator.cc | 80 |
1 files changed, 61 insertions, 19 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 7bc49132..6966ccf8 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -1,8 +1,10 @@ #include <algorithm> #include <vector> #include <queue> +#include <map> +#include <unordered_set> +#include <boost/shared_ptr.hpp> #include <boost/functional/hash.hpp> -#include <unordered_map> #include "tree_fragment.h" #include "translator.h" #include "hg.h" @@ -74,16 +76,43 @@ struct ParserState { future_work(p.future_work), input_node_idx(p.input_node_idx), node(n) {} - vector<ParserState> future_work; + bool operator==(const ParserState& o) const { + return node == o.node && input_node_idx == o.input_node_idx && + future_work == o.future_work && in_iter == o.in_iter; + } + vector<unsigned> future_work; int input_node_idx; // lhs of top level NT Tree2StringGrammarNode* node; }; +namespace std { + template<> + struct hash<ParserState> { + size_t operator()(const ParserState& s) const { + size_t h = boost::hash_range(s.future_work.begin(), s.future_work.end()); + boost::hash_combine(h, boost::hash_value(s.node)); + boost::hash_combine(h, boost::hash_value(s.input_node_idx)); + //boost::hash_combine(h, ); + return h; + } + }; +}; + struct Tree2StringTranslatorImpl { - Tree2StringGrammarNode root; - Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { - ReadFile rf(conf["grammar"].as<vector<string>>()[0]); - ReadTree2StringGrammar(rf.stream(), &root); + vector<boost::shared_ptr<Tree2StringGrammarNode>> root; + bool add_pass_through_rules; + Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) : + add_pass_through_rules(conf.count("add_pass_through_rules")) { + if (conf.count("grammar")) { + const vector<string> gf = conf["grammar"].as<vector<string>>(); + root.resize(gf.size()); + unsigned gc = 0; + for (auto& f : gf) { + ReadFile rf(f); + root[gc].reset(new Tree2StringGrammarNode); + ReadTree2StringGrammar(rf.stream(), &*root[gc++]); + } + } } bool Translate(const string& input, SentenceMetadata* smeta, @@ -94,7 +123,11 @@ struct Tree2StringTranslatorImpl { hg.ReserveNodes(input_tree.nodes.size()); vector<int> tree2hg(input_tree.nodes.size() + 1, -1); queue<ParserState> q; - q.push(ParserState(input_tree.begin(), &root)); + unordered_set<ParserState> unique; // only create items one time + for (auto& g : root) { + q.push(ParserState(input_tree.begin(), g.get())); + unique.insert(q.back()); + } unsigned tree_top = q.front().input_node_idx; while(!q.empty()) { ParserState& s = q.front(); @@ -103,14 +136,14 @@ struct Tree2StringTranslatorImpl { //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { - TailNodeVector tail; int& node_id = tree2hg[s.input_node_idx]; if (node_id < 0) node_id = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK))->id_; - for (auto& n : s.future_work) { - int& nix = tree2hg[n.input_node_idx]; + TailNodeVector tail; + for (auto n : s.future_work) { + int& nix = tree2hg[n]; if (nix < 0) - nix = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK))->id_; + nix = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK))->id_; tail.push_back(nix); } for (auto& r : s.node->rules) { @@ -120,8 +153,13 @@ struct Tree2StringTranslatorImpl { // TODO: set i and j hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); } - for (auto& w : s.future_work) - q.push(w); + for (auto n : s.future_work) { + const auto it = input_tree.begin(n); // start tree iterator at node n + for (auto& g : root) { + ParserState s(it, g.get()); + if (unique.insert(s).second) q.push(s); + } + } } else { //cerr << "I can't build anything :(\n"; } @@ -131,7 +169,8 @@ struct Tree2StringTranslatorImpl { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; - q.push(ParserState(++s.in_iter, &nit->second, s)); + ParserState news(++s.in_iter, &nit->second, s); + if (unique.insert(news).second) q.push(news); } } else if (cdec::IsRHS(sym)) { //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; @@ -141,20 +180,23 @@ struct Tree2StringTranslatorImpl { auto nit2 = s.node->next.find(*var); if (nit2 != s.node->next.end()) { //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; - ParserState new_s(++var, &nit2->second, s); - ParserState new_work(s.in_iter.remainder(), &root); + ++var; + const unsigned new_work = s.in_iter.child_node(); + ParserState new_s(var, &nit2->second, s); new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q - q.push(new_s); + if (unique.insert(new_s).second) q.push(new_s); } if (nit1 != s.node->next.end()) { //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; - q.push(ParserState(++s.in_iter, &nit1->second, s)); + const ParserState new_s(++s.in_iter, &nit1->second, s); + if (unique.insert(new_s).second) q.push(new_s); } } else if (cdec::IsTerminal(sym)) { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; - q.push(ParserState(++s.in_iter, &nit->second, s)); + const ParserState new_s(++s.in_iter, &nit->second, s); + if (unique.insert(new_s).second) q.push(new_s); } } else { cerr << "This can never happen!\n"; abort(); |