summaryrefslogtreecommitdiff
path: root/decoder/ff_wordalign.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/ff_wordalign.cc')
-rw-r--r--decoder/ff_wordalign.cc62
1 files changed, 62 insertions, 0 deletions
diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc
index 338f1a72..ef3310b4 100644
--- a/decoder/ff_wordalign.cc
+++ b/decoder/ff_wordalign.cc
@@ -556,6 +556,68 @@ void BlunsomSynchronousParseHack::TraversalFeaturesImpl(const SentenceMetadata&
SetStateMask(it->second, it->second + yield.size(), state);
}
+IdentityCycleDetector::IdentityCycleDetector(const std::string& param) : FeatureFunction(2) {
+ length_min_ = 3;
+ if (!param.empty())
+ length_min_ = atoi(param.c_str());
+ assert(length_min_ >= 0);
+ ostringstream os;
+ os << "IdentityCycle_LenGT" << length_min_;
+ fid_ = FD::Convert(os.str());
+}
+
+inline bool IsIdentityTranslation(const void* state) {
+ return static_cast<const unsigned char*>(state)[0];
+}
+
+inline int SourceIndex(const void* state) {
+ unsigned char i = static_cast<const unsigned char*>(state)[1];
+ if (i == 255) return -1;
+ return i;
+}
+
+void IdentityCycleDetector::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ unsigned char* out_state = static_cast<unsigned char*>(context);
+ unsigned char& out_is_identity = out_state[0];
+ unsigned char& out_src_index = out_state[1];
+
+ if (edge.Arity() == 0) {
+ assert(edge.rule_->EWords() == 1);
+ assert(edge.rule_->FWords() == 1);
+ out_src_index = edge.i_;
+ out_is_identity = false;
+ if (edge.rule_->e_[0] == edge.rule_->f_[0]) {
+ const WordID word = edge.rule_->e_[0];
+ static map<WordID, bool> big_enough;
+ map<WordID,bool>::iterator it = big_enough_.find(word);
+ if (it == big_enough_.end()) {
+ out_is_identity = big_enough_[word] = strlen(TD::Convert(word)) >= length_min_;
+ } else {
+ out_is_identity = it->second;
+ }
+ }
+ } else if (edge.Arity() == 1) {
+ memcpy(context, ant_contexts[0], 2);
+ } else if (edge.Arity() == 2) {
+ bool left_identity = IsIdentityTranslation(ant_contexts[0]);
+ int left_index = SourceIndex(ant_contexts[0]);
+ bool right_identity = IsIdentityTranslation(ant_contexts[1]);
+ int right_index = SourceIndex(ant_contexts[1]);
+ if ((left_identity && left_index == right_index && !right_identity) ||
+ (right_identity && left_index == right_index && !left_identity)) {
+ features->set_value(fid_, 1.0);
+ }
+ out_is_identity = right_identity;
+ out_src_index = right_index;
+ } else { assert("really really bad"); }
+}
+
+
InputIdentity::InputIdentity(const std::string& param) {}
void InputIdentity::FireFeature(WordID src,