summaryrefslogtreecommitdiff
path: root/decoder/ff_spans.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-02-22 17:59:31 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-02-22 17:59:31 -0500
commit949ad1a3938d98d4366fbf3e487acc8fc4134ce1 (patch)
tree36f8c2ea5a2616462d9e7b7268324e67c5f5cbea /decoder/ff_spans.cc
parent37ed2c59d9f5504476592cd4bd89725a1ddcadba (diff)
support span scores
Diffstat (limited to 'decoder/ff_spans.cc')
-rw-r--r--decoder/ff_spans.cc79
1 files changed, 60 insertions, 19 deletions
diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc
index 989eed56..511ab728 100644
--- a/decoder/ff_spans.cc
+++ b/decoder/ff_spans.cc
@@ -4,6 +4,7 @@
#include <cassert>
#include "filelib.h"
+#include "stringlib.h"
#include "sentence_metadata.h"
#include "lattice.h"
#include "fdict.h"
@@ -12,12 +13,18 @@
using namespace std;
SpanFeatures::SpanFeatures(const string& param) :
- kS(TD::Convert("S") * -1),
- kX(TD::Convert("X") * -1) {
- if (param.size() > 0) {
+ kS(TD::Convert("S") * -1),
+ kX(TD::Convert("X") * -1),
+ use_collapsed_features_(false) {
+ string mapfile = param;
+ string valfile;
+ vector<string> toks;
+ Tokenize(param, ' ', &toks);
+ if (toks.size() == 2) { mapfile = param[0]; valfile = param[1]; }
+ if (mapfile.size() > 0) {
int lc = 0;
if (!SILENT) { cerr << "Reading word map for SpanFeatures from " << param << endl; }
- ReadFile rf(param);
+ ReadFile rf(mapfile);
istream& in = *rf.stream();
string line;
vector<WordID> v;
@@ -37,6 +44,27 @@ SpanFeatures::SpanFeatures(const string& param) :
word2class_[TD::Convert("</s>")] = TD::Convert("EOS");
oov_ = TD::Convert("OOV");
}
+
+ if (valfile.size() > 0) {
+ use_collapsed_features_ = true;
+ fid_beg_ = FD::Convert("SpanBegin");
+ fid_end_ = FD::Convert("SpanEnd");
+ fid_span_s_ = FD::Convert("SSpanContext");
+ fid_span_ = FD::Convert("XSpanContext");
+ ReadFile rf(valfile);
+ if (!SILENT) { cerr << " Loading span scores from " << valfile << endl; }
+ istream& in = *rf.stream();
+ string line;
+ while(in) {
+ getline(in, line);
+ if (line.size() == 0 || line[0] == '#') { continue; }
+ istringstream in(line);
+ string feat_name;
+ double weight;
+ in >> feat_name >> weight;
+ feat2val_[feat_name] = weight;
+ }
+ }
}
void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
@@ -45,24 +73,24 @@ void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
SparseVector<double>* features,
SparseVector<double>* estimated_features,
void* context) const {
-// char& res = *static_cast<char*>(context);
-// res = edge.j_ - edge.i_;
-// assert(res >= 0);
assert(edge.j_ < end_span_ids_.size());
assert(edge.j_ >= 0);
- features->set_value(end_span_ids_[edge.j_], 1);
assert(edge.i_ < beg_span_ids_.size());
assert(edge.i_ >= 0);
- features->set_value(beg_span_ids_[edge.i_], 1);
- if (edge.rule_->lhs_ == kS)
- features->set_value(span_feats_(edge.i_,edge.j_).second, 1);
- else
- features->set_value(span_feats_(edge.i_,edge.j_).first, 1);
- if (edge.Arity() == 2) {
- const TRule& rule = *edge.rule_;
- if (rule.f_[0] == kS && rule.f_[1] == kX) {
-// char x_width = *static_cast<const char*>(ant_contexts[1]);
- }
+ if (use_collapsed_features_) {
+ features->set_value(fid_end_, end_span_vals_[edge.j_]);
+ features->set_value(fid_beg_, beg_span_vals_[edge.i_]);
+ if (edge.rule_->lhs_ == kS)
+ features->set_value(fid_span_s_, span_vals_(edge.i_,edge.j_).second);
+ else
+ features->set_value(fid_span_, span_vals_(edge.i_,edge.j_).first);
+ } else { // non-collapsed features:
+ features->set_value(end_span_ids_[edge.j_], 1);
+ features->set_value(beg_span_ids_[edge.i_], 1);
+ if (edge.rule_->lhs_ == kS)
+ features->set_value(span_feats_(edge.i_,edge.j_).second, 1);
+ else
+ features->set_value(span_feats_(edge.i_,edge.j_).first, 1);
}
}
@@ -79,6 +107,12 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
const WordID bos = TD::Convert("<s>");
beg_span_ids_.resize(lattice.size() + 1);
end_span_ids_.resize(lattice.size() + 1);
+ span_feats_.resize(lattice.size() + 1, lattice.size() + 1);
+ if (use_collapsed_features_) {
+ beg_span_vals_.resize(lattice.size() + 1);
+ end_span_vals_.resize(lattice.size() + 1);
+ span_vals_.resize(lattice.size() + 1, lattice.size() + 1);
+ }
for (int i = 0; i <= lattice.size(); ++i) {
WordID word = eos;
WordID bword = bos;
@@ -94,8 +128,11 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
ostringstream bfid;
bfid << "BS:" << TD::Convert(bword);
beg_span_ids_[i] = FD::Convert(bfid.str());
+ if (use_collapsed_features_) {
+ end_span_vals_[i] = feat2val_[sfid.str()];
+ beg_span_vals_[i] = feat2val_[bfid.str()];
+ }
}
- span_feats_.resize(lattice.size() + 1, lattice.size() + 1);
for (int i = 0; i <= lattice.size(); ++i) {
WordID bword = bos;
if (i > 0)
@@ -110,6 +147,10 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
pf << "SS:" << TD::Convert(bword) << "_" << TD::Convert(word);
span_feats_(i,j).first = FD::Convert(pf.str());
span_feats_(i,j).second = FD::Convert("S_" + pf.str());
+ if (use_collapsed_features_) {
+ span_vals_(i,j).first = feat2val_[pf.str()];
+ span_vals_(i,j).second = feat2val_["S_" + pf.str()];
+ }
}
}
}