From c7b2a39958912d7b85a384a871609e6db73042c7 Mon Sep 17 00:00:00 2001 From: CHRISTOPHER DYER Date: Tue, 3 Feb 2015 21:24:07 -0500 Subject: support multiple sparse features on lattice edges --- decoder/bottom_up_parser-rs.cc | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) (limited to 'decoder/bottom_up_parser-rs.cc') diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc index fbde7e24..863d7e2f 100644 --- a/decoder/bottom_up_parser-rs.cc +++ b/decoder/bottom_up_parser-rs.cc @@ -35,7 +35,7 @@ class RSChart { const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost); + const SparseVector& lattice_feats); // returns true if a new node was added to the chart // false otherwise @@ -43,7 +43,7 @@ class RSChart { const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost); + const SparseVector& lattice_feats); void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx); void TopoSortUnaries(); @@ -69,9 +69,9 @@ WordID RSChart::kGOAL = 0; // "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span" struct RSActiveItem { explicit RSActiveItem(const GrammarIter* g, int i) : - gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {} - void ExtendTerminal(int symbol, float src_cost) { - lattice_cost += src_cost; + gptr_(g), ant_nodes_(), lattice_feats(), i_(i) {} + void ExtendTerminal(int symbol, const SparseVector& src_feats) { + lattice_feats += src_feats; if (symbol != kEPS) gptr_ = gptr_->Extend(symbol); } @@ -85,7 +85,7 @@ struct RSActiveItem { } const GrammarIter* gptr_; Hypergraph::TailNodeVector ant_nodes_; - float lattice_cost; // TODO: use SparseVector to encode input features + SparseVector lattice_feats; short i_; }; @@ -174,7 +174,7 @@ bool RSChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost) { + const SparseVector& lattice_feats) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; @@ -182,8 +182,7 @@ bool RSChart::ApplyRule(const int i, new_edge->i_ = i; new_edge->j_ = j; new_edge->feature_values_ = r->GetFeatureValues(); - if (lattice_cost && lc_fid_) - new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + new_edge->feature_values_ += lattice_feats; Cat2NodeMap& c2n = nodemap_(i,j); const bool is_goal = (r->GetLHS() == kGOAL); const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); @@ -211,7 +210,7 @@ void RSChart::ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost) { + const SparseVector& lattice_feats) { const int n = rules->GetNumRules(); //cerr << i << " " << j << ": NUM RULES: " << n << endl; for (int k = 0; k < n; ++k) { @@ -219,7 +218,7 @@ void RSChart::ApplyRules(const int i, TRulePtr rule = rules->GetIthRule(k); // apply rule, and if we create a new node, apply any necessary // unary rules - if (ApplyRule(i, j, rule, tail, lattice_cost)) { + if (ApplyRule(i, j, rule, tail, lattice_feats)) { unsigned nodeidx = nodemap_(i,j)[rule->lhs_]; ApplyUnaryRules(i, j, rule->lhs_, nodeidx); } @@ -233,7 +232,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig //cerr << " --MATCH\n"; WordID new_lhs = unaries_[ri]->GetLHS(); const Hypergraph::TailNodeVector ant(1, nodeidx); - if (ApplyRule(i, j, unaries_[ri], ant, 0)) { + if (ApplyRule(i, j, unaries_[ri], ant, SparseVector())) { //cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl; unsigned nodeidx = nodemap_(i,j)[new_lhs]; ApplyUnaryRules(i, j, new_lhs, nodeidx); @@ -245,7 +244,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig void RSChart::AddToChart(const RSActiveItem& x, int i, int j) { // deal with completed rules const RuleBin* rb = x.gptr_->GetRules(); - if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost); + if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_feats); //cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n"; // continue looking for extensions of the rule to the right @@ -264,7 +263,7 @@ void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) { if (in_edge.dist2next == check_edge_len) { //cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl; RSActiveItem copy = x; - copy.ExtendTerminal(in_edge.label, in_edge.cost); + copy.ExtendTerminal(in_edge.label, in_edge.features); if (copy) AddToChart(copy, i, k); } } @@ -306,7 +305,7 @@ bool RSChart::Parse() { const Hypergraph::Node& node = forest_->nodes_[dh[di]]; if (node.cat_ == goal_cat_) { Hypergraph::TailNodeVector ant(1, node.id_); - ApplyRule(0, input_.size(), goal_rule_, ant, 0); + ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector()); } } if (!SILENT) cerr << endl; -- cgit v1.2.3