diff options
Diffstat (limited to 'decoder/bottom_up_parser.cc')
-rw-r--r-- | decoder/bottom_up_parser.cc | 279 |
1 files changed, 279 insertions, 0 deletions
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc new file mode 100644 index 00000000..b3315b8a --- /dev/null +++ b/decoder/bottom_up_parser.cc @@ -0,0 +1,279 @@ +#include "bottom_up_parser.h" + +#include <map> + +#include "hg.h" +#include "array2d.h" +#include "tdict.h" + +using namespace std; + +class ActiveChart; +class PassiveChart { + public: + PassiveChart(const string& goal, + const vector<GrammarPtr>& grammars, + const Lattice& input, + Hypergraph* forest); + ~PassiveChart(); + + inline const vector<int>& operator()(int i, int j) const { return chart_(i,j); } + bool Parse(); + inline int size() const { return chart_.width(); } + inline bool GoalFound() const { return goal_idx_ >= 0; } + inline int GetGoalIndex() const { return goal_idx_; } + + private: + void ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const float lattice_cost); + + void ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const float lattice_cost); + + void ApplyUnaryRules(const int i, const int j); + + const vector<GrammarPtr>& grammars_; + const Lattice& input_; + Hypergraph* forest_; + Array2D<vector<int> > chart_; // chart_(i,j) is the list of nodes derived spanning i,j + typedef map<int, int> Cat2NodeMap; + Array2D<Cat2NodeMap> nodemap_; + vector<ActiveChart*> act_chart_; + const WordID goal_cat_; // category that is being searched for at [0,n] + TRulePtr goal_rule_; + int goal_idx_; // index of goal node, if found + const int lc_fid_; + + static WordID kGOAL; // [Goal] +}; + +WordID PassiveChart::kGOAL = 0; + +class ActiveChart { + public: + ActiveChart(const Hypergraph* hg, const PassiveChart& psv_chart) : + hg_(hg), + act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {} + + struct ActiveItem { + ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, float lcost) : + gptr_(g), ant_nodes_(a), lattice_cost(lcost) {} + explicit ActiveItem(const GrammarIter* g) : + gptr_(g), ant_nodes_(), lattice_cost(0.0) {} + + void ExtendTerminal(int symbol, float src_cost, vector<ActiveItem>* out_cell) const { + const GrammarIter* ni = gptr_->Extend(symbol); + if (ni) out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); + } + void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector<ActiveItem>* out_cell) const { + int symbol = hg->nodes_[node_index].cat_; + const GrammarIter* ni = gptr_->Extend(symbol); + if (!ni) return; + Hypergraph::TailNodeVector na(ant_nodes_.size() + 1); + for (int i = 0; i < ant_nodes_.size(); ++i) + na[i] = ant_nodes_[i]; + na[ant_nodes_.size()] = node_index; + out_cell->push_back(ActiveItem(ni, na, lattice_cost)); + } + + const GrammarIter* gptr_; + Hypergraph::TailNodeVector ant_nodes_; + float lattice_cost; // TODO? use SparseVector<double> + }; + + inline const vector<ActiveItem>& operator()(int i, int j) const { return act_chart_(i,j); } + void SeedActiveChart(const Grammar& g) { + int size = act_chart_.width(); + for (int i = 0; i < size; ++i) + if (g.HasRuleForSpan(i,i,0)) + act_chart_(i,i).push_back(ActiveItem(g.GetRoot())); + } + + void ExtendActiveItems(int i, int k, int j) { + //cerr << " LOOK(" << i << "," << k << ") for completed items in (" << k << "," << j << ")\n"; + vector<ActiveItem>& cell = act_chart_(i,j); + const vector<ActiveItem>& icell = act_chart_(i,k); + const vector<int>& idxs = psv_chart_(k, j); + //if (!idxs.empty()) { cerr << "FOUND IN (" << k << "," << j << ")\n"; } + for (vector<ActiveItem>::const_iterator di = icell.begin(); di != icell.end(); ++di) { + for (vector<int>::const_iterator ni = idxs.begin(); ni != idxs.end(); ++ni) { + di->ExtendNonTerminal(hg_, *ni, &cell); + } + } + } + + void AdvanceDotsForAllItemsInCell(int i, int j, const vector<vector<LatticeArc> >& input) { + //cerr << "ADVANCE(" << i << "," << j << ")\n"; + for (int k=i+1; k < j; ++k) + ExtendActiveItems(i, k, j); + + const vector<LatticeArc>& out_arcs = input[j-1]; + for (vector<LatticeArc>::const_iterator ai = out_arcs.begin(); + ai != out_arcs.end(); ++ai) { + const WordID& f = ai->label; + const double& c = ai->cost; + const int& len = ai->dist2next; + //VLOG(1) << "F: " << TD::Convert(f) << endl; + const vector<ActiveItem>& ec = act_chart_(i, j-1); + for (vector<ActiveItem>::const_iterator di = ec.begin(); di != ec.end(); ++di) + di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1)); + } + } + + private: + const Hypergraph* hg_; + Array2D<vector<ActiveItem> > act_chart_; + const PassiveChart& psv_chart_; +}; + +PassiveChart::PassiveChart(const string& goal, + const vector<GrammarPtr>& grammars, + const Lattice& input, + Hypergraph* forest) : + grammars_(grammars), + input_(input), + forest_(forest), + chart_(input.size()+1, input.size()+1), + nodemap_(input.size()+1, input.size()+1), + goal_cat_(TD::Convert(goal) * -1), + goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")), + goal_idx_(-1), + lc_fid_(FD::Convert("LatticeCost")) { + act_chart_.resize(grammars_.size()); + for (int i = 0; i < grammars_.size(); ++i) + act_chart_[i] = new ActiveChart(forest, *this); + if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; + cerr << " Goal category: [" << goal << ']' << endl; +} + +void PassiveChart::ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const float lattice_cost) { + Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); + new_edge->prev_i_ = r->prev_i; + new_edge->prev_j_ = r->prev_j; + new_edge->i_ = i; + new_edge->j_ = j; + new_edge->feature_values_ = r->GetFeatureValues(); + if (lattice_cost) + new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + Cat2NodeMap& c2n = nodemap_(i,j); + const bool is_goal = (r->GetLHS() == kGOAL); + const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); + Hypergraph::Node* node = NULL; + if (ni == c2n.end()) { + node = forest_->AddNode(r->GetLHS(), ""); + c2n[r->GetLHS()] = node->id_; + if (is_goal) { + assert(goal_idx_ == -1); + goal_idx_ = node->id_; + } else { + chart_(i,j).push_back(node->id_); + } + } else { + node = &forest_->nodes_[ni->second]; + } + forest_->ConnectEdgeToHeadNode(new_edge, node); +} + +void PassiveChart::ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const float lattice_cost) { + const int n = rules->GetNumRules(); + for (int k = 0; k < n; ++k) + ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost); +} + +void PassiveChart::ApplyUnaryRules(const int i, const int j) { + const vector<int>& nodes = chart_(i,j); // reference is important! + for (int gi = 0; gi < grammars_.size(); ++gi) { + if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue; + for (int di = 0; di < nodes.size(); ++di) { + const WordID& cat = forest_->nodes_[nodes[di]].cat_; + const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat); + for (int ri = 0; ri < unaries.size(); ++ri) { + // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl; + const Hypergraph::TailNodeVector ant(1, nodes[di]); + ApplyRule(i, j, unaries[ri], ant, 0); // may update nodes + } + } + } +} + +bool PassiveChart::Parse() { + forest_->nodes_.reserve(input_.size() * input_.size() * 2); + forest_->edges_.reserve(input_.size() * input_.size() * 1000); // TODO: reservation?? + goal_idx_ = -1; + for (int gi = 0; gi < grammars_.size(); ++gi) + act_chart_[gi]->SeedActiveChart(*grammars_[gi]); + + cerr << " "; + for (int l=1; l<input_.size()+1; ++l) { + cerr << '.'; + for (int i=0; i<input_.size() + 1 - l; ++i) { + int j = i + l; + for (int gi = 0; gi < grammars_.size(); ++gi) { + const Grammar& g = *grammars_[gi]; + if (g.HasRuleForSpan(i, j, input_.Distance(i, j))) { + act_chart_[gi]->AdvanceDotsForAllItemsInCell(i, j, input_); + + const vector<ActiveChart::ActiveItem>& cell = (*act_chart_[gi])(i,j); + for (vector<ActiveChart::ActiveItem>::const_iterator ai = cell.begin(); + ai != cell.end(); ++ai) { + const RuleBin* rules = (ai->gptr_->GetRules()); + if (!rules) continue; + ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_cost); + } + } + } + ApplyUnaryRules(i,j); + + for (int gi = 0; gi < grammars_.size(); ++gi) { + const Grammar& g = *grammars_[gi]; + // deal with non-terminals that were just proved + if (g.HasRuleForSpan(i, j, input_.Distance(i,j))) + act_chart_[gi]->ExtendActiveItems(i, i, j); + } + } + const vector<int>& dh = chart_(0, input_.size()); + for (int di = 0; di < dh.size(); ++di) { + const Hypergraph::Node& node = forest_->nodes_[dh[di]]; + if (node.cat_ == goal_cat_) { + Hypergraph::TailNodeVector ant(1, node.id_); + ApplyRule(0, input_.size(), goal_rule_, ant, 0); + } + } + } + cerr << endl; + + if (GoalFound()) + forest_->PruneUnreachable(forest_->nodes_.size() - 1); + return GoalFound(); +} + +PassiveChart::~PassiveChart() { + for (int i = 0; i < act_chart_.size(); ++i) + delete act_chart_[i]; +} + +ExhaustiveBottomUpParser::ExhaustiveBottomUpParser( + const string& goal_sym, + const vector<GrammarPtr>& grammars) : + goal_sym_(goal_sym), + grammars_(grammars) {} + +bool ExhaustiveBottomUpParser::Parse(const Lattice& input, + Hypergraph* forest) const { + PassiveChart chart(goal_sym_, grammars_, input, forest); + return chart.Parse(); +} |