diff options
author | CHRISTOPHER DYER <cdyer@CHRISTOPHERs-MacBook-Pro.local> | 2015-02-03 21:24:07 -0500 |
---|---|---|
committer | CHRISTOPHER DYER <cdyer@CHRISTOPHERs-MacBook-Pro.local> | 2015-02-03 21:24:07 -0500 |
commit | c7b2a39958912d7b85a384a871609e6db73042c7 (patch) | |
tree | 5405ac792a173edaa50e67750c2b10722a47962a /decoder/bottom_up_parser.cc | |
parent | afd65846cf1456a8b49e8482b9a40777014f6883 (diff) |
support multiple sparse features on lattice edges
Diffstat (limited to 'decoder/bottom_up_parser.cc')
-rw-r--r-- | decoder/bottom_up_parser.cc | 41 |
1 files changed, 19 insertions, 22 deletions
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index b30f1ec6..a614b8b3 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -38,13 +38,13 @@ class PassiveChart { const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost); + const SparseVector<double>& lattice_feats); void ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost); + const SparseVector<double>& lattice_feats); void ApplyUnaryRules(const int i, const int j); void TopoSortUnaries(); @@ -59,7 +59,6 @@ class PassiveChart { const WordID goal_cat_; // category that is being searched for at [0,n] TRulePtr goal_rule_; int goal_idx_; // index of goal node, if found - const int lc_fid_; vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars static WordID kGOAL; // [Goal] @@ -74,18 +73,18 @@ class ActiveChart { act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {} struct ActiveItem { - ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, float lcost) : - gptr_(g), ant_nodes_(a), lattice_cost(lcost) {} + ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, const SparseVector<double>& lfeats) : + gptr_(g), ant_nodes_(a), lattice_feats(lfeats) {} explicit ActiveItem(const GrammarIter* g) : - gptr_(g), ant_nodes_(), lattice_cost(0.0) {} + gptr_(g), ant_nodes_(), lattice_feats() {} - void ExtendTerminal(int symbol, float src_cost, vector<ActiveItem>* out_cell) const { + void ExtendTerminal(int symbol, const SparseVector<double>& src_feats, vector<ActiveItem>* out_cell) const { if (symbol == kEPS) { - out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_cost + src_cost)); + out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_feats + src_feats)); } else { const GrammarIter* ni = gptr_->Extend(symbol); if (ni) - out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); + out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_feats + src_feats)); } } void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector<ActiveItem>* out_cell) const { @@ -96,12 +95,12 @@ class ActiveChart { for (unsigned i = 0; i < ant_nodes_.size(); ++i) na[i] = ant_nodes_[i]; na[ant_nodes_.size()] = node_index; - out_cell->push_back(ActiveItem(ni, na, lattice_cost)); + out_cell->push_back(ActiveItem(ni, na, lattice_feats)); } const GrammarIter* gptr_; Hypergraph::TailNodeVector ant_nodes_; - float lattice_cost; // TODO? use SparseVector<double> + SparseVector<double> lattice_feats; }; inline const vector<ActiveItem>& operator()(int i, int j) const { return act_chart_(i,j); } @@ -134,12 +133,12 @@ class ActiveChart { for (vector<LatticeArc>::const_iterator ai = out_arcs.begin(); ai != out_arcs.end(); ++ai) { const WordID& f = ai->label; - const double& c = ai->cost; + const SparseVector<double>& c = ai->features; const int& len = ai->dist2next; //cerr << "F: " << TD::Convert(f) << " dest=" << i << "," << (j+len-1) << endl; const vector<ActiveItem>& ec = act_chart_(i, j-1); //cerr << " SRC=" << i << "," << (j-1) << " [ec=" << ec.size() << "]" << endl; - //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_cost << endl; } + //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_feats << endl; } for (vector<ActiveItem>::const_iterator di = ec.begin(); di != ec.end(); ++di) di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1)); } @@ -163,7 +162,6 @@ PassiveChart::PassiveChart(const string& goal, goal_cat_(TD::Convert(goal) * -1), goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), goal_idx_(-1), - lc_fid_(FD::Convert("LatticeCost")), unaries_() { act_chart_.resize(grammars_.size()); for (unsigned i = 0; i < grammars_.size(); ++i) { @@ -232,7 +230,7 @@ void PassiveChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost) { + const SparseVector<double>& lattice_feats) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; @@ -240,8 +238,7 @@ void PassiveChart::ApplyRule(const int i, new_edge->i_ = i; new_edge->j_ = j; new_edge->feature_values_ = r->GetFeatureValues(); - if (lattice_cost && lc_fid_) - new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + new_edge->feature_values_ += lattice_feats; Cat2NodeMap& c2n = nodemap_(i,j); const bool is_goal = (r->GetLHS() == kGOAL); const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); @@ -265,12 +262,12 @@ void PassiveChart::ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost) { + const SparseVector<double>& lattice_feats) { const int n = rules->GetNumRules(); //cerr << i << " " << j << ": NUM RULES: " << n << endl; for (int k = 0; k < n; ++k) { //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl; - ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost); + ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_feats); } } @@ -283,7 +280,7 @@ void PassiveChart::ApplyUnaryRules(const int i, const int j) { if (unaries_[ri]->f()[0] == cat) { //cerr << " --MATCH\n"; const Hypergraph::TailNodeVector ant(1, nodes[di]); - ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes + ApplyRule(i, j, unaries_[ri], ant, SparseVector<double>()); // may update nodes } } } @@ -313,7 +310,7 @@ bool PassiveChart::Parse() { ai != cell.end(); ++ai) { const RuleBin* rules = (ai->gptr_->GetRules()); if (!rules) continue; - ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_cost); + ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_feats); } } } @@ -331,7 +328,7 @@ bool PassiveChart::Parse() { const Hypergraph::Node& node = forest_->nodes_[dh[di]]; if (node.cat_ == goal_cat_) { Hypergraph::TailNodeVector ant(1, node.id_); - ApplyRule(0, input_.size(), goal_rule_, ant, 0); + ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector<double>()); } } } |