summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-02-14 22:30:05 -0500
committerChris Dyer <redpony@gmail.com>2014-02-14 22:30:05 -0500
commit8015250ddd3983320b6e54ca7f1914a465bc8a59 (patch)
treeffb8f421a97a71058bac910bb1ca2e4b5a116c53 /decoder
parent2b772ed8c1dcfecbb473f63cb0ef65b1dfb574dd (diff)
fix unary handling in scfg parser
Diffstat (limited to 'decoder')
-rw-r--r--decoder/bottom_up_parser.cc80
1 files changed, 69 insertions, 11 deletions
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc
index 606b8d7e..8738c8f1 100644
--- a/decoder/bottom_up_parser.cc
+++ b/decoder/bottom_up_parser.cc
@@ -45,6 +45,7 @@ class PassiveChart {
const float lattice_cost);
void ApplyUnaryRules(const int i, const int j);
+ void TopoSortUnaries();
const vector<GrammarPtr>& grammars_;
const Lattice& input_;
@@ -57,6 +58,7 @@ class PassiveChart {
TRulePtr goal_rule_;
int goal_idx_; // index of goal node, if found
const int lc_fid_;
+ vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars
static WordID kGOAL; // [Goal]
};
@@ -159,21 +161,78 @@ PassiveChart::PassiveChart(const string& goal,
goal_cat_(TD::Convert(goal) * -1),
goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")),
goal_idx_(-1),
- lc_fid_(FD::Convert("LatticeCost")) {
+ lc_fid_(FD::Convert("LatticeCost")),
+ unaries_() {
act_chart_.resize(grammars_.size());
- for (unsigned i = 0; i < grammars_.size(); ++i)
+ for (unsigned i = 0; i < grammars_.size(); ++i) {
act_chart_[i] = new ActiveChart(forest, *this);
+ const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules();
+ for (unsigned j = 0; j < u.size(); ++j)
+ unaries_.push_back(u[j]);
+ }
+ TopoSortUnaries();
if (!kGOAL) kGOAL = TD::Convert("Goal") * -1;
if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl;
}
+static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) {
+ if (mark[node] == 1) {
+ cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n";
+ return false; // cycle detected
+ } else if (mark[node] == 2) {
+ return true; // already been
+ }
+ mark[node] = 1;
+ const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node);
+ if (nit != g.end()) {
+ const vector<TRulePtr>& edges = nit->second;
+ vector<bool> okay(edges.size(), true);
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark);
+ if (!okay[i]) {
+ cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl;
+ }
+ }
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ if (okay[i]) u.push_back(edges[i]);
+ //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl;
+ }
+ }
+ mark[node] = 2;
+ return true;
+}
+
+void PassiveChart::TopoSortUnaries() {
+ vector<TRulePtr> u(unaries_.size()); u.clear();
+ map<int, vector<TRulePtr> > g;
+ map<int, int> mark;
+ //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl;
+ mark[goal_cat_] = 2;
+ for (unsigned i = 0; i < unaries_.size(); ++i) {
+ //cerr << "Adding: " << unaries_[i]->AsString() << endl;
+ g[unaries_[i]->f()[0]].push_back(unaries_[i]);
+ }
+ //m[unaries_[i]->lhs_].push_back(unaries_[i]);
+ for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) {
+ //cerr << "PROC: " << TD::Convert(-it->first) << endl;
+ if (mark[it->first] > 0) {
+ //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n";
+ } else {
+ TopoSortVisit(it->first, u, g, mark);
+ }
+ }
+ unaries_.clear();
+ for (int i = u.size() - 1; i >= 0; --i)
+ unaries_.push_back(u[i]);
+}
+
void PassiveChart::ApplyRule(const int i,
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
const float lattice_cost) {
Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
- //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
+ // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
new_edge->prev_i_ = r->prev_i;
new_edge->prev_j_ = r->prev_j;
new_edge->i_ = i;
@@ -215,15 +274,14 @@ void PassiveChart::ApplyRules(const int i,
void PassiveChart::ApplyUnaryRules(const int i, const int j) {
const vector<int>& nodes = chart_(i,j); // reference is important!
- for (unsigned gi = 0; gi < grammars_.size(); ++gi) {
- if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue;
- for (unsigned di = 0; di < nodes.size(); ++di) {
- const WordID& cat = forest_->nodes_[nodes[di]].cat_;
- const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat);
- for (unsigned ri = 0; ri < unaries.size(); ++ri) {
- // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl;
+ for (unsigned di = 0; di < nodes.size(); ++di) {
+ const WordID& cat = forest_->nodes_[nodes[di]].cat_;
+ for (unsigned ri = 0; ri < unaries_.size(); ++ri) {
+ //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl;
+ if (unaries_[ri]->f()[0] == cat) {
+ //cerr << " --MATCH\n";
const Hypergraph::TailNodeVector ant(1, nodes[di]);
- ApplyRule(i, j, unaries[ri], ant, 0); // may update nodes
+ ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes
}
}
}