summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-06-01 01:26:55 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2011-06-01 01:26:55 -0400
commit2872eae2f95fb702c3f0a4a11eb7a70efb246dd9 (patch)
tree312e620c0530b3ca10105b252ea667a892eeed27
parentc135e494b7a49d20b51f9c181f104d9adb0099af (diff)
rule bigram features
-rw-r--r--decoder/ff_spans.cc39
-rw-r--r--decoder/ff_spans.h15
2 files changed, 54 insertions, 0 deletions
diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc
index 89335682..e1da088d 100644
--- a/decoder/ff_spans.cc
+++ b/decoder/ff_spans.cc
@@ -182,6 +182,45 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
}
}
+RuleNgramFeatures::RuleNgramFeatures(const std::string& param) {
+}
+
+void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+// std::map<const TRule*, SparseVector<double> >
+ rule2_feats_.clear();
+}
+
+void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ map<const TRule*, SparseVector<double> >::iterator it = rule2_feats_.find(edge.rule_.get());
+ if (it == rule2_feats_.end()) {
+ const TRule& rule = *edge.rule_;
+ it = rule2_feats_.insert(make_pair(&rule, SparseVector<double>())).first;
+ SparseVector<double>& f = it->second;
+ string prev = "<r>";
+ for (int i = 0; i < rule.f_.size(); ++i) {
+ WordID w = rule.f_[i];
+ if (w < 0) w = -w;
+ assert(w > 0);
+ const string& cur = TD::Convert(w);
+ ostringstream os;
+ os << "RB:" << prev << '_' << cur;
+ const int fid = FD::Convert(os.str());
+ if (fid <= 0) return;
+ f.add_value(fid, 1.0);
+ prev = cur;
+ }
+ ostringstream os;
+ os << "RB:" << prev << '_' << "</r>";
+ f.set_value(FD::Convert(os.str()), 1.0);
+ }
+ (*features) += it->second;
+}
+
inline bool IsArity2RuleReordered(const TRule& rule) {
const vector<WordID>& e = rule.e_;
for (int i = 0; i < e.size(); ++i) {
diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h
index 24e0dede..b22c4d03 100644
--- a/decoder/ff_spans.h
+++ b/decoder/ff_spans.h
@@ -44,6 +44,21 @@ class SpanFeatures : public FeatureFunction {
WordID oov_;
};
+class RuleNgramFeatures : public FeatureFunction {
+ public:
+ RuleNgramFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ private:
+ mutable std::map<const TRule*, SparseVector<double> > rule2_feats_;
+};
+
class CMR2008ReorderingFeatures : public FeatureFunction {
public:
CMR2008ReorderingFeatures(const std::string& param);