diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/ff_spans.cc | 39 | ||||
| -rw-r--r-- | decoder/ff_spans.h | 5 | 
2 files changed, 43 insertions, 1 deletions
diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc index b454c9fd..06593727 100644 --- a/decoder/ff_spans.cc +++ b/decoder/ff_spans.cc @@ -3,15 +3,41 @@  #include <sstream>  #include <cassert> +#include "filelib.h"  #include "sentence_metadata.h"  #include "lattice.h"  #include "fdict.h" +#include "verbose.h"  using namespace std;  SpanFeatures::SpanFeatures(const string& param) :    kS(TD::Convert("S") * -1), -  kX(TD::Convert("X") * -1) {} +  kX(TD::Convert("X") * -1) { +  if (param.size() > 0) { +    int lc = 0; +    if (!SILENT) { cerr << "Reading word map for SpanFeatures from " << param << endl; } +    ReadFile rf(param); +    istream& in = *rf.stream(); +    string line; +    vector<WordID> v; +    while(in) { +      ++lc; +      getline(in, line); +      if (line.empty()) continue; +      v.clear(); +      TD::ConvertSentence(line, &v); +      if (v.size() != 2) { +        cerr << "Error reading line " << lc << ": " << line << endl; +        abort(); +      } +      word2class_[v[0]] = v[1]; +    } +    word2class_[TD::Convert("<s>")] = TD::Convert("BOS"); +    word2class_[TD::Convert("</s>")] = TD::Convert("EOS"); +    oov_ = TD::Convert("OOV"); +  } +}  void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,                                           const Hypergraph::Edge& edge, @@ -37,6 +63,13 @@ void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,    }  } +WordID SpanFeatures::MapIfNecessary(const WordID& w) const { +  if (word2class_.empty()) return w; +  map<WordID,WordID>::const_iterator it = word2class_.find(w); +  if (it == word2class_.end()) return oov_; +  return it->second; +} +  void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {    const Lattice& lattice = smeta.GetSourceLattice();    const WordID eos = TD::Convert("</s>"); @@ -48,8 +81,10 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {      WordID bword = bos;      if (i > 0)        bword = lattice[i-1][0].label; +    bword = MapIfNecessary(bword);      if (i < lattice.size())        word = lattice[i][0].label;  // rather arbitrary for lattices +    word = MapIfNecessary(word);      ostringstream sfid;      sfid << "ES:" << TD::Convert(word);      end_span_ids_[i] = FD::Convert(sfid.str()); @@ -62,10 +97,12 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {      WordID bword = bos;      if (i > 0)        bword = lattice[i-1][0].label; +    bword = MapIfNecessary(bword);      for (int j = 0; j <= lattice.size(); ++j) {        WordID word = eos;        if (j < lattice.size())          word = lattice[j][0].label; +      word = MapIfNecessary(word);        ostringstream pf;        pf << "SS:" << TD::Convert(bword) << "_" << TD::Convert(word);        span_feats_(i,j) = FD::Convert(pf.str()); diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h index 0446d062..5e90b7e0 100644 --- a/decoder/ff_spans.h +++ b/decoder/ff_spans.h @@ -2,8 +2,10 @@  #define _FF_SPANS_H_  #include <vector> +#include <map>  #include "ff.h"  #include "array2d.h" +#include "wordid.h"  class SpanFeatures : public FeatureFunction {   public: @@ -17,11 +19,14 @@ class SpanFeatures : public FeatureFunction {                                       void* context) const;    virtual void PrepareForInput(const SentenceMetadata& smeta);   private: +  WordID MapIfNecessary(const WordID& w) const;    const int kS;    const int kX;    Array2D<int> span_feats_;    std::vector<int> end_span_ids_;    std::vector<int> beg_span_ids_; +  std::map<WordID, WordID> word2class_;  // optional projection to coarser class +  WordID oov_;  };  #endif  | 
