summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/tree2string_translator.cc121
-rw-r--r--decoder/trule.h3
-rw-r--r--training/utils/grammar_convert.cc5
3 files changed, 89 insertions, 40 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index 3fbf1ee5..29caaf8f 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -30,12 +30,15 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) {
unsigned xc = 0;
while (line[pos - 1] == ' ') { --pos; xc++; }
cdec::TreeFragment rule_src(line.substr(0, pos), true);
- Tree2StringGrammarNode* cur = root;
+ // TODO transducer_state should (optionally?) be read from input
+ const unsigned transducer_state = 0;
+ Tree2StringGrammarNode* cur = &root->next[transducer_state];
ostringstream os;
int lhs = -(rule_src.root & cdec::ALL_MASK);
// build source RHS for SCFG projection
// TODO - this is buggy - it will generate a well-formed SCFG rule
- // but it will not generate source strings correctly
+ // so it will not generate source strings correctly
+ // it will, however, generate target translations appropriately
vector<int> frhs;
for (auto sym : rule_src) {
//cerr << TD::Convert(sym & cdec::ALL_MASK) << endl;
@@ -59,40 +62,65 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) {
while(line[pos] == ' ') { ++pos; }
os << " ||| " << line.substr(pos);
TRulePtr rule(new TRule(os.str()));
+ // TODO the transducer_state you end up in after using this rule (for each NT)
+ // needs to be read and encoded somehow in the rule (for use XXX)
cur->rules.push_back(rule);
//cerr << "RULE: " << rule->AsString() << "\n\n";
}
}
+// represents where in an input parse tree the transducer must continue
+// and what state it is in
+struct TransducerState {
+ TransducerState() : input_node_idx(), transducer_state() {}
+ TransducerState(unsigned n, unsigned q) : input_node_idx(n), transducer_state(q) {}
+ bool operator==(const TransducerState& o) const {
+ return input_node_idx == o.input_node_idx &&
+ transducer_state == o.transducer_state;
+ }
+ unsigned input_node_idx;
+ unsigned transducer_state;
+};
+
+// represents the state of the composition algorithm
struct ParserState {
ParserState() : in_iter(), node() {}
cdec::TreeFragment::iterator in_iter;
- ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n) :
+ ParserState(const cdec::TreeFragment::iterator& it, unsigned q, Tree2StringGrammarNode* n) :
in_iter(it),
- input_node_idx(it.node_idx()),
+ task(it.node_idx(), q),
node(n) {}
ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) :
in_iter(it),
future_work(p.future_work),
- input_node_idx(p.input_node_idx),
+ task(p.task),
node(n) {}
bool operator==(const ParserState& o) const {
- return node == o.node && input_node_idx == o.input_node_idx &&
+ return node == o.node && task == o.task &&
future_work == o.future_work && in_iter == o.in_iter;
}
- vector<unsigned> future_work;
- int input_node_idx; // lhs of top level NT
- Tree2StringGrammarNode* node; // pointer into grammar
+ vector<TransducerState> future_work;
+ TransducerState task; // subtree root where and in what state did the transducer start?
+ Tree2StringGrammarNode* node; // pointer into grammar trie
};
namespace std {
template<>
+ struct hash<TransducerState> {
+ size_t operator()(const TransducerState& q) const {
+ size_t h = boost::hash_value(q.transducer_state);
+ boost::hash_combine(h, boost::hash_value(q.input_node_idx));
+ return h;
+ }
+ };
+ template<>
struct hash<ParserState> {
size_t operator()(const ParserState& s) const {
- size_t h = boost::hash_range(s.future_work.begin(), s.future_work.end());
- boost::hash_combine(h, boost::hash_value(s.node));
- boost::hash_combine(h, boost::hash_value(s.input_node_idx));
- //boost::hash_combine(h, );
+ size_t h = boost::hash_value(s.node);
+ for (auto& w : s.future_work)
+ boost::hash_combine(h, hash<TransducerState>()(w));
+ boost::hash_combine(h, hash<TransducerState>()(s.task));
+ // TODO hash with iterator
return h;
}
};
@@ -144,6 +172,9 @@ struct Tree2StringTranslatorImpl {
os << ')';
cdec::TreeFragment rule_src(os.str(), true);
Tree2StringGrammarNode* cur = root.back().get();
+ // do we need all transducer states here??? a list??? no pass through rules???
+ unsigned transducer_state = 0;
+ cur = &cur->next[transducer_state];
for (auto sym : rule_src)
cur = &cur->next[sym];
TRulePtr rule(new TRule(rhse, rhsf, lhs));
@@ -167,15 +198,19 @@ struct Tree2StringTranslatorImpl {
if (add_pass_through_rules) CreatePassThroughRules(input_tree);
Hypergraph hg;
hg.ReserveNodes(input_tree.nodes.size());
- vector<int> tree2hg(input_tree.nodes.size() + 1, -1);
+ unordered_map<TransducerState, unsigned> x2hg(input_tree.nodes.size() * 5);
queue<ParserState> q;
unordered_set<ParserState> unique; // only create items one time
for (auto& g : root) {
- q.push(ParserState(input_tree.begin(), g.get()));
- unique.insert(q.back());
+ unsigned q_0 = 0; // TODO initialize q_0 properly once multi-state transducers are supported
+ auto rit = g->next.find(q_0);
+ if (rit != g->next.end()) { // does this g have this transducer state?
+ q.push(ParserState(input_tree.begin(), q_0, &rit->second));
+ unique.insert(q.back());
+ }
}
if (q.size() == 0) return false;
- unsigned tree_top = q.front().input_node_idx;
+ const TransducerState tree_top = q.front().task;
while(!q.empty()) {
ParserState& s = q.front();
@@ -183,21 +218,24 @@ struct Tree2StringTranslatorImpl {
//cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" <<
// TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl;
if (s.node->rules.size()) {
- int& node_id = tree2hg[s.input_node_idx];
- if (node_id < 0) {
- HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK));
- new_node->node_hash = s.input_node_idx + 1;
- node_id = new_node->id_;
+ auto it = x2hg.find(s.task);
+ if (it == x2hg.end()) {
+ // TODO create composite state symbol that encodes transducer state type?
+ HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.task.input_node_idx].lhs & cdec::ALL_MASK));
+ new_node->node_hash = std::hash<TransducerState>()(s.task);
+ it = x2hg.insert(make_pair(s.task, new_node->id_)).first;
}
+ const unsigned node_id = it->second;
TailNodeVector tail;
- for (auto n : s.future_work) {
- int& nix = tree2hg[n];
- if (nix < 0) {
- HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK));
- new_node->node_hash = n + 1;
- nix = new_node->id_;
+ for (const auto& n : s.future_work) {
+ auto it = x2hg.find(n);
+ if (it == x2hg.end()) {
+ // TODO create composite state symbol that encodes transducer state type?
+ HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK));
+ new_node->node_hash = std::hash<TransducerState>()(n);
+ it = x2hg.insert(make_pair(n, new_node->id_)).first;
}
- tail.push_back(nix);
+ tail.push_back(it->second);
}
for (auto& r : s.node->rules) {
assert(tail.size() == r->Arity());
@@ -206,11 +244,14 @@ struct Tree2StringTranslatorImpl {
// TODO: set i and j
hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]);
}
- for (auto n : s.future_work) {
- const auto it = input_tree.begin(n); // start tree iterator at node n
+ for (const auto& n : s.future_work) {
+ const auto it = input_tree.begin(n.input_node_idx); // start tree iterator at node n
for (auto& g : root) {
- ParserState s(it, g.get());
- if (unique.insert(s).second) q.push(s);
+ auto rit = g->next.find(n.transducer_state);
+ if (rit != g->next.end()) { // does this g have this transducer state?
+ const ParserState s(it, n.transducer_state, &rit->second);
+ if (unique.insert(s).second) q.push(s);
+ }
}
}
} else {
@@ -234,9 +275,13 @@ struct Tree2StringTranslatorImpl {
if (nit2 != s.node->next.end()) {
//cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
++var;
- const unsigned new_work = s.in_iter.child_node();
+ // TODO: find out from rule what the new target state is (the 0 in the next line)
+ // if it is associated with the rule, we won't know until we match the whole input
+ // so the 0 may be okay (if this is the case, which is probably the easiest thing,
+ // then the state must be dealt with when the future work becomes real work)
+ const TransducerState new_task(s.in_iter.child_node(), 0);
ParserState new_s(var, &nit2->second, s);
- new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q
+ new_s.future_work.push_back(new_task); // if this traversal of the input succeeds, future_work goes on the q
if (unique.insert(new_s).second) q.push(new_s);
}
//else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; }
@@ -259,10 +304,10 @@ struct Tree2StringTranslatorImpl {
}
q.pop();
}
- int goal = tree2hg[tree_top];
- if (goal < 0) return false;
+ const auto goal_it = x2hg.find(tree_top);
+ if (goal_it == x2hg.end()) return false;
//cerr << "Goal node: " << goal << endl;
- hg.TopologicallySortNodesAndEdges(goal);
+ hg.TopologicallySortNodesAndEdges(goal_it->second);
hg.Reweight(weights);
// there might be nodes that cannot be derived
diff --git a/decoder/trule.h b/decoder/trule.h
index 7dced5a1..cc370757 100644
--- a/decoder/trule.h
+++ b/decoder/trule.h
@@ -42,6 +42,9 @@ class TRule {
scores_.set_value(feat_ids[i], feat_vals[i]);
}
+ TRule(WordID lhs, const WordID* src, int src_size, const WordID* trg, int trg_size, int arity, int pi, int pj) :
+ e_(trg, trg + trg_size), f_(src, src + src_size), lhs_(lhs), arity_(arity), prev_i(pi), prev_j(pj) {}
+
bool IsGoal() const;
explicit TRule(const std::vector<WordID>& e) : e_(e), lhs_(0), prev_i(-1), prev_j(-1) {}
diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc
index 607a7cb9..58d1957c 100644
--- a/training/utils/grammar_convert.cc
+++ b/training/utils/grammar_convert.cc
@@ -292,10 +292,10 @@ int main(int argc, char **argv) {
int lc = 0;
Hypergraph hg;
map<WordID, int> lhs2node;
+ string line;
while(*in) {
- string line;
+ getline(*in,line);
++lc;
- getline(*in, line);
if (is_json_input) {
if (line.empty() || line[0] == '#') continue;
string ref;
@@ -342,6 +342,7 @@ int main(int argc, char **argv) {
edge->feature_values_ = tr->scores_;
Hypergraph::Node* node = &hg.nodes_[head];
hg.ConnectEdgeToHeadNode(edge, node);
+ node->node_hash = lc;
}
}
}