diff options
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/tree2string_translator.cc | 48 | ||||
-rw-r--r-- | decoder/tree_fragment.h | 1 |
2 files changed, 37 insertions, 12 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index cd6ee550..09eca147 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -65,17 +65,17 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { struct ParserState { ParserState() : in_iter(), node() {} cdec::TreeFragment::iterator in_iter; - ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const int rt) : + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n) : in_iter(it), - root_type(rt), + input_node_idx(it.node_idx()), node(n) {} ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : in_iter(it), future_work(p.future_work), - root_type(p.root_type), + input_node_idx(p.input_node_idx), node(n) {} vector<ParserState> future_work; - int root_type; // lhs of top level NT + int input_node_idx; // lhs of top level NT Tree2StringGrammarNode* node; }; @@ -90,23 +90,40 @@ struct Tree2StringTranslatorImpl { const vector<double>& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); - const int kS = -TD::Convert("S"); Hypergraph hg; + hg.ReserveNodes(input_tree.nodes.size()); + vector<int> tree2hg(input_tree.nodes.size() + 1, -1); queue<ParserState> q; - q.push(ParserState(input_tree.begin(), &root, kS)); + q.push(ParserState(input_tree.begin(), &root)); + unsigned tree_top = q.front().input_node_idx; while(!q.empty()) { ParserState& s = q.front(); if (s.in_iter.at_end()) { // completed a traversal of a subtree - cerr << "I traversed a subtree of the input...\n"; + //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << + // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { - // TODO: build hypergraph - for (auto& r : s.node->rules) - cerr << "I can build: " << r->AsString() << endl; + TailNodeVector tail; + int& node_id = tree2hg[s.input_node_idx]; + if (node_id < 0) + node_id = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK))->id_; + for (auto& n : s.future_work) { + int& nix = tree2hg[n.input_node_idx]; + if (nix < 0) + nix = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK))->id_; + tail.push_back(nix); + } + for (auto& r : s.node->rules) { + assert(tail.size() == r->Arity()); + HG::Edge* new_edge = hg.AddEdge(r, tail); + new_edge->feature_values_ = r->GetFeatureValues(); + // TODO: set i and j + hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); + } for (auto& w : s.future_work) q.push(w); } else { - cerr << "I can't build anything :(\n"; + //cerr << "I can't build anything :(\n"; } } else { // more input tree to match unsigned sym = *s.in_iter; @@ -125,7 +142,7 @@ struct Tree2StringTranslatorImpl { if (nit2 != s.node->next.end()) { //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; ParserState new_s(++var, &nit2->second, s); - ParserState new_work(s.in_iter.remainder(), &root, -(sym & cdec::ALL_MASK)); + ParserState new_work(s.in_iter.remainder(), &root); new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q q.push(new_s); } @@ -145,7 +162,14 @@ struct Tree2StringTranslatorImpl { } q.pop(); } + int goal = tree2hg[tree_top]; + if (goal < 0) return false; + //cerr << "Goal node: " << goal << endl; + hg.TopologicallySortNodesAndEdges(goal); + hg.Reweight(weights); + //hg.PrintGraphviz(); minus_lm_forest->swap(hg); + return true; } }; diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index b83afc27..ceb7fa60 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -107,6 +107,7 @@ class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, uns bool operator!=(const BreadthFirstIterator& other) const { return (tf_ != other.tf_) || (q_ != other.q_); } + unsigned node_idx() const { return q_.front().node; } const BreadthFirstIterator& operator++() { TFIState& s = q_.front(); if (s.state == 0) { |