summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-09-17 17:06:40 +0100
committerChris Dyer <cdyer@cs.cmu.edu>2011-09-17 17:06:40 +0100
commit7f9efe25a3c7e09141394a132ab82cdd8237e00a (patch)
treed0c22e06ccc76b4d4afd65fd269298a00d65382e /decoder
parentc1d44af44ce582a34320c8d81e3b62f71416de34 (diff)
source span size features
Diffstat (limited to 'decoder')
-rw-r--r--decoder/ff_source_syntax.cc62
-rw-r--r--decoder/ff_source_syntax.h17
2 files changed, 79 insertions, 0 deletions
diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc
index ffe07f03..2df31c3a 100644
--- a/decoder/ff_source_syntax.cc
+++ b/decoder/ff_source_syntax.cc
@@ -157,3 +157,65 @@ void SourceSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) {
impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength());
}
+struct SourceSpanSizeFeaturesImpl {
+ SourceSpanSizeFeaturesImpl() {}
+
+ void InitializeGrids(unsigned src_len) {
+ fids.clear();
+ fids.resize(src_len, src_len + 1);
+ }
+
+ int FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector<double>* feats) {
+ int& fid = fids(i,j)[&rule];
+ if (fid <= 0) {
+ ostringstream os;
+ os << "SSS:";
+ unsigned ntc = 0;
+ for (unsigned k = 0; k < rule.f_.size(); ++k) {
+ if (k > 0) os << '_';
+ int fj = rule.f_[k];
+ if (fj <= 0) {
+ os << '[' << TD::Convert(-fj) << ants[ntc++] << ']';
+ } else {
+ os << TD::Convert(fj);
+ }
+ }
+ fid = FD::Convert(os.str());
+ }
+ if (fid > 0)
+ feats->set_value(fid, 1.0);
+ return SpanSizeTransform(j - i);
+ }
+
+ mutable Array2D<map<const TRule*, int> > fids;
+};
+
+SourceSpanSizeFeatures::SourceSpanSizeFeatures(const string& param) :
+ FeatureFunction(sizeof(char)) {
+ impl = new SourceSpanSizeFeaturesImpl;
+}
+
+SourceSpanSizeFeatures::~SourceSpanSizeFeatures() {
+ delete impl;
+ impl = NULL;
+}
+
+void SourceSpanSizeFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ int ants[8];
+ for (unsigned i = 0; i < ant_contexts.size(); ++i)
+ ants[i] = *static_cast<const char*>(ant_contexts[i]);
+
+ *static_cast<char*>(context) =
+ impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features);
+}
+
+void SourceSpanSizeFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+ impl->InitializeGrids(smeta.GetSourceLength());
+}
+
+
diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h
index 1e890736..279563e1 100644
--- a/decoder/ff_source_syntax.h
+++ b/decoder/ff_source_syntax.h
@@ -21,4 +21,21 @@ class SourceSyntaxFeatures : public FeatureFunction {
SourceSyntaxFeaturesImpl* impl;
};
+struct SourceSpanSizeFeaturesImpl;
+class SourceSpanSizeFeatures : public FeatureFunction {
+ public:
+ SourceSpanSizeFeatures(const std::string& param);
+ ~SourceSpanSizeFeatures();
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ private:
+ SourceSpanSizeFeaturesImpl* impl;
+};
+
#endif