summaryrefslogtreecommitdiff
path: root/decoder/tree2string_translator.cc
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-04-01 18:47:20 -0400
committerChris Dyer <redpony@gmail.com>2014-04-01 18:47:20 -0400
commit241a9932588563f7952f7d758e3f77d8c499443c (patch)
treedd349cf16eb184dd8f657a292556527c31a28c8c /decoder/tree2string_translator.cc
parent6a3e80bb4b2a6bb10899183299d14fc070db1654 (diff)
deal with multiple grammars in t2s
Diffstat (limited to 'decoder/tree2string_translator.cc')
-rw-r--r--decoder/tree2string_translator.cc80
1 files changed, 61 insertions, 19 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index 7bc49132..6966ccf8 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -1,8 +1,10 @@
#include <algorithm>
#include <vector>
#include <queue>
+#include <map>
+#include <unordered_set>
+#include <boost/shared_ptr.hpp>
#include <boost/functional/hash.hpp>
-#include <unordered_map>
#include "tree_fragment.h"
#include "translator.h"
#include "hg.h"
@@ -74,16 +76,43 @@ struct ParserState {
future_work(p.future_work),
input_node_idx(p.input_node_idx),
node(n) {}
- vector<ParserState> future_work;
+ bool operator==(const ParserState& o) const {
+ return node == o.node && input_node_idx == o.input_node_idx &&
+ future_work == o.future_work && in_iter == o.in_iter;
+ }
+ vector<unsigned> future_work;
int input_node_idx; // lhs of top level NT
Tree2StringGrammarNode* node;
};
+namespace std {
+ template<>
+ struct hash<ParserState> {
+ size_t operator()(const ParserState& s) const {
+ size_t h = boost::hash_range(s.future_work.begin(), s.future_work.end());
+ boost::hash_combine(h, boost::hash_value(s.node));
+ boost::hash_combine(h, boost::hash_value(s.input_node_idx));
+ //boost::hash_combine(h, );
+ return h;
+ }
+ };
+};
+
struct Tree2StringTranslatorImpl {
- Tree2StringGrammarNode root;
- Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) {
- ReadFile rf(conf["grammar"].as<vector<string>>()[0]);
- ReadTree2StringGrammar(rf.stream(), &root);
+ vector<boost::shared_ptr<Tree2StringGrammarNode>> root;
+ bool add_pass_through_rules;
+ Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) :
+ add_pass_through_rules(conf.count("add_pass_through_rules")) {
+ if (conf.count("grammar")) {
+ const vector<string> gf = conf["grammar"].as<vector<string>>();
+ root.resize(gf.size());
+ unsigned gc = 0;
+ for (auto& f : gf) {
+ ReadFile rf(f);
+ root[gc].reset(new Tree2StringGrammarNode);
+ ReadTree2StringGrammar(rf.stream(), &*root[gc++]);
+ }
+ }
}
bool Translate(const string& input,
SentenceMetadata* smeta,
@@ -94,7 +123,11 @@ struct Tree2StringTranslatorImpl {
hg.ReserveNodes(input_tree.nodes.size());
vector<int> tree2hg(input_tree.nodes.size() + 1, -1);
queue<ParserState> q;
- q.push(ParserState(input_tree.begin(), &root));
+ unordered_set<ParserState> unique; // only create items one time
+ for (auto& g : root) {
+ q.push(ParserState(input_tree.begin(), g.get()));
+ unique.insert(q.back());
+ }
unsigned tree_top = q.front().input_node_idx;
while(!q.empty()) {
ParserState& s = q.front();
@@ -103,14 +136,14 @@ struct Tree2StringTranslatorImpl {
//cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" <<
// TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl;
if (s.node->rules.size()) {
- TailNodeVector tail;
int& node_id = tree2hg[s.input_node_idx];
if (node_id < 0)
node_id = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK))->id_;
- for (auto& n : s.future_work) {
- int& nix = tree2hg[n.input_node_idx];
+ TailNodeVector tail;
+ for (auto n : s.future_work) {
+ int& nix = tree2hg[n];
if (nix < 0)
- nix = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK))->id_;
+ nix = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK))->id_;
tail.push_back(nix);
}
for (auto& r : s.node->rules) {
@@ -120,8 +153,13 @@ struct Tree2StringTranslatorImpl {
// TODO: set i and j
hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]);
}
- for (auto& w : s.future_work)
- q.push(w);
+ for (auto n : s.future_work) {
+ const auto it = input_tree.begin(n); // start tree iterator at node n
+ for (auto& g : root) {
+ ParserState s(it, g.get());
+ if (unique.insert(s).second) q.push(s);
+ }
+ }
} else {
//cerr << "I can't build anything :(\n";
}
@@ -131,7 +169,8 @@ struct Tree2StringTranslatorImpl {
auto nit = s.node->next.find(sym);
if (nit != s.node->next.end()) {
//cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
- q.push(ParserState(++s.in_iter, &nit->second, s));
+ ParserState news(++s.in_iter, &nit->second, s);
+ if (unique.insert(news).second) q.push(news);
}
} else if (cdec::IsRHS(sym)) {
//cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
@@ -141,20 +180,23 @@ struct Tree2StringTranslatorImpl {
auto nit2 = s.node->next.find(*var);
if (nit2 != s.node->next.end()) {
//cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
- ParserState new_s(++var, &nit2->second, s);
- ParserState new_work(s.in_iter.remainder(), &root);
+ ++var;
+ const unsigned new_work = s.in_iter.child_node();
+ ParserState new_s(var, &nit2->second, s);
new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q
- q.push(new_s);
+ if (unique.insert(new_s).second) q.push(new_s);
}
if (nit1 != s.node->next.end()) {
//cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
- q.push(ParserState(++s.in_iter, &nit1->second, s));
+ const ParserState new_s(++s.in_iter, &nit1->second, s);
+ if (unique.insert(new_s).second) q.push(new_s);
}
} else if (cdec::IsTerminal(sym)) {
auto nit = s.node->next.find(sym);
if (nit != s.node->next.end()) {
//cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl;
- q.push(ParserState(++s.in_iter, &nit->second, s));
+ const ParserState new_s(++s.in_iter, &nit->second, s);
+ if (unique.insert(new_s).second) q.push(new_s);
}
} else {
cerr << "This can never happen!\n"; abort();