summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2010-12-15 00:26:55 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2010-12-15 00:26:55 -0500
commitf2814314a1245fa0da3cba248cbe59b7f7cd87a8 (patch)
tree3eec12f65b780f4246b9527fd880429e2418cf05
parentfb38d57a76ec3e39de13f1fe2b4fd4abb7a5457c (diff)
feature to detect self-transition before/after identity translations
-rw-r--r--decoder/cdec_ff.cc1
-rw-r--r--decoder/ff_wordalign.cc62
-rw-r--r--decoder/ff_wordalign.h17
3 files changed, 80 insertions, 0 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index e87ab5ab..729d1214 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -60,6 +60,7 @@ void register_feature_functions() {
ff_registry.Register("Tagger_BigramIdentity", new FFFactory<Tagger_BigramIdentity>);
ff_registry.Register("LexicalPairIdentity", new FFFactory<LexicalPairIdentity>);
ff_registry.Register("OutputIdentity", new FFFactory<OutputIdentity>);
+ ff_registry.Register("IdentityCycleDetector", new FFFactory<IdentityCycleDetector>);
ff_registry.Register("InputIdentity", new FFFactory<InputIdentity>);
ff_registry.Register("LexicalTranslationTrigger", new FFFactory<LexicalTranslationTrigger>);
ff_registry.Register("WordPairFeatures", new FFFactory<WordPairFeatures>);
diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc
index 338f1a72..ef3310b4 100644
--- a/decoder/ff_wordalign.cc
+++ b/decoder/ff_wordalign.cc
@@ -556,6 +556,68 @@ void BlunsomSynchronousParseHack::TraversalFeaturesImpl(const SentenceMetadata&
SetStateMask(it->second, it->second + yield.size(), state);
}
+IdentityCycleDetector::IdentityCycleDetector(const std::string& param) : FeatureFunction(2) {
+ length_min_ = 3;
+ if (!param.empty())
+ length_min_ = atoi(param.c_str());
+ assert(length_min_ >= 0);
+ ostringstream os;
+ os << "IdentityCycle_LenGT" << length_min_;
+ fid_ = FD::Convert(os.str());
+}
+
+inline bool IsIdentityTranslation(const void* state) {
+ return static_cast<const unsigned char*>(state)[0];
+}
+
+inline int SourceIndex(const void* state) {
+ unsigned char i = static_cast<const unsigned char*>(state)[1];
+ if (i == 255) return -1;
+ return i;
+}
+
+void IdentityCycleDetector::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ unsigned char* out_state = static_cast<unsigned char*>(context);
+ unsigned char& out_is_identity = out_state[0];
+ unsigned char& out_src_index = out_state[1];
+
+ if (edge.Arity() == 0) {
+ assert(edge.rule_->EWords() == 1);
+ assert(edge.rule_->FWords() == 1);
+ out_src_index = edge.i_;
+ out_is_identity = false;
+ if (edge.rule_->e_[0] == edge.rule_->f_[0]) {
+ const WordID word = edge.rule_->e_[0];
+ static map<WordID, bool> big_enough;
+ map<WordID,bool>::iterator it = big_enough_.find(word);
+ if (it == big_enough_.end()) {
+ out_is_identity = big_enough_[word] = strlen(TD::Convert(word)) >= length_min_;
+ } else {
+ out_is_identity = it->second;
+ }
+ }
+ } else if (edge.Arity() == 1) {
+ memcpy(context, ant_contexts[0], 2);
+ } else if (edge.Arity() == 2) {
+ bool left_identity = IsIdentityTranslation(ant_contexts[0]);
+ int left_index = SourceIndex(ant_contexts[0]);
+ bool right_identity = IsIdentityTranslation(ant_contexts[1]);
+ int right_index = SourceIndex(ant_contexts[1]);
+ if ((left_identity && left_index == right_index && !right_identity) ||
+ (right_identity && left_index == right_index && !left_identity)) {
+ features->set_value(fid_, 1.0);
+ }
+ out_is_identity = right_identity;
+ out_src_index = right_index;
+ } else { assert("really really bad"); }
+}
+
+
InputIdentity::InputIdentity(const std::string& param) {}
void InputIdentity::FireFeature(WordID src,
diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h
index a1ffd9ca..8035000e 100644
--- a/decoder/ff_wordalign.h
+++ b/decoder/ff_wordalign.h
@@ -237,6 +237,23 @@ class WordPairFeatures : public FeatureFunction {
std::vector<std::map<WordID, SparseVector<float> > > values_; // fkeys_index -> e -> value
};
+// fires when a len(word) >= length_min_ is translated as itself and then a self-transition is made
+class IdentityCycleDetector : public FeatureFunction {
+ public:
+ IdentityCycleDetector(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ private:
+ int length_min_;
+ int fid_;
+ mutable std::map<WordID, bool> big_enough_;
+};
+
class InputIdentity : public FeatureFunction {
public:
InputIdentity(const std::string& param);