diff options
-rw-r--r-- | decoder/cdec_ff.cc | 1 | ||||
-rw-r--r-- | decoder/ff_wordalign.cc | 62 | ||||
-rw-r--r-- | decoder/ff_wordalign.h | 17 |
3 files changed, 80 insertions, 0 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index e87ab5ab..729d1214 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -60,6 +60,7 @@ void register_feature_functions() { ff_registry.Register("Tagger_BigramIdentity", new FFFactory<Tagger_BigramIdentity>); ff_registry.Register("LexicalPairIdentity", new FFFactory<LexicalPairIdentity>); ff_registry.Register("OutputIdentity", new FFFactory<OutputIdentity>); + ff_registry.Register("IdentityCycleDetector", new FFFactory<IdentityCycleDetector>); ff_registry.Register("InputIdentity", new FFFactory<InputIdentity>); ff_registry.Register("LexicalTranslationTrigger", new FFFactory<LexicalTranslationTrigger>); ff_registry.Register("WordPairFeatures", new FFFactory<WordPairFeatures>); diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc index 338f1a72..ef3310b4 100644 --- a/decoder/ff_wordalign.cc +++ b/decoder/ff_wordalign.cc @@ -556,6 +556,68 @@ void BlunsomSynchronousParseHack::TraversalFeaturesImpl(const SentenceMetadata& SetStateMask(it->second, it->second + yield.size(), state); } +IdentityCycleDetector::IdentityCycleDetector(const std::string& param) : FeatureFunction(2) { + length_min_ = 3; + if (!param.empty()) + length_min_ = atoi(param.c_str()); + assert(length_min_ >= 0); + ostringstream os; + os << "IdentityCycle_LenGT" << length_min_; + fid_ = FD::Convert(os.str()); +} + +inline bool IsIdentityTranslation(const void* state) { + return static_cast<const unsigned char*>(state)[0]; +} + +inline int SourceIndex(const void* state) { + unsigned char i = static_cast<const unsigned char*>(state)[1]; + if (i == 255) return -1; + return i; +} + +void IdentityCycleDetector::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { + unsigned char* out_state = static_cast<unsigned char*>(context); + unsigned char& out_is_identity = out_state[0]; + unsigned char& out_src_index = out_state[1]; + + if (edge.Arity() == 0) { + assert(edge.rule_->EWords() == 1); + assert(edge.rule_->FWords() == 1); + out_src_index = edge.i_; + out_is_identity = false; + if (edge.rule_->e_[0] == edge.rule_->f_[0]) { + const WordID word = edge.rule_->e_[0]; + static map<WordID, bool> big_enough; + map<WordID,bool>::iterator it = big_enough_.find(word); + if (it == big_enough_.end()) { + out_is_identity = big_enough_[word] = strlen(TD::Convert(word)) >= length_min_; + } else { + out_is_identity = it->second; + } + } + } else if (edge.Arity() == 1) { + memcpy(context, ant_contexts[0], 2); + } else if (edge.Arity() == 2) { + bool left_identity = IsIdentityTranslation(ant_contexts[0]); + int left_index = SourceIndex(ant_contexts[0]); + bool right_identity = IsIdentityTranslation(ant_contexts[1]); + int right_index = SourceIndex(ant_contexts[1]); + if ((left_identity && left_index == right_index && !right_identity) || + (right_identity && left_index == right_index && !left_identity)) { + features->set_value(fid_, 1.0); + } + out_is_identity = right_identity; + out_src_index = right_index; + } else { assert("really really bad"); } +} + + InputIdentity::InputIdentity(const std::string& param) {} void InputIdentity::FireFeature(WordID src, diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h index a1ffd9ca..8035000e 100644 --- a/decoder/ff_wordalign.h +++ b/decoder/ff_wordalign.h @@ -237,6 +237,23 @@ class WordPairFeatures : public FeatureFunction { std::vector<std::map<WordID, SparseVector<float> > > values_; // fkeys_index -> e -> value }; +// fires when a len(word) >= length_min_ is translated as itself and then a self-transition is made +class IdentityCycleDetector : public FeatureFunction { + public: + IdentityCycleDetector(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + private: + int length_min_; + int fid_; + mutable std::map<WordID, bool> big_enough_; +}; + class InputIdentity : public FeatureFunction { public: InputIdentity(const std::string& param); |