summaryrefslogtreecommitdiff
path: root/decoder/bottom_up_parser.cc
diff options
context:
space:
mode:
authorCHRISTOPHER DYER <cdyer@CHRISTOPHERs-MacBook-Pro.local>2015-02-03 21:24:07 -0500
committerCHRISTOPHER DYER <cdyer@CHRISTOPHERs-MacBook-Pro.local>2015-02-03 21:24:07 -0500
commit24cee877f2bb4b490495ea578564d0266b1addd3 (patch)
tree9b3456a1f832363294b715be62180c59fc7d7139 /decoder/bottom_up_parser.cc
parente2d9eb0ba94acd728a0706fa4209a36f67dd6d80 (diff)
support multiple sparse features on lattice edges
Diffstat (limited to 'decoder/bottom_up_parser.cc')
-rw-r--r--decoder/bottom_up_parser.cc41
1 files changed, 19 insertions, 22 deletions
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc
index b30f1ec6..a614b8b3 100644
--- a/decoder/bottom_up_parser.cc
+++ b/decoder/bottom_up_parser.cc
@@ -38,13 +38,13 @@ class PassiveChart {
const int j,
const RuleBin* rules,
const Hypergraph::TailNodeVector& tail,
- const float lattice_cost);
+ const SparseVector<double>& lattice_feats);
void ApplyRule(const int i,
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
- const float lattice_cost);
+ const SparseVector<double>& lattice_feats);
void ApplyUnaryRules(const int i, const int j);
void TopoSortUnaries();
@@ -59,7 +59,6 @@ class PassiveChart {
const WordID goal_cat_; // category that is being searched for at [0,n]
TRulePtr goal_rule_;
int goal_idx_; // index of goal node, if found
- const int lc_fid_;
vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars
static WordID kGOAL; // [Goal]
@@ -74,18 +73,18 @@ class ActiveChart {
act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {}
struct ActiveItem {
- ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, float lcost) :
- gptr_(g), ant_nodes_(a), lattice_cost(lcost) {}
+ ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, const SparseVector<double>& lfeats) :
+ gptr_(g), ant_nodes_(a), lattice_feats(lfeats) {}
explicit ActiveItem(const GrammarIter* g) :
- gptr_(g), ant_nodes_(), lattice_cost(0.0) {}
+ gptr_(g), ant_nodes_(), lattice_feats() {}
- void ExtendTerminal(int symbol, float src_cost, vector<ActiveItem>* out_cell) const {
+ void ExtendTerminal(int symbol, const SparseVector<double>& src_feats, vector<ActiveItem>* out_cell) const {
if (symbol == kEPS) {
- out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_cost + src_cost));
+ out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_feats + src_feats));
} else {
const GrammarIter* ni = gptr_->Extend(symbol);
if (ni)
- out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost));
+ out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_feats + src_feats));
}
}
void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector<ActiveItem>* out_cell) const {
@@ -96,12 +95,12 @@ class ActiveChart {
for (unsigned i = 0; i < ant_nodes_.size(); ++i)
na[i] = ant_nodes_[i];
na[ant_nodes_.size()] = node_index;
- out_cell->push_back(ActiveItem(ni, na, lattice_cost));
+ out_cell->push_back(ActiveItem(ni, na, lattice_feats));
}
const GrammarIter* gptr_;
Hypergraph::TailNodeVector ant_nodes_;
- float lattice_cost; // TODO? use SparseVector<double>
+ SparseVector<double> lattice_feats;
};
inline const vector<ActiveItem>& operator()(int i, int j) const { return act_chart_(i,j); }
@@ -134,12 +133,12 @@ class ActiveChart {
for (vector<LatticeArc>::const_iterator ai = out_arcs.begin();
ai != out_arcs.end(); ++ai) {
const WordID& f = ai->label;
- const double& c = ai->cost;
+ const SparseVector<double>& c = ai->features;
const int& len = ai->dist2next;
//cerr << "F: " << TD::Convert(f) << " dest=" << i << "," << (j+len-1) << endl;
const vector<ActiveItem>& ec = act_chart_(i, j-1);
//cerr << " SRC=" << i << "," << (j-1) << " [ec=" << ec.size() << "]" << endl;
- //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_cost << endl; }
+ //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_feats << endl; }
for (vector<ActiveItem>::const_iterator di = ec.begin(); di != ec.end(); ++di)
di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1));
}
@@ -163,7 +162,6 @@ PassiveChart::PassiveChart(const string& goal,
goal_cat_(TD::Convert(goal) * -1),
goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")),
goal_idx_(-1),
- lc_fid_(FD::Convert("LatticeCost")),
unaries_() {
act_chart_.resize(grammars_.size());
for (unsigned i = 0; i < grammars_.size(); ++i) {
@@ -232,7 +230,7 @@ void PassiveChart::ApplyRule(const int i,
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
- const float lattice_cost) {
+ const SparseVector<double>& lattice_feats) {
Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
// cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
new_edge->prev_i_ = r->prev_i;
@@ -240,8 +238,7 @@ void PassiveChart::ApplyRule(const int i,
new_edge->i_ = i;
new_edge->j_ = j;
new_edge->feature_values_ = r->GetFeatureValues();
- if (lattice_cost && lc_fid_)
- new_edge->feature_values_.set_value(lc_fid_, lattice_cost);
+ new_edge->feature_values_ += lattice_feats;
Cat2NodeMap& c2n = nodemap_(i,j);
const bool is_goal = (r->GetLHS() == kGOAL);
const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS());
@@ -265,12 +262,12 @@ void PassiveChart::ApplyRules(const int i,
const int j,
const RuleBin* rules,
const Hypergraph::TailNodeVector& tail,
- const float lattice_cost) {
+ const SparseVector<double>& lattice_feats) {
const int n = rules->GetNumRules();
//cerr << i << " " << j << ": NUM RULES: " << n << endl;
for (int k = 0; k < n; ++k) {
//cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl;
- ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost);
+ ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_feats);
}
}
@@ -283,7 +280,7 @@ void PassiveChart::ApplyUnaryRules(const int i, const int j) {
if (unaries_[ri]->f()[0] == cat) {
//cerr << " --MATCH\n";
const Hypergraph::TailNodeVector ant(1, nodes[di]);
- ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes
+ ApplyRule(i, j, unaries_[ri], ant, SparseVector<double>()); // may update nodes
}
}
}
@@ -313,7 +310,7 @@ bool PassiveChart::Parse() {
ai != cell.end(); ++ai) {
const RuleBin* rules = (ai->gptr_->GetRules());
if (!rules) continue;
- ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_cost);
+ ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_feats);
}
}
}
@@ -331,7 +328,7 @@ bool PassiveChart::Parse() {
const Hypergraph::Node& node = forest_->nodes_[dh[di]];
if (node.cat_ == goal_cat_) {
Hypergraph::TailNodeVector ant(1, node.id_);
- ApplyRule(0, input_.size(), goal_rule_, ant, 0);
+ ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector<double>());
}
}
}