summaryrefslogtreecommitdiff
path: root/src/bottom_up_parser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/bottom_up_parser.cc')
-rw-r--r--src/bottom_up_parser.cc53
1 files changed, 36 insertions, 17 deletions
diff --git a/src/bottom_up_parser.cc b/src/bottom_up_parser.cc
index 349ed2de..b3315b8a 100644
--- a/src/bottom_up_parser.cc
+++ b/src/bottom_up_parser.cc
@@ -24,8 +24,18 @@ class PassiveChart {
inline int GetGoalIndex() const { return goal_idx_; }
private:
- void ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail);
- void ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes);
+ void ApplyRules(const int i,
+ const int j,
+ const RuleBin* rules,
+ const Hypergraph::TailNodeVector& tail,
+ const float lattice_cost);
+
+ void ApplyRule(const int i,
+ const int j,
+ const TRulePtr& r,
+ const Hypergraph::TailNodeVector& ant_nodes,
+ const float lattice_cost);
+
void ApplyUnaryRules(const int i, const int j);
const vector<GrammarPtr>& grammars_;
@@ -38,6 +48,7 @@ class PassiveChart {
const WordID goal_cat_; // category that is being searched for at [0,n]
TRulePtr goal_rule_;
int goal_idx_; // index of goal node, if found
+ const int lc_fid_;
static WordID kGOAL; // [Goal]
};
@@ -51,12 +62,12 @@ class ActiveChart {
act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {}
struct ActiveItem {
- ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, double lcost) :
+ ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, float lcost) :
gptr_(g), ant_nodes_(a), lattice_cost(lcost) {}
explicit ActiveItem(const GrammarIter* g) :
- gptr_(g), ant_nodes_(), lattice_cost() {}
+ gptr_(g), ant_nodes_(), lattice_cost(0.0) {}
- void ExtendTerminal(int symbol, double src_cost, vector<ActiveItem>* out_cell) const {
+ void ExtendTerminal(int symbol, float src_cost, vector<ActiveItem>* out_cell) const {
const GrammarIter* ni = gptr_->Extend(symbol);
if (ni) out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost));
}
@@ -73,14 +84,14 @@ class ActiveChart {
const GrammarIter* gptr_;
Hypergraph::TailNodeVector ant_nodes_;
- double lattice_cost; // TODO? use SparseVector<double>
+ float lattice_cost; // TODO? use SparseVector<double>
};
inline const vector<ActiveItem>& operator()(int i, int j) const { return act_chart_(i,j); }
void SeedActiveChart(const Grammar& g) {
int size = act_chart_.width();
for (int i = 0; i < size; ++i)
- if (g.HasRuleForSpan(i,i))
+ if (g.HasRuleForSpan(i,i,0))
act_chart_(i,i).push_back(ActiveItem(g.GetRoot()));
}
@@ -132,7 +143,8 @@ PassiveChart::PassiveChart(const string& goal,
nodemap_(input.size()+1, input.size()+1),
goal_cat_(TD::Convert(goal) * -1),
goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")),
- goal_idx_(-1) {
+ goal_idx_(-1),
+ lc_fid_(FD::Convert("LatticeCost")) {
act_chart_.resize(grammars_.size());
for (int i = 0; i < grammars_.size(); ++i)
act_chart_[i] = new ActiveChart(forest, *this);
@@ -140,13 +152,19 @@ PassiveChart::PassiveChart(const string& goal,
cerr << " Goal category: [" << goal << ']' << endl;
}
-void PassiveChart::ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes) {
+void PassiveChart::ApplyRule(const int i,
+ const int j,
+ const TRulePtr& r,
+ const Hypergraph::TailNodeVector& ant_nodes,
+ const float lattice_cost) {
Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
new_edge->prev_i_ = r->prev_i;
new_edge->prev_j_ = r->prev_j;
new_edge->i_ = i;
new_edge->j_ = j;
new_edge->feature_values_ = r->GetFeatureValues();
+ if (lattice_cost)
+ new_edge->feature_values_.set_value(lc_fid_, lattice_cost);
Cat2NodeMap& c2n = nodemap_(i,j);
const bool is_goal = (r->GetLHS() == kGOAL);
const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS());
@@ -169,23 +187,24 @@ void PassiveChart::ApplyRule(const int i, const int j, TRulePtr r, const Hypergr
void PassiveChart::ApplyRules(const int i,
const int j,
const RuleBin* rules,
- const Hypergraph::TailNodeVector& tail) {
+ const Hypergraph::TailNodeVector& tail,
+ const float lattice_cost) {
const int n = rules->GetNumRules();
for (int k = 0; k < n; ++k)
- ApplyRule(i, j, rules->GetIthRule(k), tail);
+ ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost);
}
void PassiveChart::ApplyUnaryRules(const int i, const int j) {
const vector<int>& nodes = chart_(i,j); // reference is important!
for (int gi = 0; gi < grammars_.size(); ++gi) {
- if (!grammars_[gi]->HasRuleForSpan(i,j)) continue;
+ if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue;
for (int di = 0; di < nodes.size(); ++di) {
const WordID& cat = forest_->nodes_[nodes[di]].cat_;
const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat);
for (int ri = 0; ri < unaries.size(); ++ri) {
// cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl;
const Hypergraph::TailNodeVector ant(1, nodes[di]);
- ApplyRule(i, j, unaries[ri], ant); // may update nodes
+ ApplyRule(i, j, unaries[ri], ant, 0); // may update nodes
}
}
}
@@ -205,7 +224,7 @@ bool PassiveChart::Parse() {
int j = i + l;
for (int gi = 0; gi < grammars_.size(); ++gi) {
const Grammar& g = *grammars_[gi];
- if (g.HasRuleForSpan(i, j)) {
+ if (g.HasRuleForSpan(i, j, input_.Distance(i, j))) {
act_chart_[gi]->AdvanceDotsForAllItemsInCell(i, j, input_);
const vector<ActiveChart::ActiveItem>& cell = (*act_chart_[gi])(i,j);
@@ -213,7 +232,7 @@ bool PassiveChart::Parse() {
ai != cell.end(); ++ai) {
const RuleBin* rules = (ai->gptr_->GetRules());
if (!rules) continue;
- ApplyRules(i, j, rules, ai->ant_nodes_);
+ ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_cost);
}
}
}
@@ -222,7 +241,7 @@ bool PassiveChart::Parse() {
for (int gi = 0; gi < grammars_.size(); ++gi) {
const Grammar& g = *grammars_[gi];
// deal with non-terminals that were just proved
- if (g.HasRuleForSpan(i, j))
+ if (g.HasRuleForSpan(i, j, input_.Distance(i,j)))
act_chart_[gi]->ExtendActiveItems(i, i, j);
}
}
@@ -231,7 +250,7 @@ bool PassiveChart::Parse() {
const Hypergraph::Node& node = forest_->nodes_[dh[di]];
if (node.cat_ == goal_cat_) {
Hypergraph::TailNodeVector ant(1, node.id_);
- ApplyRule(0, input_.size(), goal_rule_, ant);
+ ApplyRule(0, input_.size(), goal_rule_, ant, 0);
}
}
}