diff options
-rw-r--r-- | decoder/tree2string_translator.cc | 121 | ||||
-rw-r--r-- | decoder/trule.h | 3 | ||||
-rw-r--r-- | training/utils/grammar_convert.cc | 5 |
3 files changed, 89 insertions, 40 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 3fbf1ee5..29caaf8f 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -30,12 +30,15 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { unsigned xc = 0; while (line[pos - 1] == ' ') { --pos; xc++; } cdec::TreeFragment rule_src(line.substr(0, pos), true); - Tree2StringGrammarNode* cur = root; + // TODO transducer_state should (optionally?) be read from input + const unsigned transducer_state = 0; + Tree2StringGrammarNode* cur = &root->next[transducer_state]; ostringstream os; int lhs = -(rule_src.root & cdec::ALL_MASK); // build source RHS for SCFG projection // TODO - this is buggy - it will generate a well-formed SCFG rule - // but it will not generate source strings correctly + // so it will not generate source strings correctly + // it will, however, generate target translations appropriately vector<int> frhs; for (auto sym : rule_src) { //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; @@ -59,40 +62,65 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { while(line[pos] == ' ') { ++pos; } os << " ||| " << line.substr(pos); TRulePtr rule(new TRule(os.str())); + // TODO the transducer_state you end up in after using this rule (for each NT) + // needs to be read and encoded somehow in the rule (for use XXX) cur->rules.push_back(rule); //cerr << "RULE: " << rule->AsString() << "\n\n"; } } +// represents where in an input parse tree the transducer must continue +// and what state it is in +struct TransducerState { + TransducerState() : input_node_idx(), transducer_state() {} + TransducerState(unsigned n, unsigned q) : input_node_idx(n), transducer_state(q) {} + bool operator==(const TransducerState& o) const { + return input_node_idx == o.input_node_idx && + transducer_state == o.transducer_state; + } + unsigned input_node_idx; + unsigned transducer_state; +}; + +// represents the state of the composition algorithm struct ParserState { ParserState() : in_iter(), node() {} cdec::TreeFragment::iterator in_iter; - ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n) : + ParserState(const cdec::TreeFragment::iterator& it, unsigned q, Tree2StringGrammarNode* n) : in_iter(it), - input_node_idx(it.node_idx()), + task(it.node_idx(), q), node(n) {} ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : in_iter(it), future_work(p.future_work), - input_node_idx(p.input_node_idx), + task(p.task), node(n) {} bool operator==(const ParserState& o) const { - return node == o.node && input_node_idx == o.input_node_idx && + return node == o.node && task == o.task && future_work == o.future_work && in_iter == o.in_iter; } - vector<unsigned> future_work; - int input_node_idx; // lhs of top level NT - Tree2StringGrammarNode* node; // pointer into grammar + vector<TransducerState> future_work; + TransducerState task; // subtree root where and in what state did the transducer start? + Tree2StringGrammarNode* node; // pointer into grammar trie }; namespace std { template<> + struct hash<TransducerState> { + size_t operator()(const TransducerState& q) const { + size_t h = boost::hash_value(q.transducer_state); + boost::hash_combine(h, boost::hash_value(q.input_node_idx)); + return h; + } + }; + template<> struct hash<ParserState> { size_t operator()(const ParserState& s) const { - size_t h = boost::hash_range(s.future_work.begin(), s.future_work.end()); - boost::hash_combine(h, boost::hash_value(s.node)); - boost::hash_combine(h, boost::hash_value(s.input_node_idx)); - //boost::hash_combine(h, ); + size_t h = boost::hash_value(s.node); + for (auto& w : s.future_work) + boost::hash_combine(h, hash<TransducerState>()(w)); + boost::hash_combine(h, hash<TransducerState>()(s.task)); + // TODO hash with iterator return h; } }; @@ -144,6 +172,9 @@ struct Tree2StringTranslatorImpl { os << ')'; cdec::TreeFragment rule_src(os.str(), true); Tree2StringGrammarNode* cur = root.back().get(); + // do we need all transducer states here??? a list??? no pass through rules??? + unsigned transducer_state = 0; + cur = &cur->next[transducer_state]; for (auto sym : rule_src) cur = &cur->next[sym]; TRulePtr rule(new TRule(rhse, rhsf, lhs)); @@ -167,15 +198,19 @@ struct Tree2StringTranslatorImpl { if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; hg.ReserveNodes(input_tree.nodes.size()); - vector<int> tree2hg(input_tree.nodes.size() + 1, -1); + unordered_map<TransducerState, unsigned> x2hg(input_tree.nodes.size() * 5); queue<ParserState> q; unordered_set<ParserState> unique; // only create items one time for (auto& g : root) { - q.push(ParserState(input_tree.begin(), g.get())); - unique.insert(q.back()); + unsigned q_0 = 0; // TODO initialize q_0 properly once multi-state transducers are supported + auto rit = g->next.find(q_0); + if (rit != g->next.end()) { // does this g have this transducer state? + q.push(ParserState(input_tree.begin(), q_0, &rit->second)); + unique.insert(q.back()); + } } if (q.size() == 0) return false; - unsigned tree_top = q.front().input_node_idx; + const TransducerState tree_top = q.front().task; while(!q.empty()) { ParserState& s = q.front(); @@ -183,21 +218,24 @@ struct Tree2StringTranslatorImpl { //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { - int& node_id = tree2hg[s.input_node_idx]; - if (node_id < 0) { - HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK)); - new_node->node_hash = s.input_node_idx + 1; - node_id = new_node->id_; + auto it = x2hg.find(s.task); + if (it == x2hg.end()) { + // TODO create composite state symbol that encodes transducer state type? + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.task.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = std::hash<TransducerState>()(s.task); + it = x2hg.insert(make_pair(s.task, new_node->id_)).first; } + const unsigned node_id = it->second; TailNodeVector tail; - for (auto n : s.future_work) { - int& nix = tree2hg[n]; - if (nix < 0) { - HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK)); - new_node->node_hash = n + 1; - nix = new_node->id_; + for (const auto& n : s.future_work) { + auto it = x2hg.find(n); + if (it == x2hg.end()) { + // TODO create composite state symbol that encodes transducer state type? + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = std::hash<TransducerState>()(n); + it = x2hg.insert(make_pair(n, new_node->id_)).first; } - tail.push_back(nix); + tail.push_back(it->second); } for (auto& r : s.node->rules) { assert(tail.size() == r->Arity()); @@ -206,11 +244,14 @@ struct Tree2StringTranslatorImpl { // TODO: set i and j hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); } - for (auto n : s.future_work) { - const auto it = input_tree.begin(n); // start tree iterator at node n + for (const auto& n : s.future_work) { + const auto it = input_tree.begin(n.input_node_idx); // start tree iterator at node n for (auto& g : root) { - ParserState s(it, g.get()); - if (unique.insert(s).second) q.push(s); + auto rit = g->next.find(n.transducer_state); + if (rit != g->next.end()) { // does this g have this transducer state? + const ParserState s(it, n.transducer_state, &rit->second); + if (unique.insert(s).second) q.push(s); + } } } } else { @@ -234,9 +275,13 @@ struct Tree2StringTranslatorImpl { if (nit2 != s.node->next.end()) { //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; ++var; - const unsigned new_work = s.in_iter.child_node(); + // TODO: find out from rule what the new target state is (the 0 in the next line) + // if it is associated with the rule, we won't know until we match the whole input + // so the 0 may be okay (if this is the case, which is probably the easiest thing, + // then the state must be dealt with when the future work becomes real work) + const TransducerState new_task(s.in_iter.child_node(), 0); ParserState new_s(var, &nit2->second, s); - new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q + new_s.future_work.push_back(new_task); // if this traversal of the input succeeds, future_work goes on the q if (unique.insert(new_s).second) q.push(new_s); } //else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; } @@ -259,10 +304,10 @@ struct Tree2StringTranslatorImpl { } q.pop(); } - int goal = tree2hg[tree_top]; - if (goal < 0) return false; + const auto goal_it = x2hg.find(tree_top); + if (goal_it == x2hg.end()) return false; //cerr << "Goal node: " << goal << endl; - hg.TopologicallySortNodesAndEdges(goal); + hg.TopologicallySortNodesAndEdges(goal_it->second); hg.Reweight(weights); // there might be nodes that cannot be derived diff --git a/decoder/trule.h b/decoder/trule.h index 7dced5a1..cc370757 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -42,6 +42,9 @@ class TRule { scores_.set_value(feat_ids[i], feat_vals[i]); } + TRule(WordID lhs, const WordID* src, int src_size, const WordID* trg, int trg_size, int arity, int pi, int pj) : + e_(trg, trg + trg_size), f_(src, src + src_size), lhs_(lhs), arity_(arity), prev_i(pi), prev_j(pj) {} + bool IsGoal() const; explicit TRule(const std::vector<WordID>& e) : e_(e), lhs_(0), prev_i(-1), prev_j(-1) {} diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc index 607a7cb9..58d1957c 100644 --- a/training/utils/grammar_convert.cc +++ b/training/utils/grammar_convert.cc @@ -292,10 +292,10 @@ int main(int argc, char **argv) { int lc = 0; Hypergraph hg; map<WordID, int> lhs2node; + string line; while(*in) { - string line; + getline(*in,line); ++lc; - getline(*in, line); if (is_json_input) { if (line.empty() || line[0] == '#') continue; string ref; @@ -342,6 +342,7 @@ int main(int argc, char **argv) { edge->feature_values_ = tr->scores_; Hypergraph::Node* node = &hg.nodes_[head]; hg.ConnectEdgeToHeadNode(edge, node); + node->node_hash = lc; } } } |