summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-02-14 22:30:05 -0500
committerChris Dyer <redpony@gmail.com>2014-02-14 22:30:05 -0500
commit9e2f7fcfa76213f5e41abb4f4c9a264ebe8f9d8c (patch)
tree94905f2dcb4616fa329c7ee294d76d7321b73df5
parent0964b959fc2be173a2adbd52d2d1d143f7458496 (diff)
fix unary handling in scfg parser
-rw-r--r--decoder/bottom_up_parser.cc80
-rw-r--r--tests/system_tests/multigram/cdec.ini1
-rw-r--r--tests/system_tests/multigram/g1.scfg4
-rw-r--r--tests/system_tests/multigram/g2.scfg1
-rw-r--r--tests/system_tests/multigram/gold.statistics3
-rw-r--r--tests/system_tests/multigram/gold.stdout1
-rw-r--r--tests/system_tests/multigram/input.txt1
-rw-r--r--tests/system_tests/multigram/weights1
8 files changed, 81 insertions, 11 deletions
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc
index 606b8d7e..8738c8f1 100644
--- a/decoder/bottom_up_parser.cc
+++ b/decoder/bottom_up_parser.cc
@@ -45,6 +45,7 @@ class PassiveChart {
const float lattice_cost);
void ApplyUnaryRules(const int i, const int j);
+ void TopoSortUnaries();
const vector<GrammarPtr>& grammars_;
const Lattice& input_;
@@ -57,6 +58,7 @@ class PassiveChart {
TRulePtr goal_rule_;
int goal_idx_; // index of goal node, if found
const int lc_fid_;
+ vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars
static WordID kGOAL; // [Goal]
};
@@ -159,21 +161,78 @@ PassiveChart::PassiveChart(const string& goal,
goal_cat_(TD::Convert(goal) * -1),
goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")),
goal_idx_(-1),
- lc_fid_(FD::Convert("LatticeCost")) {
+ lc_fid_(FD::Convert("LatticeCost")),
+ unaries_() {
act_chart_.resize(grammars_.size());
- for (unsigned i = 0; i < grammars_.size(); ++i)
+ for (unsigned i = 0; i < grammars_.size(); ++i) {
act_chart_[i] = new ActiveChart(forest, *this);
+ const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules();
+ for (unsigned j = 0; j < u.size(); ++j)
+ unaries_.push_back(u[j]);
+ }
+ TopoSortUnaries();
if (!kGOAL) kGOAL = TD::Convert("Goal") * -1;
if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl;
}
+static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) {
+ if (mark[node] == 1) {
+ cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n";
+ return false; // cycle detected
+ } else if (mark[node] == 2) {
+ return true; // already been
+ }
+ mark[node] = 1;
+ const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node);
+ if (nit != g.end()) {
+ const vector<TRulePtr>& edges = nit->second;
+ vector<bool> okay(edges.size(), true);
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark);
+ if (!okay[i]) {
+ cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl;
+ }
+ }
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ if (okay[i]) u.push_back(edges[i]);
+ //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl;
+ }
+ }
+ mark[node] = 2;
+ return true;
+}
+
+void PassiveChart::TopoSortUnaries() {
+ vector<TRulePtr> u(unaries_.size()); u.clear();
+ map<int, vector<TRulePtr> > g;
+ map<int, int> mark;
+ //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl;
+ mark[goal_cat_] = 2;
+ for (unsigned i = 0; i < unaries_.size(); ++i) {
+ //cerr << "Adding: " << unaries_[i]->AsString() << endl;
+ g[unaries_[i]->f()[0]].push_back(unaries_[i]);
+ }
+ //m[unaries_[i]->lhs_].push_back(unaries_[i]);
+ for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) {
+ //cerr << "PROC: " << TD::Convert(-it->first) << endl;
+ if (mark[it->first] > 0) {
+ //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n";
+ } else {
+ TopoSortVisit(it->first, u, g, mark);
+ }
+ }
+ unaries_.clear();
+ for (int i = u.size() - 1; i >= 0; --i)
+ unaries_.push_back(u[i]);
+}
+
void PassiveChart::ApplyRule(const int i,
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
const float lattice_cost) {
Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
- //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
+ // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
new_edge->prev_i_ = r->prev_i;
new_edge->prev_j_ = r->prev_j;
new_edge->i_ = i;
@@ -215,15 +274,14 @@ void PassiveChart::ApplyRules(const int i,
void PassiveChart::ApplyUnaryRules(const int i, const int j) {
const vector<int>& nodes = chart_(i,j); // reference is important!
- for (unsigned gi = 0; gi < grammars_.size(); ++gi) {
- if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue;
- for (unsigned di = 0; di < nodes.size(); ++di) {
- const WordID& cat = forest_->nodes_[nodes[di]].cat_;
- const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat);
- for (unsigned ri = 0; ri < unaries.size(); ++ri) {
- // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl;
+ for (unsigned di = 0; di < nodes.size(); ++di) {
+ const WordID& cat = forest_->nodes_[nodes[di]].cat_;
+ for (unsigned ri = 0; ri < unaries_.size(); ++ri) {
+ //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl;
+ if (unaries_[ri]->f()[0] == cat) {
+ //cerr << " --MATCH\n";
const Hypergraph::TailNodeVector ant(1, nodes[di]);
- ApplyRule(i, j, unaries[ri], ant, 0); // may update nodes
+ ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes
}
}
}
diff --git a/tests/system_tests/multigram/cdec.ini b/tests/system_tests/multigram/cdec.ini
new file mode 100644
index 00000000..f31becb8
--- /dev/null
+++ b/tests/system_tests/multigram/cdec.ini
@@ -0,0 +1 @@
+formalism=scfg
diff --git a/tests/system_tests/multigram/g1.scfg b/tests/system_tests/multigram/g1.scfg
new file mode 100644
index 00000000..a3a59699
--- /dev/null
+++ b/tests/system_tests/multigram/g1.scfg
@@ -0,0 +1,4 @@
+[X] ||| [Z] ||| [1] ||| Top=1
+[Y] ||| foo ||| foo ||| F1=1
+[Z] ||| [Z] [Y] ||| [1] [2] ||| W1=1
+[Z] ||| [Y] ||| [1] ||| W2=1
diff --git a/tests/system_tests/multigram/g2.scfg b/tests/system_tests/multigram/g2.scfg
new file mode 100644
index 00000000..40962517
--- /dev/null
+++ b/tests/system_tests/multigram/g2.scfg
@@ -0,0 +1 @@
+[Y] ||| bar ||| bar ||| F2=1
diff --git a/tests/system_tests/multigram/gold.statistics b/tests/system_tests/multigram/gold.statistics
new file mode 100644
index 00000000..ef23a685
--- /dev/null
+++ b/tests/system_tests/multigram/gold.statistics
@@ -0,0 +1,3 @@
+-lm_nodes 11
+-lm_edges 12
+-lm_paths 2
diff --git a/tests/system_tests/multigram/gold.stdout b/tests/system_tests/multigram/gold.stdout
new file mode 100644
index 00000000..d675fa44
--- /dev/null
+++ b/tests/system_tests/multigram/gold.stdout
@@ -0,0 +1 @@
+foo bar
diff --git a/tests/system_tests/multigram/input.txt b/tests/system_tests/multigram/input.txt
new file mode 100644
index 00000000..2aef01a0
--- /dev/null
+++ b/tests/system_tests/multigram/input.txt
@@ -0,0 +1 @@
+<seg id="0" grammar="g1.scfg" grammar1="g2.scfg">foo bar</seg>
diff --git a/tests/system_tests/multigram/weights b/tests/system_tests/multigram/weights
new file mode 100644
index 00000000..a6b6698a
--- /dev/null
+++ b/tests/system_tests/multigram/weights
@@ -0,0 +1 @@
+Glue -1