From 328873498127f44ff257db9e9a9551c6561587c0 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 22 Feb 2011 17:59:31 -0500 Subject: support span scores --- decoder/ff_spans.cc | 79 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 60 insertions(+), 19 deletions(-) (limited to 'decoder/ff_spans.cc') diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc index 989eed56..511ab728 100644 --- a/decoder/ff_spans.cc +++ b/decoder/ff_spans.cc @@ -4,6 +4,7 @@ #include #include "filelib.h" +#include "stringlib.h" #include "sentence_metadata.h" #include "lattice.h" #include "fdict.h" @@ -12,12 +13,18 @@ using namespace std; SpanFeatures::SpanFeatures(const string& param) : - kS(TD::Convert("S") * -1), - kX(TD::Convert("X") * -1) { - if (param.size() > 0) { + kS(TD::Convert("S") * -1), + kX(TD::Convert("X") * -1), + use_collapsed_features_(false) { + string mapfile = param; + string valfile; + vector toks; + Tokenize(param, ' ', &toks); + if (toks.size() == 2) { mapfile = param[0]; valfile = param[1]; } + if (mapfile.size() > 0) { int lc = 0; if (!SILENT) { cerr << "Reading word map for SpanFeatures from " << param << endl; } - ReadFile rf(param); + ReadFile rf(mapfile); istream& in = *rf.stream(); string line; vector v; @@ -37,6 +44,27 @@ SpanFeatures::SpanFeatures(const string& param) : word2class_[TD::Convert("")] = TD::Convert("EOS"); oov_ = TD::Convert("OOV"); } + + if (valfile.size() > 0) { + use_collapsed_features_ = true; + fid_beg_ = FD::Convert("SpanBegin"); + fid_end_ = FD::Convert("SpanEnd"); + fid_span_s_ = FD::Convert("SSpanContext"); + fid_span_ = FD::Convert("XSpanContext"); + ReadFile rf(valfile); + if (!SILENT) { cerr << " Loading span scores from " << valfile << endl; } + istream& in = *rf.stream(); + string line; + while(in) { + getline(in, line); + if (line.size() == 0 || line[0] == '#') { continue; } + istringstream in(line); + string feat_name; + double weight; + in >> feat_name >> weight; + feat2val_[feat_name] = weight; + } + } } void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, @@ -45,24 +73,24 @@ void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, SparseVector* features, SparseVector* estimated_features, void* context) const { -// char& res = *static_cast(context); -// res = edge.j_ - edge.i_; -// assert(res >= 0); assert(edge.j_ < end_span_ids_.size()); assert(edge.j_ >= 0); - features->set_value(end_span_ids_[edge.j_], 1); assert(edge.i_ < beg_span_ids_.size()); assert(edge.i_ >= 0); - features->set_value(beg_span_ids_[edge.i_], 1); - if (edge.rule_->lhs_ == kS) - features->set_value(span_feats_(edge.i_,edge.j_).second, 1); - else - features->set_value(span_feats_(edge.i_,edge.j_).first, 1); - if (edge.Arity() == 2) { - const TRule& rule = *edge.rule_; - if (rule.f_[0] == kS && rule.f_[1] == kX) { -// char x_width = *static_cast(ant_contexts[1]); - } + if (use_collapsed_features_) { + features->set_value(fid_end_, end_span_vals_[edge.j_]); + features->set_value(fid_beg_, beg_span_vals_[edge.i_]); + if (edge.rule_->lhs_ == kS) + features->set_value(fid_span_s_, span_vals_(edge.i_,edge.j_).second); + else + features->set_value(fid_span_, span_vals_(edge.i_,edge.j_).first); + } else { // non-collapsed features: + features->set_value(end_span_ids_[edge.j_], 1); + features->set_value(beg_span_ids_[edge.i_], 1); + if (edge.rule_->lhs_ == kS) + features->set_value(span_feats_(edge.i_,edge.j_).second, 1); + else + features->set_value(span_feats_(edge.i_,edge.j_).first, 1); } } @@ -79,6 +107,12 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { const WordID bos = TD::Convert(""); beg_span_ids_.resize(lattice.size() + 1); end_span_ids_.resize(lattice.size() + 1); + span_feats_.resize(lattice.size() + 1, lattice.size() + 1); + if (use_collapsed_features_) { + beg_span_vals_.resize(lattice.size() + 1); + end_span_vals_.resize(lattice.size() + 1); + span_vals_.resize(lattice.size() + 1, lattice.size() + 1); + } for (int i = 0; i <= lattice.size(); ++i) { WordID word = eos; WordID bword = bos; @@ -94,8 +128,11 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { ostringstream bfid; bfid << "BS:" << TD::Convert(bword); beg_span_ids_[i] = FD::Convert(bfid.str()); + if (use_collapsed_features_) { + end_span_vals_[i] = feat2val_[sfid.str()]; + beg_span_vals_[i] = feat2val_[bfid.str()]; + } } - span_feats_.resize(lattice.size() + 1, lattice.size() + 1); for (int i = 0; i <= lattice.size(); ++i) { WordID bword = bos; if (i > 0) @@ -110,6 +147,10 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { pf << "SS:" << TD::Convert(bword) << "_" << TD::Convert(word); span_feats_(i,j).first = FD::Convert(pf.str()); span_feats_(i,j).second = FD::Convert("S_" + pf.str()); + if (use_collapsed_features_) { + span_vals_(i,j).first = feat2val_[pf.str()]; + span_vals_(i,j).second = feat2val_["S_" + pf.str()]; + } } } } -- cgit v1.2.3