From f96bf4df7e4a34b42373723cbe38e6c7425e3239 Mon Sep 17 00:00:00 2001 From: redpony Date: Mon, 28 Jun 2010 20:40:28 +0000 Subject: rule shape features git-svn-id: https://ws10smt.googlecode.com/svn/trunk@46 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/Makefile.am | 1 + decoder/cdec_ff.cc | 2 + decoder/ff_ruleshape.cc | 104 ++++++++++++++++++++++++++++++++++++++++++++++++ decoder/ff_ruleshape.h | 31 +++++++++++++++ 4 files changed, 138 insertions(+) create mode 100644 decoder/ff_ruleshape.cc create mode 100644 decoder/ff_ruleshape.h (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index a385197c..44d6adc8 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -69,6 +69,7 @@ libcdec_a_SOURCES = \ ttables.cc \ ff.cc \ ff_lm.cc \ + ff_ruleshape.cc \ ff_wordalign.cc \ ff_csplit.cc \ ff_tagger.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index d0b93795..3b83bab3 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -6,6 +6,7 @@ #include "ff_wordalign.h" #include "ff_tagger.h" #include "ff_factory.h" +#include "ff_ruleshape.h" boost::shared_ptr global_ff_registry; @@ -17,6 +18,7 @@ void register_feature_functions() { global_ff_registry->Register("WordPenalty", new FFFactory); global_ff_registry->Register("SourceWordPenalty", new FFFactory); global_ff_registry->Register("ArityPenalty", new FFFactory); + global_ff_registry->Register("RuleShape", new FFFactory); global_ff_registry->Register("RelativeSentencePosition", new FFFactory); global_ff_registry->Register("Model2BinaryFeatures", new FFFactory); global_ff_registry->Register("MarkovJump", new FFFactory); diff --git a/decoder/ff_ruleshape.cc b/decoder/ff_ruleshape.cc new file mode 100644 index 00000000..d473704a --- /dev/null +++ b/decoder/ff_ruleshape.cc @@ -0,0 +1,104 @@ +#include "ff_ruleshape.h" + +#include "fdict.h" +#include + +using namespace std; + +inline bool IsBitSet(int i, int bit) { + const int mask = 1 << bit; + return (i & mask); +} + +inline char BitAsChar(bool bit) { + return (bit ? '1' : '0'); +} + +RuleShapeFeatures::RuleShapeFeatures(const string& param) { + bool first = true; + for (int i = 0; i < 32; ++i) { + for (int j = 0; j < 32; ++j) { + ostringstream os; + os << "Shape_S"; + Node* cur = &fidtree_; + for (int k = 0; k < 5; ++k) { + bool bit = IsBitSet(i,k); + cur = &cur->next_[bit]; + os << BitAsChar(bit); + } + os << "_T"; + for (int k = 0; k < 5; ++k) { + bool bit = IsBitSet(j,k); + cur = &cur->next_[bit]; + os << BitAsChar(bit); + } + if (first) { first = false; cerr << " Example feature: " << os.str() << endl; } + cur->fid_ = FD::Convert(os.str()); + } + } +} + +void RuleShapeFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + const Node* cur = &fidtree_; + TRule& rule = *edge.rule_; + int pos = 0; // feature position + int i = 0; + while(i < rule.f_.size()) { + WordID sym = rule.f_[i]; + if (pos % 2 == 0) { + if (sym > 0) { // is terminal + cur = Advance(cur, true); + while (i < rule.f_.size() && rule.f_[i] > 0) ++i; // consume lexical string + } else { + cur = Advance(cur, false); + } + ++pos; + } else { // expecting a NT + if (sym < 1) { + cur = Advance(cur, true); + ++i; + ++pos; + } else { + cerr << "BAD RULE: " << rule.AsString() << endl; + exit(1); + } + } + } + for (; pos < 5; ++pos) + cur = Advance(cur, false); + assert(pos == 5); // this will fail if you are using using > binary rules! + + i = 0; + while(i < rule.e_.size()) { + WordID sym = rule.e_[i]; + if (pos % 2 == 1) { + if (sym > 0) { // is terminal + cur = Advance(cur, true); + while (i < rule.e_.size() && rule.e_[i] > 0) ++i; // consume lexical string + } else { + cur = Advance(cur, false); + } + ++pos; + } else { // expecting a NT + if (sym < 1) { + cur = Advance(cur, true); + ++i; + ++pos; + } else { + cerr << "BAD RULE: " << rule.AsString() << endl; + exit(1); + } + } + } + for (;pos < 10; ++pos) + cur = Advance(cur, false); + assert(pos == 10); // this will fail if you are using using > binary rules! + + features->set_value(cur->fid_, 1.0); +} + diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h new file mode 100644 index 00000000..23c9827e --- /dev/null +++ b/decoder/ff_ruleshape.h @@ -0,0 +1,31 @@ +#ifndef _FF_RULESHAPE_H_ +#define _FF_RULESHAPE_H_ + +#include +#include "ff.h" + +class RuleShapeFeatures : public FeatureFunction { + public: + RuleShapeFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + struct Node { + int fid_; + Node() : fid_(-1) {} + std::map next_; + }; + Node fidtree_; + static const Node* Advance(const Node* cur, bool val) { + std::map::const_iterator it = cur->next_.find(val); + if (it == cur->next_.end()) return NULL; + return &it->second; + } +}; + +#endif -- cgit v1.2.3