From dc6930c00b4b276883280cff1ed6dcd9ddef03c7 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 8 Dec 2009 21:38:55 -0500 Subject: LICENSE fixes, full support of lattice decoding --- src/bottom_up_parser.cc | 53 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 17 deletions(-) (limited to 'src/bottom_up_parser.cc') diff --git a/src/bottom_up_parser.cc b/src/bottom_up_parser.cc index 349ed2de..b3315b8a 100644 --- a/src/bottom_up_parser.cc +++ b/src/bottom_up_parser.cc @@ -24,8 +24,18 @@ class PassiveChart { inline int GetGoalIndex() const { return goal_idx_; } private: - void ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail); - void ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes); + void ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const float lattice_cost); + + void ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const float lattice_cost); + void ApplyUnaryRules(const int i, const int j); const vector& grammars_; @@ -38,6 +48,7 @@ class PassiveChart { const WordID goal_cat_; // category that is being searched for at [0,n] TRulePtr goal_rule_; int goal_idx_; // index of goal node, if found + const int lc_fid_; static WordID kGOAL; // [Goal] }; @@ -51,12 +62,12 @@ class ActiveChart { act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {} struct ActiveItem { - ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, double lcost) : + ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, float lcost) : gptr_(g), ant_nodes_(a), lattice_cost(lcost) {} explicit ActiveItem(const GrammarIter* g) : - gptr_(g), ant_nodes_(), lattice_cost() {} + gptr_(g), ant_nodes_(), lattice_cost(0.0) {} - void ExtendTerminal(int symbol, double src_cost, vector* out_cell) const { + void ExtendTerminal(int symbol, float src_cost, vector* out_cell) const { const GrammarIter* ni = gptr_->Extend(symbol); if (ni) out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); } @@ -73,14 +84,14 @@ class ActiveChart { const GrammarIter* gptr_; Hypergraph::TailNodeVector ant_nodes_; - double lattice_cost; // TODO? use SparseVector + float lattice_cost; // TODO? use SparseVector }; inline const vector& operator()(int i, int j) const { return act_chart_(i,j); } void SeedActiveChart(const Grammar& g) { int size = act_chart_.width(); for (int i = 0; i < size; ++i) - if (g.HasRuleForSpan(i,i)) + if (g.HasRuleForSpan(i,i,0)) act_chart_(i,i).push_back(ActiveItem(g.GetRoot())); } @@ -132,7 +143,8 @@ PassiveChart::PassiveChart(const string& goal, nodemap_(input.size()+1, input.size()+1), goal_cat_(TD::Convert(goal) * -1), goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")), - goal_idx_(-1) { + goal_idx_(-1), + lc_fid_(FD::Convert("LatticeCost")) { act_chart_.resize(grammars_.size()); for (int i = 0; i < grammars_.size(); ++i) act_chart_[i] = new ActiveChart(forest, *this); @@ -140,13 +152,19 @@ PassiveChart::PassiveChart(const string& goal, cerr << " Goal category: [" << goal << ']' << endl; } -void PassiveChart::ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes) { +void PassiveChart::ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const float lattice_cost) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); new_edge->prev_i_ = r->prev_i; new_edge->prev_j_ = r->prev_j; new_edge->i_ = i; new_edge->j_ = j; new_edge->feature_values_ = r->GetFeatureValues(); + if (lattice_cost) + new_edge->feature_values_.set_value(lc_fid_, lattice_cost); Cat2NodeMap& c2n = nodemap_(i,j); const bool is_goal = (r->GetLHS() == kGOAL); const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); @@ -169,23 +187,24 @@ void PassiveChart::ApplyRule(const int i, const int j, TRulePtr r, const Hypergr void PassiveChart::ApplyRules(const int i, const int j, const RuleBin* rules, - const Hypergraph::TailNodeVector& tail) { + const Hypergraph::TailNodeVector& tail, + const float lattice_cost) { const int n = rules->GetNumRules(); for (int k = 0; k < n; ++k) - ApplyRule(i, j, rules->GetIthRule(k), tail); + ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost); } void PassiveChart::ApplyUnaryRules(const int i, const int j) { const vector& nodes = chart_(i,j); // reference is important! for (int gi = 0; gi < grammars_.size(); ++gi) { - if (!grammars_[gi]->HasRuleForSpan(i,j)) continue; + if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue; for (int di = 0; di < nodes.size(); ++di) { const WordID& cat = forest_->nodes_[nodes[di]].cat_; const vector& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat); for (int ri = 0; ri < unaries.size(); ++ri) { // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl; const Hypergraph::TailNodeVector ant(1, nodes[di]); - ApplyRule(i, j, unaries[ri], ant); // may update nodes + ApplyRule(i, j, unaries[ri], ant, 0); // may update nodes } } } @@ -205,7 +224,7 @@ bool PassiveChart::Parse() { int j = i + l; for (int gi = 0; gi < grammars_.size(); ++gi) { const Grammar& g = *grammars_[gi]; - if (g.HasRuleForSpan(i, j)) { + if (g.HasRuleForSpan(i, j, input_.Distance(i, j))) { act_chart_[gi]->AdvanceDotsForAllItemsInCell(i, j, input_); const vector& cell = (*act_chart_[gi])(i,j); @@ -213,7 +232,7 @@ bool PassiveChart::Parse() { ai != cell.end(); ++ai) { const RuleBin* rules = (ai->gptr_->GetRules()); if (!rules) continue; - ApplyRules(i, j, rules, ai->ant_nodes_); + ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_cost); } } } @@ -222,7 +241,7 @@ bool PassiveChart::Parse() { for (int gi = 0; gi < grammars_.size(); ++gi) { const Grammar& g = *grammars_[gi]; // deal with non-terminals that were just proved - if (g.HasRuleForSpan(i, j)) + if (g.HasRuleForSpan(i, j, input_.Distance(i,j))) act_chart_[gi]->ExtendActiveItems(i, i, j); } } @@ -231,7 +250,7 @@ bool PassiveChart::Parse() { const Hypergraph::Node& node = forest_->nodes_[dh[di]]; if (node.cat_ == goal_cat_) { Hypergraph::TailNodeVector ant(1, node.id_); - ApplyRule(0, input_.size(), goal_rule_, ant); + ApplyRule(0, input_.size(), goal_rule_, ant, 0); } } } -- cgit v1.2.3