summaryrefslogtreecommitdiff
path: root/decoder/bottom_up_parser-rs.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/bottom_up_parser-rs.cc')
-rw-r--r--decoder/bottom_up_parser-rs.cc29
1 files changed, 14 insertions, 15 deletions
diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc
index fbde7e24..863d7e2f 100644
--- a/decoder/bottom_up_parser-rs.cc
+++ b/decoder/bottom_up_parser-rs.cc
@@ -35,7 +35,7 @@ class RSChart {
const int j,
const RuleBin* rules,
const Hypergraph::TailNodeVector& tail,
- const float lattice_cost);
+ const SparseVector<double>& lattice_feats);
// returns true if a new node was added to the chart
// false otherwise
@@ -43,7 +43,7 @@ class RSChart {
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
- const float lattice_cost);
+ const SparseVector<double>& lattice_feats);
void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx);
void TopoSortUnaries();
@@ -69,9 +69,9 @@ WordID RSChart::kGOAL = 0;
// "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span"
struct RSActiveItem {
explicit RSActiveItem(const GrammarIter* g, int i) :
- gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {}
- void ExtendTerminal(int symbol, float src_cost) {
- lattice_cost += src_cost;
+ gptr_(g), ant_nodes_(), lattice_feats(), i_(i) {}
+ void ExtendTerminal(int symbol, const SparseVector<double>& src_feats) {
+ lattice_feats += src_feats;
if (symbol != kEPS)
gptr_ = gptr_->Extend(symbol);
}
@@ -85,7 +85,7 @@ struct RSActiveItem {
}
const GrammarIter* gptr_;
Hypergraph::TailNodeVector ant_nodes_;
- float lattice_cost; // TODO: use SparseVector<double> to encode input features
+ SparseVector<double> lattice_feats;
short i_;
};
@@ -174,7 +174,7 @@ bool RSChart::ApplyRule(const int i,
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
- const float lattice_cost) {
+ const SparseVector<double>& lattice_feats) {
Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
//cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
new_edge->prev_i_ = r->prev_i;
@@ -182,8 +182,7 @@ bool RSChart::ApplyRule(const int i,
new_edge->i_ = i;
new_edge->j_ = j;
new_edge->feature_values_ = r->GetFeatureValues();
- if (lattice_cost && lc_fid_)
- new_edge->feature_values_.set_value(lc_fid_, lattice_cost);
+ new_edge->feature_values_ += lattice_feats;
Cat2NodeMap& c2n = nodemap_(i,j);
const bool is_goal = (r->GetLHS() == kGOAL);
const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS());
@@ -211,7 +210,7 @@ void RSChart::ApplyRules(const int i,
const int j,
const RuleBin* rules,
const Hypergraph::TailNodeVector& tail,
- const float lattice_cost) {
+ const SparseVector<double>& lattice_feats) {
const int n = rules->GetNumRules();
//cerr << i << " " << j << ": NUM RULES: " << n << endl;
for (int k = 0; k < n; ++k) {
@@ -219,7 +218,7 @@ void RSChart::ApplyRules(const int i,
TRulePtr rule = rules->GetIthRule(k);
// apply rule, and if we create a new node, apply any necessary
// unary rules
- if (ApplyRule(i, j, rule, tail, lattice_cost)) {
+ if (ApplyRule(i, j, rule, tail, lattice_feats)) {
unsigned nodeidx = nodemap_(i,j)[rule->lhs_];
ApplyUnaryRules(i, j, rule->lhs_, nodeidx);
}
@@ -233,7 +232,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig
//cerr << " --MATCH\n";
WordID new_lhs = unaries_[ri]->GetLHS();
const Hypergraph::TailNodeVector ant(1, nodeidx);
- if (ApplyRule(i, j, unaries_[ri], ant, 0)) {
+ if (ApplyRule(i, j, unaries_[ri], ant, SparseVector<double>())) {
//cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl;
unsigned nodeidx = nodemap_(i,j)[new_lhs];
ApplyUnaryRules(i, j, new_lhs, nodeidx);
@@ -245,7 +244,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig
void RSChart::AddToChart(const RSActiveItem& x, int i, int j) {
// deal with completed rules
const RuleBin* rb = x.gptr_->GetRules();
- if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost);
+ if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_feats);
//cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n";
// continue looking for extensions of the rule to the right
@@ -264,7 +263,7 @@ void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) {
if (in_edge.dist2next == check_edge_len) {
//cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl;
RSActiveItem copy = x;
- copy.ExtendTerminal(in_edge.label, in_edge.cost);
+ copy.ExtendTerminal(in_edge.label, in_edge.features);
if (copy) AddToChart(copy, i, k);
}
}
@@ -306,7 +305,7 @@ bool RSChart::Parse() {
const Hypergraph::Node& node = forest_->nodes_[dh[di]];
if (node.cat_ == goal_cat_) {
Hypergraph::TailNodeVector ant(1, node.id_);
- ApplyRule(0, input_.size(), goal_rule_, ant, 0);
+ ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector<double>());
}
}
if (!SILENT) cerr << endl;