diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/bottom_up_parser.cc | 80 | 
1 files changed, 69 insertions, 11 deletions
| diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index 606b8d7e..8738c8f1 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -45,6 +45,7 @@ class PassiveChart {                   const float lattice_cost);    void ApplyUnaryRules(const int i, const int j); +  void TopoSortUnaries();    const vector<GrammarPtr>& grammars_;    const Lattice& input_; @@ -57,6 +58,7 @@ class PassiveChart {    TRulePtr goal_rule_;    int goal_idx_;             // index of goal node, if found    const int lc_fid_; +  vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars    static WordID kGOAL;       // [Goal]  }; @@ -159,21 +161,78 @@ PassiveChart::PassiveChart(const string& goal,      goal_cat_(TD::Convert(goal) * -1),      goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")),      goal_idx_(-1), -    lc_fid_(FD::Convert("LatticeCost")) { +    lc_fid_(FD::Convert("LatticeCost")), +    unaries_() {    act_chart_.resize(grammars_.size()); -  for (unsigned i = 0; i < grammars_.size(); ++i) +  for (unsigned i = 0; i < grammars_.size(); ++i) {      act_chart_[i] = new ActiveChart(forest, *this); +    const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules(); +    for (unsigned j = 0; j < u.size(); ++j) +      unaries_.push_back(u[j]); +  } +  TopoSortUnaries();    if (!kGOAL) kGOAL = TD::Convert("Goal") * -1;    if (!SILENT) cerr << "  Goal category: [" << goal << ']' << endl;  } +static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) { +  if (mark[node] == 1) { +    cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n"; +    return false; // cycle detected +  } else if (mark[node] == 2) { +    return true; // already been  +  } +  mark[node] = 1; +  const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node); +  if (nit != g.end()) { +    const vector<TRulePtr>& edges = nit->second; +    vector<bool> okay(edges.size(), true); +    for (unsigned i = 0; i < edges.size(); ++i) { +      okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark); +      if (!okay[i]) { +        cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl; +      } +   } +    for (unsigned i = 0; i < edges.size(); ++i) { +      if (okay[i]) u.push_back(edges[i]); +      //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl; +    } +  } +  mark[node] = 2; +  return true; +} + +void PassiveChart::TopoSortUnaries() { +  vector<TRulePtr> u(unaries_.size()); u.clear(); +  map<int, vector<TRulePtr> > g; +  map<int, int> mark; +  //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl; +  mark[goal_cat_] = 2; +  for (unsigned i = 0; i < unaries_.size(); ++i) { +    //cerr << "Adding: " << unaries_[i]->AsString() << endl; +    g[unaries_[i]->f()[0]].push_back(unaries_[i]); +  } +    //m[unaries_[i]->lhs_].push_back(unaries_[i]); +  for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) { +    //cerr << "PROC: " << TD::Convert(-it->first) << endl; +    if (mark[it->first] > 0) { +      //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n"; +    } else { +      TopoSortVisit(it->first, u, g, mark); +    } +  } +  unaries_.clear(); +  for (int i = u.size() - 1; i >= 0; --i) +    unaries_.push_back(u[i]); +} +  void PassiveChart::ApplyRule(const int i,                               const int j,                               const TRulePtr& r,                               const Hypergraph::TailNodeVector& ant_nodes,                               const float lattice_cost) {    Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); -  //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; +  // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;    new_edge->prev_i_ = r->prev_i;    new_edge->prev_j_ = r->prev_j;    new_edge->i_ = i; @@ -215,15 +274,14 @@ void PassiveChart::ApplyRules(const int i,  void PassiveChart::ApplyUnaryRules(const int i, const int j) {    const vector<int>& nodes = chart_(i,j);  // reference is important! -  for (unsigned gi = 0; gi < grammars_.size(); ++gi) { -    if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue; -    for (unsigned di = 0; di < nodes.size(); ++di) { -      const WordID& cat = forest_->nodes_[nodes[di]].cat_; -      const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat); -      for (unsigned ri = 0; ri < unaries.size(); ++ri) { -        // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl; +  for (unsigned di = 0; di < nodes.size(); ++di) { +    const WordID& cat = forest_->nodes_[nodes[di]].cat_; +    for (unsigned ri = 0; ri < unaries_.size(); ++ri) { +      //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; +      if (unaries_[ri]->f()[0] == cat) { +        //cerr << "  --MATCH\n";          const Hypergraph::TailNodeVector ant(1, nodes[di]); -        ApplyRule(i, j, unaries[ri], ant, 0);  // may update nodes +        ApplyRule(i, j, unaries_[ri], ant, 0);  // may update nodes        }      }    } | 
