diff options
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/ff_spans.cc | 39 | ||||
-rw-r--r-- | decoder/ff_spans.h | 5 |
2 files changed, 43 insertions, 1 deletions
diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc index b454c9fd..06593727 100644 --- a/decoder/ff_spans.cc +++ b/decoder/ff_spans.cc @@ -3,15 +3,41 @@ #include <sstream> #include <cassert> +#include "filelib.h" #include "sentence_metadata.h" #include "lattice.h" #include "fdict.h" +#include "verbose.h" using namespace std; SpanFeatures::SpanFeatures(const string& param) : kS(TD::Convert("S") * -1), - kX(TD::Convert("X") * -1) {} + kX(TD::Convert("X") * -1) { + if (param.size() > 0) { + int lc = 0; + if (!SILENT) { cerr << "Reading word map for SpanFeatures from " << param << endl; } + ReadFile rf(param); + istream& in = *rf.stream(); + string line; + vector<WordID> v; + while(in) { + ++lc; + getline(in, line); + if (line.empty()) continue; + v.clear(); + TD::ConvertSentence(line, &v); + if (v.size() != 2) { + cerr << "Error reading line " << lc << ": " << line << endl; + abort(); + } + word2class_[v[0]] = v[1]; + } + word2class_[TD::Convert("<s>")] = TD::Convert("BOS"); + word2class_[TD::Convert("</s>")] = TD::Convert("EOS"); + oov_ = TD::Convert("OOV"); + } +} void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, @@ -37,6 +63,13 @@ void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } +WordID SpanFeatures::MapIfNecessary(const WordID& w) const { + if (word2class_.empty()) return w; + map<WordID,WordID>::const_iterator it = word2class_.find(w); + if (it == word2class_.end()) return oov_; + return it->second; +} + void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { const Lattice& lattice = smeta.GetSourceLattice(); const WordID eos = TD::Convert("</s>"); @@ -48,8 +81,10 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { WordID bword = bos; if (i > 0) bword = lattice[i-1][0].label; + bword = MapIfNecessary(bword); if (i < lattice.size()) word = lattice[i][0].label; // rather arbitrary for lattices + word = MapIfNecessary(word); ostringstream sfid; sfid << "ES:" << TD::Convert(word); end_span_ids_[i] = FD::Convert(sfid.str()); @@ -62,10 +97,12 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { WordID bword = bos; if (i > 0) bword = lattice[i-1][0].label; + bword = MapIfNecessary(bword); for (int j = 0; j <= lattice.size(); ++j) { WordID word = eos; if (j < lattice.size()) word = lattice[j][0].label; + word = MapIfNecessary(word); ostringstream pf; pf << "SS:" << TD::Convert(bword) << "_" << TD::Convert(word); span_feats_(i,j) = FD::Convert(pf.str()); diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h index 0446d062..5e90b7e0 100644 --- a/decoder/ff_spans.h +++ b/decoder/ff_spans.h @@ -2,8 +2,10 @@ #define _FF_SPANS_H_ #include <vector> +#include <map> #include "ff.h" #include "array2d.h" +#include "wordid.h" class SpanFeatures : public FeatureFunction { public: @@ -17,11 +19,14 @@ class SpanFeatures : public FeatureFunction { void* context) const; virtual void PrepareForInput(const SentenceMetadata& smeta); private: + WordID MapIfNecessary(const WordID& w) const; const int kS; const int kX; Array2D<int> span_feats_; std::vector<int> end_span_ids_; std::vector<int> beg_span_ids_; + std::map<WordID, WordID> word2class_; // optional projection to coarser class + WordID oov_; }; #endif |