summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-02-17 23:41:29 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-02-17 23:41:29 -0500
commitb6a372a529093975464e23cb01cb2a9ee3e7b0c7 (patch)
tree38f17b74c0afb6b14a4a00d013298e6fe9073b66
parentb53636855b8370fd3a6f61041d351c3d14f952c6 (diff)
more spans
-rw-r--r--decoder/ff_spans.cc39
-rw-r--r--decoder/ff_spans.h5
2 files changed, 43 insertions, 1 deletions
diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc
index b454c9fd..06593727 100644
--- a/decoder/ff_spans.cc
+++ b/decoder/ff_spans.cc
@@ -3,15 +3,41 @@
#include <sstream>
#include <cassert>
+#include "filelib.h"
#include "sentence_metadata.h"
#include "lattice.h"
#include "fdict.h"
+#include "verbose.h"
using namespace std;
SpanFeatures::SpanFeatures(const string& param) :
kS(TD::Convert("S") * -1),
- kX(TD::Convert("X") * -1) {}
+ kX(TD::Convert("X") * -1) {
+ if (param.size() > 0) {
+ int lc = 0;
+ if (!SILENT) { cerr << "Reading word map for SpanFeatures from " << param << endl; }
+ ReadFile rf(param);
+ istream& in = *rf.stream();
+ string line;
+ vector<WordID> v;
+ while(in) {
+ ++lc;
+ getline(in, line);
+ if (line.empty()) continue;
+ v.clear();
+ TD::ConvertSentence(line, &v);
+ if (v.size() != 2) {
+ cerr << "Error reading line " << lc << ": " << line << endl;
+ abort();
+ }
+ word2class_[v[0]] = v[1];
+ }
+ word2class_[TD::Convert("<s>")] = TD::Convert("BOS");
+ word2class_[TD::Convert("</s>")] = TD::Convert("EOS");
+ oov_ = TD::Convert("OOV");
+ }
+}
void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
@@ -37,6 +63,13 @@ void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
}
+WordID SpanFeatures::MapIfNecessary(const WordID& w) const {
+ if (word2class_.empty()) return w;
+ map<WordID,WordID>::const_iterator it = word2class_.find(w);
+ if (it == word2class_.end()) return oov_;
+ return it->second;
+}
+
void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
const Lattice& lattice = smeta.GetSourceLattice();
const WordID eos = TD::Convert("</s>");
@@ -48,8 +81,10 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
WordID bword = bos;
if (i > 0)
bword = lattice[i-1][0].label;
+ bword = MapIfNecessary(bword);
if (i < lattice.size())
word = lattice[i][0].label; // rather arbitrary for lattices
+ word = MapIfNecessary(word);
ostringstream sfid;
sfid << "ES:" << TD::Convert(word);
end_span_ids_[i] = FD::Convert(sfid.str());
@@ -62,10 +97,12 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
WordID bword = bos;
if (i > 0)
bword = lattice[i-1][0].label;
+ bword = MapIfNecessary(bword);
for (int j = 0; j <= lattice.size(); ++j) {
WordID word = eos;
if (j < lattice.size())
word = lattice[j][0].label;
+ word = MapIfNecessary(word);
ostringstream pf;
pf << "SS:" << TD::Convert(bword) << "_" << TD::Convert(word);
span_feats_(i,j) = FD::Convert(pf.str());
diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h
index 0446d062..5e90b7e0 100644
--- a/decoder/ff_spans.h
+++ b/decoder/ff_spans.h
@@ -2,8 +2,10 @@
#define _FF_SPANS_H_
#include <vector>
+#include <map>
#include "ff.h"
#include "array2d.h"
+#include "wordid.h"
class SpanFeatures : public FeatureFunction {
public:
@@ -17,11 +19,14 @@ class SpanFeatures : public FeatureFunction {
void* context) const;
virtual void PrepareForInput(const SentenceMetadata& smeta);
private:
+ WordID MapIfNecessary(const WordID& w) const;
const int kS;
const int kX;
Array2D<int> span_feats_;
std::vector<int> end_span_ids_;
std::vector<int> beg_span_ids_;
+ std::map<WordID, WordID> word2class_; // optional projection to coarser class
+ WordID oov_;
};
#endif