summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-28 20:40:28 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-28 20:40:28 +0000
commitf96bf4df7e4a34b42373723cbe38e6c7425e3239 (patch)
tree0b57bda1e4e72ce6e679cad996e0150372dd0b25 /decoder
parent07e5fa39b4cf87f72e4e12604f26cc6c235d69d7 (diff)
rule shape features
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@46 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am1
-rw-r--r--decoder/cdec_ff.cc2
-rw-r--r--decoder/ff_ruleshape.cc104
-rw-r--r--decoder/ff_ruleshape.h31
4 files changed, 138 insertions, 0 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index a385197c..44d6adc8 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -69,6 +69,7 @@ libcdec_a_SOURCES = \
ttables.cc \
ff.cc \
ff_lm.cc \
+ ff_ruleshape.cc \
ff_wordalign.cc \
ff_csplit.cc \
ff_tagger.cc \
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index d0b93795..3b83bab3 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -6,6 +6,7 @@
#include "ff_wordalign.h"
#include "ff_tagger.h"
#include "ff_factory.h"
+#include "ff_ruleshape.h"
boost::shared_ptr<FFRegistry> global_ff_registry;
@@ -17,6 +18,7 @@ void register_feature_functions() {
global_ff_registry->Register("WordPenalty", new FFFactory<WordPenalty>);
global_ff_registry->Register("SourceWordPenalty", new FFFactory<SourceWordPenalty>);
global_ff_registry->Register("ArityPenalty", new FFFactory<ArityPenalty>);
+ global_ff_registry->Register("RuleShape", new FFFactory<RuleShapeFeatures>);
global_ff_registry->Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);
global_ff_registry->Register("Model2BinaryFeatures", new FFFactory<Model2BinaryFeatures>);
global_ff_registry->Register("MarkovJump", new FFFactory<MarkovJump>);
diff --git a/decoder/ff_ruleshape.cc b/decoder/ff_ruleshape.cc
new file mode 100644
index 00000000..d473704a
--- /dev/null
+++ b/decoder/ff_ruleshape.cc
@@ -0,0 +1,104 @@
+#include "ff_ruleshape.h"
+
+#include "fdict.h"
+#include <sstream>
+
+using namespace std;
+
+inline bool IsBitSet(int i, int bit) {
+ const int mask = 1 << bit;
+ return (i & mask);
+}
+
+inline char BitAsChar(bool bit) {
+ return (bit ? '1' : '0');
+}
+
+RuleShapeFeatures::RuleShapeFeatures(const string& param) {
+ bool first = true;
+ for (int i = 0; i < 32; ++i) {
+ for (int j = 0; j < 32; ++j) {
+ ostringstream os;
+ os << "Shape_S";
+ Node* cur = &fidtree_;
+ for (int k = 0; k < 5; ++k) {
+ bool bit = IsBitSet(i,k);
+ cur = &cur->next_[bit];
+ os << BitAsChar(bit);
+ }
+ os << "_T";
+ for (int k = 0; k < 5; ++k) {
+ bool bit = IsBitSet(j,k);
+ cur = &cur->next_[bit];
+ os << BitAsChar(bit);
+ }
+ if (first) { first = false; cerr << " Example feature: " << os.str() << endl; }
+ cur->fid_ = FD::Convert(os.str());
+ }
+ }
+}
+
+void RuleShapeFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ const Node* cur = &fidtree_;
+ TRule& rule = *edge.rule_;
+ int pos = 0; // feature position
+ int i = 0;
+ while(i < rule.f_.size()) {
+ WordID sym = rule.f_[i];
+ if (pos % 2 == 0) {
+ if (sym > 0) { // is terminal
+ cur = Advance(cur, true);
+ while (i < rule.f_.size() && rule.f_[i] > 0) ++i; // consume lexical string
+ } else {
+ cur = Advance(cur, false);
+ }
+ ++pos;
+ } else { // expecting a NT
+ if (sym < 1) {
+ cur = Advance(cur, true);
+ ++i;
+ ++pos;
+ } else {
+ cerr << "BAD RULE: " << rule.AsString() << endl;
+ exit(1);
+ }
+ }
+ }
+ for (; pos < 5; ++pos)
+ cur = Advance(cur, false);
+ assert(pos == 5); // this will fail if you are using using > binary rules!
+
+ i = 0;
+ while(i < rule.e_.size()) {
+ WordID sym = rule.e_[i];
+ if (pos % 2 == 1) {
+ if (sym > 0) { // is terminal
+ cur = Advance(cur, true);
+ while (i < rule.e_.size() && rule.e_[i] > 0) ++i; // consume lexical string
+ } else {
+ cur = Advance(cur, false);
+ }
+ ++pos;
+ } else { // expecting a NT
+ if (sym < 1) {
+ cur = Advance(cur, true);
+ ++i;
+ ++pos;
+ } else {
+ cerr << "BAD RULE: " << rule.AsString() << endl;
+ exit(1);
+ }
+ }
+ }
+ for (;pos < 10; ++pos)
+ cur = Advance(cur, false);
+ assert(pos == 10); // this will fail if you are using using > binary rules!
+
+ features->set_value(cur->fid_, 1.0);
+}
+
diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h
new file mode 100644
index 00000000..23c9827e
--- /dev/null
+++ b/decoder/ff_ruleshape.h
@@ -0,0 +1,31 @@
+#ifndef _FF_RULESHAPE_H_
+#define _FF_RULESHAPE_H_
+
+#include <vector>
+#include "ff.h"
+
+class RuleShapeFeatures : public FeatureFunction {
+ public:
+ RuleShapeFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ private:
+ struct Node {
+ int fid_;
+ Node() : fid_(-1) {}
+ std::map<bool, Node> next_;
+ };
+ Node fidtree_;
+ static const Node* Advance(const Node* cur, bool val) {
+ std::map<bool, Node>::const_iterator it = cur->next_.find(val);
+ if (it == cur->next_.end()) return NULL;
+ return &it->second;
+ }
+};
+
+#endif