From 448b451aa481b1509566ddb11abc3476466def6a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 30 Nov 2014 21:30:09 -0500 Subject: implementation of Rico Sennrich's CKY+ variant; it currently doesn't support span limits so it is not enabled, but it seems to be functional. --- decoder/bottom_up_parser-rs.cc | 341 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 341 insertions(+) create mode 100644 decoder/bottom_up_parser-rs.cc (limited to 'decoder/bottom_up_parser-rs.cc') diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc new file mode 100644 index 00000000..fbde7e24 --- /dev/null +++ b/decoder/bottom_up_parser-rs.cc @@ -0,0 +1,341 @@ +#include "bottom_up_parser-rs.h" + +#include +#include + +#include "node_state_hash.h" +#include "nt_span.h" +#include "hg.h" +#include "array2d.h" +#include "tdict.h" +#include "verbose.h" + +using namespace std; + +static WordID kEPS = 0; + +struct RSActiveItem; +class RSChart { + public: + RSChart(const string& goal, + const vector& grammars, + const Lattice& input, + Hypergraph* forest); + ~RSChart(); + + void AddToChart(const RSActiveItem& x, int i, int j); + void ConsumeTerminal(const RSActiveItem& x, int i, int j, int k); + void ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k); + bool Parse(); + inline bool GoalFound() const { return goal_idx_ >= 0; } + inline int GetGoalIndex() const { return goal_idx_; } + + private: + void ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const float lattice_cost); + + // returns true if a new node was added to the chart + // false otherwise + bool ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const float lattice_cost); + + void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx); + void TopoSortUnaries(); + + const vector& grammars_; + const Lattice& input_; + Hypergraph* forest_; + Array2D> chart_; // chart_(i,j) is the list of nodes (represented + // by their index in forest_->nodes_) derived spanning i,j + typedef map Cat2NodeMap; + Array2D nodemap_; + const WordID goal_cat_; // category that is being searched for at [0,n] + TRulePtr goal_rule_; + int goal_idx_; // index of goal node, if found + const int lc_fid_; + vector unaries_; // topologically sorted list of unary rules from all grammars + + static WordID kGOAL; // [Goal] +}; + +WordID RSChart::kGOAL = 0; + +// "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span" +struct RSActiveItem { + explicit RSActiveItem(const GrammarIter* g, int i) : + gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {} + void ExtendTerminal(int symbol, float src_cost) { + lattice_cost += src_cost; + if (symbol != kEPS) + gptr_ = gptr_->Extend(symbol); + } + void ExtendNonTerminal(const Hypergraph* hg, int node_index) { + gptr_ = gptr_->Extend(hg->nodes_[node_index].cat_); + ant_nodes_.push_back(node_index); + } + // returns false if the extension has failed + explicit operator bool() const { + return gptr_; + } + const GrammarIter* gptr_; + Hypergraph::TailNodeVector ant_nodes_; + float lattice_cost; // TODO: use SparseVector to encode input features + short i_; +}; + +// some notes on the implementation +// "X" in Rico's Algorithm 2 roughly looks like it is just a pointer into a grammar +// trie, but it is actually a full "dotted item" since it needs to contain the information +// to build the hypergraph (i.e., it must remember the antecedent nodes and where they are, +// also any information about the path costs). + +RSChart::RSChart(const string& goal, + const vector& grammars, + const Lattice& input, + Hypergraph* forest) : + grammars_(grammars), + input_(input), + forest_(forest), + chart_(input.size()+1, input.size()+1), + nodemap_(input.size()+1, input.size()+1), + goal_cat_(TD::Convert(goal) * -1), + goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), + goal_idx_(-1), + lc_fid_(FD::Convert("LatticeCost")), + unaries_() { + for (unsigned i = 0; i < grammars_.size(); ++i) { + const vector& u = grammars_[i]->GetAllUnaryRules(); + for (unsigned j = 0; j < u.size(); ++j) + unaries_.push_back(u[j]); + } + TopoSortUnaries(); + if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; + if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl; +} + +static bool TopoSortVisit(int node, vector& u, const map >& g, map& mark) { + if (mark[node] == 1) { + cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n"; + return false; // cycle detected + } else if (mark[node] == 2) { + return true; // already been + } + mark[node] = 1; + const map >::const_iterator nit = g.find(node); + if (nit != g.end()) { + const vector& edges = nit->second; + vector okay(edges.size(), true); + for (unsigned i = 0; i < edges.size(); ++i) { + okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark); + if (!okay[i]) { + cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl; + } + } + for (unsigned i = 0; i < edges.size(); ++i) { + if (okay[i]) u.push_back(edges[i]); + //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl; + } + } + mark[node] = 2; + return true; +} + +void RSChart::TopoSortUnaries() { + vector u(unaries_.size()); u.clear(); + map > g; + map mark; + //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl; + mark[goal_cat_] = 2; + for (unsigned i = 0; i < unaries_.size(); ++i) { + //cerr << "Adding: " << unaries_[i]->AsString() << endl; + g[unaries_[i]->f()[0]].push_back(unaries_[i]); + } + //m[unaries_[i]->lhs_].push_back(unaries_[i]); + for (map >::iterator it = g.begin(); it != g.end(); ++it) { + //cerr << "PROC: " << TD::Convert(-it->first) << endl; + if (mark[it->first] > 0) { + //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n"; + } else { + TopoSortVisit(it->first, u, g, mark); + } + } + unaries_.clear(); + for (int i = u.size() - 1; i >= 0; --i) + unaries_.push_back(u[i]); +} + +bool RSChart::ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const float lattice_cost) { + Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); + //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; + new_edge->prev_i_ = r->prev_i; + new_edge->prev_j_ = r->prev_j; + new_edge->i_ = i; + new_edge->j_ = j; + new_edge->feature_values_ = r->GetFeatureValues(); + if (lattice_cost && lc_fid_) + new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + Cat2NodeMap& c2n = nodemap_(i,j); + const bool is_goal = (r->GetLHS() == kGOAL); + const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); + Hypergraph::Node* node = NULL; + bool added_node = false; + if (ni == c2n.end()) { + //cerr << "(" << i << "," << j << ") => " << TD::Convert(-r->GetLHS()) << endl; + added_node = true; + node = forest_->AddNode(r->GetLHS()); + c2n[r->GetLHS()] = node->id_; + if (is_goal) { + assert(goal_idx_ == -1); + goal_idx_ = node->id_; + } else { + chart_(i,j).push_back(node->id_); + } + } else { + node = &forest_->nodes_[ni->second]; + } + forest_->ConnectEdgeToHeadNode(new_edge, node); + return added_node; +} + +void RSChart::ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const float lattice_cost) { + const int n = rules->GetNumRules(); + //cerr << i << " " << j << ": NUM RULES: " << n << endl; + for (int k = 0; k < n; ++k) { + //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl; + TRulePtr rule = rules->GetIthRule(k); + // apply rule, and if we create a new node, apply any necessary + // unary rules + if (ApplyRule(i, j, rule, tail, lattice_cost)) { + unsigned nodeidx = nodemap_(i,j)[rule->lhs_]; + ApplyUnaryRules(i, j, rule->lhs_, nodeidx); + } + } +} + +void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx) { + for (unsigned ri = 0; ri < unaries_.size(); ++ri) { + //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; + if (unaries_[ri]->f()[0] == cat) { + //cerr << " --MATCH\n"; + WordID new_lhs = unaries_[ri]->GetLHS(); + const Hypergraph::TailNodeVector ant(1, nodeidx); + if (ApplyRule(i, j, unaries_[ri], ant, 0)) { + //cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl; + unsigned nodeidx = nodemap_(i,j)[new_lhs]; + ApplyUnaryRules(i, j, new_lhs, nodeidx); + } + } + } +} + +void RSChart::AddToChart(const RSActiveItem& x, int i, int j) { + // deal with completed rules + const RuleBin* rb = x.gptr_->GetRules(); + if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost); + + //cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n"; + // continue looking for extensions of the rule to the right + for (unsigned k = j+1; k <= input_.size(); ++k) { + ConsumeTerminal(x, i, j, k); + ConsumeNonTerminal(x, i, j, k); + } +} + +void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) { + //cerr << "ConsumeT(" << i << "," << j << "," << k << "):\n"; + + const unsigned check_edge_len = k - j; + // long-term TODO preindex this search so i->len->words is constant time rather than fan out + for (auto& in_edge : input_[j]) { + if (in_edge.dist2next == check_edge_len) { + //cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl; + RSActiveItem copy = x; + copy.ExtendTerminal(in_edge.label, in_edge.cost); + if (copy) AddToChart(copy, i, k); + } + } +} + +void RSChart::ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k) { + //cerr << "ConsumeNT(" << i << "," << j << "," << k << "):\n"; + for (auto& nodeidx : chart_(j,k)) { + //cerr << " Found completed NT in (" << j << "," << k << ") of type " << TD::Convert(-forest_->nodes_[nodeidx].cat_) << endl; + RSActiveItem copy = x; + copy.ExtendNonTerminal(forest_, nodeidx); + if (copy) AddToChart(copy, i, k); + } +} + +bool RSChart::Parse() { + size_t in_size_2 = input_.size() * input_.size(); + forest_->nodes_.reserve(in_size_2 * 2); + size_t res = min(static_cast(2000000), static_cast(in_size_2 * 1000)); + forest_->edges_.reserve(res); + goal_idx_ = -1; + const int N = input_.size(); + for (int i = N - 1; i >= 0; --i) { + for (int j = i + 1; j <= N; ++j) { + for (unsigned gi = 0; gi < grammars_.size(); ++gi) { + RSActiveItem item(grammars_[gi]->GetRoot(), i); + ConsumeTerminal(item, i, i, j); + } + for (unsigned gi = 0; gi < grammars_.size(); ++gi) { + RSActiveItem item(grammars_[gi]->GetRoot(), i); + ConsumeNonTerminal(item, i, i, j); + } + } + } + + // look for goal + const vector& dh = chart_(0, input_.size()); + for (unsigned di = 0; di < dh.size(); ++di) { + const Hypergraph::Node& node = forest_->nodes_[dh[di]]; + if (node.cat_ == goal_cat_) { + Hypergraph::TailNodeVector ant(1, node.id_); + ApplyRule(0, input_.size(), goal_rule_, ant, 0); + } + } + if (!SILENT) cerr << endl; + + if (GoalFound()) + forest_->PruneUnreachable(forest_->nodes_.size() - 1); + return GoalFound(); +} + +RSChart::~RSChart() {} + +RSExhaustiveBottomUpParser::RSExhaustiveBottomUpParser( + const string& goal_sym, + const vector& grammars) : + goal_sym_(goal_sym), + grammars_(grammars) {} + +bool RSExhaustiveBottomUpParser::Parse(const Lattice& input, + Hypergraph* forest) const { + kEPS = TD::Convert("*EPS*"); + RSChart chart(goal_sym_, grammars_, input, forest); + const bool result = chart.Parse(); + + if (result) { + for (auto& node : forest->nodes_) { + Span prev; + const Span s = forest->NodeSpan(node.id_, &prev); + node.node_hash = cdec::HashNode(node.cat_, s.l, s.r, prev.l, prev.r); + } + } + return result; +} -- cgit v1.2.3