From 949ad1a3938d98d4366fbf3e487acc8fc4134ce1 Mon Sep 17 00:00:00 2001
From: Chris Dyer <cdyer@cs.cmu.edu>
Date: Tue, 22 Feb 2011 17:59:31 -0500
Subject: support span scores

---
 decoder/ff_spans.cc | 79 ++++++++++++++++++++++++++++++++++++++++-------------
 decoder/ff_spans.h  | 12 ++++++++
 2 files changed, 72 insertions(+), 19 deletions(-)

diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc
index 989eed56..511ab728 100644
--- a/decoder/ff_spans.cc
+++ b/decoder/ff_spans.cc
@@ -4,6 +4,7 @@
 #include <cassert>
 
 #include "filelib.h"
+#include "stringlib.h"
 #include "sentence_metadata.h"
 #include "lattice.h"
 #include "fdict.h"
@@ -12,12 +13,18 @@
 using namespace std;
 
 SpanFeatures::SpanFeatures(const string& param) :
-  kS(TD::Convert("S") * -1),
-  kX(TD::Convert("X") * -1) {
-  if (param.size() > 0) {
+    kS(TD::Convert("S") * -1),
+    kX(TD::Convert("X") * -1),
+    use_collapsed_features_(false) {
+  string mapfile = param;
+  string valfile;
+  vector<string> toks;
+  Tokenize(param, ' ', &toks);
+  if (toks.size() == 2) { mapfile = param[0]; valfile = param[1]; }
+  if (mapfile.size() > 0) {
     int lc = 0;
     if (!SILENT) { cerr << "Reading word map for SpanFeatures from " << param << endl; }
-    ReadFile rf(param);
+    ReadFile rf(mapfile);
     istream& in = *rf.stream();
     string line;
     vector<WordID> v;
@@ -37,6 +44,27 @@ SpanFeatures::SpanFeatures(const string& param) :
     word2class_[TD::Convert("</s>")] = TD::Convert("EOS");
     oov_ = TD::Convert("OOV");
   }
+
+  if (valfile.size() > 0) {
+    use_collapsed_features_ = true;
+    fid_beg_ = FD::Convert("SpanBegin");
+    fid_end_ = FD::Convert("SpanEnd");
+    fid_span_s_ = FD::Convert("SSpanContext");
+    fid_span_ = FD::Convert("XSpanContext");
+    ReadFile rf(valfile);
+    if (!SILENT) { cerr << "  Loading span scores from " << valfile << endl; }
+    istream& in = *rf.stream();
+    string line;
+    while(in) {
+      getline(in, line);
+      if (line.size() == 0 || line[0] == '#') { continue; }
+      istringstream in(line);
+      string feat_name;
+      double weight;
+      in >> feat_name >> weight;
+      feat2val_[feat_name] = weight;
+    }
+  }
 }
 
 void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
@@ -45,24 +73,24 @@ void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
                                          SparseVector<double>* features,
                                          SparseVector<double>* estimated_features,
                                          void* context) const {
-//  char& res = *static_cast<char*>(context);
-//  res = edge.j_ - edge.i_;
-//  assert(res >= 0);
   assert(edge.j_ < end_span_ids_.size());
   assert(edge.j_ >= 0);
-  features->set_value(end_span_ids_[edge.j_], 1);
   assert(edge.i_ < beg_span_ids_.size());
   assert(edge.i_ >= 0);
-  features->set_value(beg_span_ids_[edge.i_], 1);
-  if (edge.rule_->lhs_ == kS)
-    features->set_value(span_feats_(edge.i_,edge.j_).second, 1);
-  else
-    features->set_value(span_feats_(edge.i_,edge.j_).first, 1);
-  if (edge.Arity() == 2) {
-    const TRule& rule = *edge.rule_;
-    if (rule.f_[0] == kS && rule.f_[1] == kX) {
-//      char x_width = *static_cast<const char*>(ant_contexts[1]);
-    }
+  if (use_collapsed_features_) {
+    features->set_value(fid_end_, end_span_vals_[edge.j_]);
+    features->set_value(fid_beg_, beg_span_vals_[edge.i_]);
+    if (edge.rule_->lhs_ == kS)
+      features->set_value(fid_span_s_, span_vals_(edge.i_,edge.j_).second);
+    else
+      features->set_value(fid_span_, span_vals_(edge.i_,edge.j_).first);
+  } else {  // non-collapsed features:
+    features->set_value(end_span_ids_[edge.j_], 1);
+    features->set_value(beg_span_ids_[edge.i_], 1);
+    if (edge.rule_->lhs_ == kS)
+      features->set_value(span_feats_(edge.i_,edge.j_).second, 1);
+    else
+      features->set_value(span_feats_(edge.i_,edge.j_).first, 1);
   }
 }
 
@@ -79,6 +107,12 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
   const WordID bos = TD::Convert("<s>");
   beg_span_ids_.resize(lattice.size() + 1);
   end_span_ids_.resize(lattice.size() + 1);
+  span_feats_.resize(lattice.size() + 1, lattice.size() + 1);
+  if (use_collapsed_features_) {
+    beg_span_vals_.resize(lattice.size() + 1);
+    end_span_vals_.resize(lattice.size() + 1);
+    span_vals_.resize(lattice.size() + 1, lattice.size() + 1);
+  }
   for (int i = 0; i <= lattice.size(); ++i) {
     WordID word = eos;
     WordID bword = bos;
@@ -94,8 +128,11 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
     ostringstream bfid;
     bfid << "BS:" << TD::Convert(bword);
     beg_span_ids_[i] = FD::Convert(bfid.str());
+    if (use_collapsed_features_) {
+      end_span_vals_[i] = feat2val_[sfid.str()];
+      beg_span_vals_[i] = feat2val_[bfid.str()];
+    }
   }
-  span_feats_.resize(lattice.size() + 1, lattice.size() + 1);
   for (int i = 0; i <= lattice.size(); ++i) {
     WordID bword = bos;
     if (i > 0)
@@ -110,6 +147,10 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
       pf << "SS:" << TD::Convert(bword) << "_" << TD::Convert(word);
       span_feats_(i,j).first = FD::Convert(pf.str());
       span_feats_(i,j).second = FD::Convert("S_" + pf.str());
+      if (use_collapsed_features_) {
+        span_vals_(i,j).first = feat2val_[pf.str()];
+        span_vals_(i,j).second = feat2val_["S_" + pf.str()];
+      }
     }
   } 
 }
diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h
index 4bf5d7e7..b93faec5 100644
--- a/decoder/ff_spans.h
+++ b/decoder/ff_spans.h
@@ -26,6 +26,18 @@ class SpanFeatures : public FeatureFunction {
   std::vector<int> end_span_ids_;
   std::vector<int> beg_span_ids_;
   std::map<WordID, WordID> word2class_;  // optional projection to coarser class
+
+  // collapsed feature values
+  bool use_collapsed_features_;
+  int fid_beg_;
+  int fid_end_;
+  int fid_span_s_;
+  int fid_span_;
+  std::map<std::string, double> feat2val_;
+  std::vector<double> end_span_vals_;
+  std::vector<double> beg_span_vals_;
+  Array2D<std::pair<double,double> > span_vals_;
+
   WordID oov_;
 };
 
-- 
cgit v1.2.3