//TODO: when using many nonterminals, group passive edges for a span (treat all as a single X for the active items). #include "bottom_up_parser.h" #include #include #include "hg.h" #include "array2d.h" #include "tdict.h" using namespace std; struct ParserStats { ParserStats() : active_items(), passive_items() {} void Reset() { active_items=0; passive_items=0; } void Report() { cerr << " ACTIVE ITEMS: " << active_items << "\tPASSIVE ITEMS: " << passive_items << endl; } int active_items; int passive_items; void NotifyActive(int , int ) { ++active_items; } void NotifyPassive(int , int ) { ++passive_items; } }; ParserStats stats; class ActiveChart; class PassiveChart { public: PassiveChart(const string& goal, const vector& grammars, const Lattice& input, Hypergraph* forest); ~PassiveChart(); inline const vector& operator()(int i, int j) const { return chart_(i,j); } bool Parse(); inline int size() const { return chart_.width(); } inline bool GoalFound() const { return goal_idx_ >= 0; } inline int GetGoalIndex() const { return goal_idx_; } private: void ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, const float lattice_cost); void ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, const float lattice_cost); void ApplyUnaryRules(const int i, const int j); const vector& grammars_; const Lattice& input_; Hypergraph* forest_; Array2D > chart_; // chart_(i,j) is the list of nodes derived spanning i,j typedef map Cat2NodeMap; Array2D nodemap_; vector act_chart_; const WordID goal_cat_; // category that is being searched for at [0,n] TRulePtr goal_rule_; int goal_idx_; // index of goal node, if found const int lc_fid_; static WordID kGOAL; // [Goal] }; WordID PassiveChart::kGOAL = 0; class ActiveChart { public: ActiveChart(const Hypergraph* hg, const PassiveChart& psv_chart) : hg_(hg), act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {} struct ActiveItem { ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, float lcost) : gptr_(g), ant_nodes_(a), lattice_cost(lcost) {} explicit ActiveItem(const GrammarIter* g) : gptr_(g), ant_nodes_(), lattice_cost(0.0) {} void ExtendTerminal(int symbol, float src_cost, vector* out_cell) const { const GrammarIter* ni = gptr_->Extend(symbol); if (ni) { stats.NotifyActive(-1,-1); // TRACKING STATS out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); } } void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector* out_cell) const { int symbol = hg->nodes_[node_index].cat_; const GrammarIter* ni = gptr_->Extend(symbol); if (!ni) return; stats.NotifyActive(-1,-1); // TRACKING STATS Hypergraph::TailNodeVector na(ant_nodes_.size() + 1); for (int i = 0; i < ant_nodes_.size(); ++i) na[i] = ant_nodes_[i]; na[ant_nodes_.size()] = node_index; out_cell->push_back(ActiveItem(ni, na, lattice_cost)); } const GrammarIter* gptr_; Hypergraph::TailNodeVector ant_nodes_; float lattice_cost; // TODO? use SparseVector }; inline const vector& operator()(int i, int j) const { return act_chart_(i,j); } void SeedActiveChart(const Grammar& g) { int size = act_chart_.width(); for (int i = 0; i < size; ++i) if (g.HasRuleForSpan(i,i,0)) act_chart_(i,i).push_back(ActiveItem(g.GetRoot())); } void ExtendActiveItems(int i, int k, int j) { //cerr << " LOOK(" << i << "," << k << ") for completed items in (" << k << "," << j << ")\n"; vector& cell = act_chart_(i,j); const vector& icell = act_chart_(i,k); const vector& idxs = psv_chart_(k, j); //if (!idxs.empty()) { cerr << "FOUND IN (" << k << "," << j << ")\n"; } for (vector::const_iterator di = icell.begin(); di != icell.end(); ++di) { for (vector::const_iterator ni = idxs.begin(); ni != idxs.end(); ++ni) { di->ExtendNonTerminal(hg_, *ni, &cell); } } } void AdvanceDotsForAllItemsInCell(int i, int j, const vector >& input) { //cerr << "ADVANCE(" << i << "," << j << ")\n"; for (int k=i+1; k < j; ++k) ExtendActiveItems(i, k, j); const vector& out_arcs = input[j-1]; for (vector::const_iterator ai = out_arcs.begin(); ai != out_arcs.end(); ++ai) { const WordID& f = ai->label; const double& c = ai->cost; const int& len = ai->dist2next; //VLOG(1) << "F: " << TD::Convert(f) << endl; const vector& ec = act_chart_(i, j-1); for (vector::const_iterator di = ec.begin(); di != ec.end(); ++di) di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1)); } } private: const Hypergraph* hg_; Array2D > act_chart_; const PassiveChart& psv_chart_; }; PassiveChart::PassiveChart(const string& goal, const vector& grammars, const Lattice& input, Hypergraph* forest) : grammars_(grammars), input_(input), forest_(forest), chart_(input.size()+1, input.size()+1), nodemap_(input.size()+1, input.size()+1), goal_cat_(TD::Convert(goal) * -1), goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")), goal_idx_(-1), lc_fid_(FD::Convert("LatticeCost")) { act_chart_.resize(grammars_.size()); for (int i = 0; i < grammars_.size(); ++i) act_chart_[i] = new ActiveChart(forest, *this); if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; cerr << " Goal category: [" << goal << ']' << endl; } void PassiveChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, const float lattice_cost) { stats.NotifyPassive(i,j); // TRACKING STATS Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); new_edge->prev_i_ = r->prev_i; new_edge->prev_j_ = r->prev_j; new_edge->i_ = i; new_edge->j_ = j; new_edge->feature_values_ = r->GetFeatureValues(); if (lattice_cost) new_edge->feature_values_.set_value(lc_fid_, lattice_cost); Cat2NodeMap& c2n = nodemap_(i,j); const bool is_goal = (r->GetLHS() == kGOAL); const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); Hypergraph::Node* node = NULL; if (ni == c2n.end()) { node = forest_->AddNode(r->GetLHS()); c2n[r->GetLHS()] = node->id_; if (is_goal) { assert(goal_idx_ == -1); goal_idx_ = node->id_; } else { chart_(i,j).push_back(node->id_); } } else { node = &forest_->nodes_[ni->second]; } forest_->ConnectEdgeToHeadNode(new_edge, node); } void PassiveChart::ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, const float lattice_cost) { const int n = rules->GetNumRules(); for (int k = 0; k < n; ++k) ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost); } void PassiveChart::ApplyUnaryRules(const int i, const int j) { const vector& nodes = chart_(i,j); // reference is important! for (int gi = 0; gi < grammars_.size(); ++gi) { if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue; for (int di = 0; di < nodes.size(); ++di) { const WordID& cat = forest_->nodes_[nodes[di]].cat_; const vector& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat); for (int ri = 0; ri < unaries.size(); ++ri) { // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl; const Hypergraph::TailNodeVector ant(1, nodes[di]); ApplyRule(i, j, unaries[ri], ant, 0); // may update nodes } } } } bool PassiveChart::Parse() { forest_->nodes_.reserve(input_.size() * input_.size() * 2); forest_->edges_.reserve(input_.size() * input_.size() * 1000); // TODO: reservation?? goal_idx_ = -1; for (int gi = 0; gi < grammars_.size(); ++gi) act_chart_[gi]->SeedActiveChart(*grammars_[gi]); cerr << " "; for (int l=1; lAdvanceDotsForAllItemsInCell(i, j, input_); const vector& cell = (*act_chart_[gi])(i,j); for (vector::const_iterator ai = cell.begin(); ai != cell.end(); ++ai) { const RuleBin* rules = (ai->gptr_->GetRules()); if (!rules) continue; ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_cost); } } } ApplyUnaryRules(i,j); for (int gi = 0; gi < grammars_.size(); ++gi) { const Grammar& g = *grammars_[gi]; // deal with non-terminals that were just proved if (g.HasRuleForSpan(i, j, input_.Distance(i,j))) act_chart_[gi]->ExtendActiveItems(i, i, j); } } const vector& dh = chart_(0, input_.size()); for (int di = 0; di < dh.size(); ++di) { const Hypergraph::Node& node = forest_->nodes_[dh[di]]; if (node.cat_ == goal_cat_) { Hypergraph::TailNodeVector ant(1, node.id_); ApplyRule(0, input_.size(), goal_rule_, ant, 0); } } } cerr << endl; if (GoalFound()) forest_->PruneUnreachable(forest_->nodes_.size() - 1); return GoalFound(); } PassiveChart::~PassiveChart() { for (int i = 0; i < act_chart_.size(); ++i) delete act_chart_[i]; } ExhaustiveBottomUpParser::ExhaustiveBottomUpParser( const string& goal_sym, const vector& grammars) : goal_sym_(goal_sym), grammars_(grammars) {} bool ExhaustiveBottomUpParser::Parse(const Lattice& input, Hypergraph* forest) const { stats.Reset(); PassiveChart chart(goal_sym_, grammars_, input, forest); const bool result = chart.Parse(); stats.Report(); return result; }