diff options
Diffstat (limited to 'decoder/ff_wordalign.cc')
-rw-r--r-- | decoder/ff_wordalign.cc | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc index 338f1a72..ef3310b4 100644 --- a/decoder/ff_wordalign.cc +++ b/decoder/ff_wordalign.cc @@ -556,6 +556,68 @@ void BlunsomSynchronousParseHack::TraversalFeaturesImpl(const SentenceMetadata& SetStateMask(it->second, it->second + yield.size(), state); } +IdentityCycleDetector::IdentityCycleDetector(const std::string& param) : FeatureFunction(2) { + length_min_ = 3; + if (!param.empty()) + length_min_ = atoi(param.c_str()); + assert(length_min_ >= 0); + ostringstream os; + os << "IdentityCycle_LenGT" << length_min_; + fid_ = FD::Convert(os.str()); +} + +inline bool IsIdentityTranslation(const void* state) { + return static_cast<const unsigned char*>(state)[0]; +} + +inline int SourceIndex(const void* state) { + unsigned char i = static_cast<const unsigned char*>(state)[1]; + if (i == 255) return -1; + return i; +} + +void IdentityCycleDetector::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { + unsigned char* out_state = static_cast<unsigned char*>(context); + unsigned char& out_is_identity = out_state[0]; + unsigned char& out_src_index = out_state[1]; + + if (edge.Arity() == 0) { + assert(edge.rule_->EWords() == 1); + assert(edge.rule_->FWords() == 1); + out_src_index = edge.i_; + out_is_identity = false; + if (edge.rule_->e_[0] == edge.rule_->f_[0]) { + const WordID word = edge.rule_->e_[0]; + static map<WordID, bool> big_enough; + map<WordID,bool>::iterator it = big_enough_.find(word); + if (it == big_enough_.end()) { + out_is_identity = big_enough_[word] = strlen(TD::Convert(word)) >= length_min_; + } else { + out_is_identity = it->second; + } + } + } else if (edge.Arity() == 1) { + memcpy(context, ant_contexts[0], 2); + } else if (edge.Arity() == 2) { + bool left_identity = IsIdentityTranslation(ant_contexts[0]); + int left_index = SourceIndex(ant_contexts[0]); + bool right_identity = IsIdentityTranslation(ant_contexts[1]); + int right_index = SourceIndex(ant_contexts[1]); + if ((left_identity && left_index == right_index && !right_identity) || + (right_identity && left_index == right_index && !left_identity)) { + features->set_value(fid_, 1.0); + } + out_is_identity = right_identity; + out_src_index = right_index; + } else { assert("really really bad"); } +} + + InputIdentity::InputIdentity(const std::string& param) {} void InputIdentity::FireFeature(WordID src, |