diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-09-17 17:06:40 +0100 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-09-17 17:06:40 +0100 |
commit | 10cfa1082059db646148af1884117082335a48e7 (patch) | |
tree | ba128eeadab0e9ff589449bcd43b7ae1ca03c4e2 | |
parent | ce830ec51477f345c811987e11a9ed4322edcac0 (diff) |
source span size features
-rw-r--r-- | decoder/ff_source_syntax.cc | 62 | ||||
-rw-r--r-- | decoder/ff_source_syntax.h | 17 |
2 files changed, 79 insertions, 0 deletions
diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index ffe07f03..2df31c3a 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -157,3 +157,65 @@ void SourceSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) { impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); } +struct SourceSpanSizeFeaturesImpl { + SourceSpanSizeFeaturesImpl() {} + + void InitializeGrids(unsigned src_len) { + fids.clear(); + fids.resize(src_len, src_len + 1); + } + + int FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector<double>* feats) { + int& fid = fids(i,j)[&rule]; + if (fid <= 0) { + ostringstream os; + os << "SSS:"; + unsigned ntc = 0; + for (unsigned k = 0; k < rule.f_.size(); ++k) { + if (k > 0) os << '_'; + int fj = rule.f_[k]; + if (fj <= 0) { + os << '[' << TD::Convert(-fj) << ants[ntc++] << ']'; + } else { + os << TD::Convert(fj); + } + } + fid = FD::Convert(os.str()); + } + if (fid > 0) + feats->set_value(fid, 1.0); + return SpanSizeTransform(j - i); + } + + mutable Array2D<map<const TRule*, int> > fids; +}; + +SourceSpanSizeFeatures::SourceSpanSizeFeatures(const string& param) : + FeatureFunction(sizeof(char)) { + impl = new SourceSpanSizeFeaturesImpl; +} + +SourceSpanSizeFeatures::~SourceSpanSizeFeatures() { + delete impl; + impl = NULL; +} + +void SourceSpanSizeFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { + int ants[8]; + for (unsigned i = 0; i < ant_contexts.size(); ++i) + ants[i] = *static_cast<const char*>(ant_contexts[i]); + + *static_cast<char*>(context) = + impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features); +} + +void SourceSpanSizeFeatures::PrepareForInput(const SentenceMetadata& smeta) { + impl->InitializeGrids(smeta.GetSourceLength()); +} + + diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h index 1e890736..279563e1 100644 --- a/decoder/ff_source_syntax.h +++ b/decoder/ff_source_syntax.h @@ -21,4 +21,21 @@ class SourceSyntaxFeatures : public FeatureFunction { SourceSyntaxFeaturesImpl* impl; }; +struct SourceSpanSizeFeaturesImpl; +class SourceSpanSizeFeatures : public FeatureFunction { + public: + SourceSpanSizeFeatures(const std::string& param); + ~SourceSpanSizeFeatures(); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + SourceSpanSizeFeaturesImpl* impl; +}; + #endif |