From 8015250ddd3983320b6e54ca7f1914a465bc8a59 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 14 Feb 2014 22:30:05 -0500 Subject: fix unary handling in scfg parser --- decoder/bottom_up_parser.cc | 80 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 11 deletions(-) (limited to 'decoder/bottom_up_parser.cc') diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index 606b8d7e..8738c8f1 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -45,6 +45,7 @@ class PassiveChart { const float lattice_cost); void ApplyUnaryRules(const int i, const int j); + void TopoSortUnaries(); const vector& grammars_; const Lattice& input_; @@ -57,6 +58,7 @@ class PassiveChart { TRulePtr goal_rule_; int goal_idx_; // index of goal node, if found const int lc_fid_; + vector unaries_; // topologically sorted list of unary rules from all grammars static WordID kGOAL; // [Goal] }; @@ -159,21 +161,78 @@ PassiveChart::PassiveChart(const string& goal, goal_cat_(TD::Convert(goal) * -1), goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")), goal_idx_(-1), - lc_fid_(FD::Convert("LatticeCost")) { + lc_fid_(FD::Convert("LatticeCost")), + unaries_() { act_chart_.resize(grammars_.size()); - for (unsigned i = 0; i < grammars_.size(); ++i) + for (unsigned i = 0; i < grammars_.size(); ++i) { act_chart_[i] = new ActiveChart(forest, *this); + const vector& u = grammars_[i]->GetAllUnaryRules(); + for (unsigned j = 0; j < u.size(); ++j) + unaries_.push_back(u[j]); + } + TopoSortUnaries(); if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl; } +static bool TopoSortVisit(int node, vector& u, const map >& g, map& mark) { + if (mark[node] == 1) { + cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n"; + return false; // cycle detected + } else if (mark[node] == 2) { + return true; // already been + } + mark[node] = 1; + const map >::const_iterator nit = g.find(node); + if (nit != g.end()) { + const vector& edges = nit->second; + vector okay(edges.size(), true); + for (unsigned i = 0; i < edges.size(); ++i) { + okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark); + if (!okay[i]) { + cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl; + } + } + for (unsigned i = 0; i < edges.size(); ++i) { + if (okay[i]) u.push_back(edges[i]); + //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl; + } + } + mark[node] = 2; + return true; +} + +void PassiveChart::TopoSortUnaries() { + vector u(unaries_.size()); u.clear(); + map > g; + map mark; + //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl; + mark[goal_cat_] = 2; + for (unsigned i = 0; i < unaries_.size(); ++i) { + //cerr << "Adding: " << unaries_[i]->AsString() << endl; + g[unaries_[i]->f()[0]].push_back(unaries_[i]); + } + //m[unaries_[i]->lhs_].push_back(unaries_[i]); + for (map >::iterator it = g.begin(); it != g.end(); ++it) { + //cerr << "PROC: " << TD::Convert(-it->first) << endl; + if (mark[it->first] > 0) { + //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n"; + } else { + TopoSortVisit(it->first, u, g, mark); + } + } + unaries_.clear(); + for (int i = u.size() - 1; i >= 0; --i) + unaries_.push_back(u[i]); +} + void PassiveChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, const float lattice_cost) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); - //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; + // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; new_edge->prev_j_ = r->prev_j; new_edge->i_ = i; @@ -215,15 +274,14 @@ void PassiveChart::ApplyRules(const int i, void PassiveChart::ApplyUnaryRules(const int i, const int j) { const vector& nodes = chart_(i,j); // reference is important! - for (unsigned gi = 0; gi < grammars_.size(); ++gi) { - if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue; - for (unsigned di = 0; di < nodes.size(); ++di) { - const WordID& cat = forest_->nodes_[nodes[di]].cat_; - const vector& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat); - for (unsigned ri = 0; ri < unaries.size(); ++ri) { - // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl; + for (unsigned di = 0; di < nodes.size(); ++di) { + const WordID& cat = forest_->nodes_[nodes[di]].cat_; + for (unsigned ri = 0; ri < unaries_.size(); ++ri) { + //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; + if (unaries_[ri]->f()[0] == cat) { + //cerr << " --MATCH\n"; const Hypergraph::TailNodeVector ant(1, nodes[di]); - ApplyRule(i, j, unaries[ri], ant, 0); // may update nodes + ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes } } } -- cgit v1.2.3