summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/tree2string_translator.cc48
-rw-r--r--decoder/tree_fragment.h1
2 files changed, 37 insertions, 12 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index cd6ee550..09eca147 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -65,17 +65,17 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) {
struct ParserState {
ParserState() : in_iter(), node() {}
cdec::TreeFragment::iterator in_iter;
- ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const int rt) :
+ ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n) :
in_iter(it),
- root_type(rt),
+ input_node_idx(it.node_idx()),
node(n) {}
ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) :
in_iter(it),
future_work(p.future_work),
- root_type(p.root_type),
+ input_node_idx(p.input_node_idx),
node(n) {}
vector<ParserState> future_work;
- int root_type; // lhs of top level NT
+ int input_node_idx; // lhs of top level NT
Tree2StringGrammarNode* node;
};
@@ -90,23 +90,40 @@ struct Tree2StringTranslatorImpl {
const vector<double>& weights,
Hypergraph* minus_lm_forest) {
cdec::TreeFragment input_tree(input, false);
- const int kS = -TD::Convert("S");
Hypergraph hg;
+ hg.ReserveNodes(input_tree.nodes.size());
+ vector<int> tree2hg(input_tree.nodes.size() + 1, -1);
queue<ParserState> q;
- q.push(ParserState(input_tree.begin(), &root, kS));
+ q.push(ParserState(input_tree.begin(), &root));
+ unsigned tree_top = q.front().input_node_idx;
while(!q.empty()) {
ParserState& s = q.front();
if (s.in_iter.at_end()) { // completed a traversal of a subtree
- cerr << "I traversed a subtree of the input...\n";
+ //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" <<
+ // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl;
if (s.node->rules.size()) {
- // TODO: build hypergraph
- for (auto& r : s.node->rules)
- cerr << "I can build: " << r->AsString() << endl;
+ TailNodeVector tail;
+ int& node_id = tree2hg[s.input_node_idx];
+ if (node_id < 0)
+ node_id = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK))->id_;
+ for (auto& n : s.future_work) {
+ int& nix = tree2hg[n.input_node_idx];
+ if (nix < 0)
+ nix = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK))->id_;
+ tail.push_back(nix);
+ }
+ for (auto& r : s.node->rules) {
+ assert(tail.size() == r->Arity());
+ HG::Edge* new_edge = hg.AddEdge(r, tail);
+ new_edge->feature_values_ = r->GetFeatureValues();
+ // TODO: set i and j
+ hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]);
+ }
for (auto& w : s.future_work)
q.push(w);
} else {
- cerr << "I can't build anything :(\n";
+ //cerr << "I can't build anything :(\n";
}
} else { // more input tree to match
unsigned sym = *s.in_iter;
@@ -125,7 +142,7 @@ struct Tree2StringTranslatorImpl {
if (nit2 != s.node->next.end()) {
//cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
ParserState new_s(++var, &nit2->second, s);
- ParserState new_work(s.in_iter.remainder(), &root, -(sym & cdec::ALL_MASK));
+ ParserState new_work(s.in_iter.remainder(), &root);
new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q
q.push(new_s);
}
@@ -145,7 +162,14 @@ struct Tree2StringTranslatorImpl {
}
q.pop();
}
+ int goal = tree2hg[tree_top];
+ if (goal < 0) return false;
+ //cerr << "Goal node: " << goal << endl;
+ hg.TopologicallySortNodesAndEdges(goal);
+ hg.Reweight(weights);
+ //hg.PrintGraphviz();
minus_lm_forest->swap(hg);
+ return true;
}
};
diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h
index b83afc27..ceb7fa60 100644
--- a/decoder/tree_fragment.h
+++ b/decoder/tree_fragment.h
@@ -107,6 +107,7 @@ class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, uns
bool operator!=(const BreadthFirstIterator& other) const {
return (tf_ != other.tf_) || (q_ != other.q_);
}
+ unsigned node_idx() const { return q_.front().node; }
const BreadthFirstIterator& operator++() {
TFIState& s = q_.front();
if (s.state == 0) {