diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/cdec_ff.cc | 1 | ||||
| -rw-r--r-- | decoder/ff_wordalign.cc | 62 | ||||
| -rw-r--r-- | decoder/ff_wordalign.h | 17 | 
3 files changed, 80 insertions, 0 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index e87ab5ab..729d1214 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -60,6 +60,7 @@ void register_feature_functions() {    ff_registry.Register("Tagger_BigramIdentity", new FFFactory<Tagger_BigramIdentity>);    ff_registry.Register("LexicalPairIdentity", new FFFactory<LexicalPairIdentity>);    ff_registry.Register("OutputIdentity", new FFFactory<OutputIdentity>); +  ff_registry.Register("IdentityCycleDetector", new FFFactory<IdentityCycleDetector>);    ff_registry.Register("InputIdentity", new FFFactory<InputIdentity>);    ff_registry.Register("LexicalTranslationTrigger", new FFFactory<LexicalTranslationTrigger>);    ff_registry.Register("WordPairFeatures", new FFFactory<WordPairFeatures>); diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc index 338f1a72..ef3310b4 100644 --- a/decoder/ff_wordalign.cc +++ b/decoder/ff_wordalign.cc @@ -556,6 +556,68 @@ void BlunsomSynchronousParseHack::TraversalFeaturesImpl(const SentenceMetadata&    SetStateMask(it->second, it->second + yield.size(), state);  } +IdentityCycleDetector::IdentityCycleDetector(const std::string& param) : FeatureFunction(2) { +  length_min_ = 3; +  if (!param.empty()) +    length_min_ = atoi(param.c_str()); +  assert(length_min_ >= 0); +  ostringstream os; +  os << "IdentityCycle_LenGT" << length_min_; +  fid_ = FD::Convert(os.str()); +} + +inline bool IsIdentityTranslation(const void* state) { +  return static_cast<const unsigned char*>(state)[0]; +} + +inline int SourceIndex(const void* state) { +  unsigned char i = static_cast<const unsigned char*>(state)[1]; +  if (i == 255) return -1; +  return i; +} + +void IdentityCycleDetector::TraversalFeaturesImpl(const SentenceMetadata& smeta, +                                     const Hypergraph::Edge& edge, +                                     const std::vector<const void*>& ant_contexts, +                                     SparseVector<double>* features, +                                     SparseVector<double>* estimated_features, +                                     void* context) const { +  unsigned char* out_state = static_cast<unsigned char*>(context); +  unsigned char& out_is_identity = out_state[0]; +  unsigned char& out_src_index = out_state[1]; + +  if (edge.Arity() == 0) { +    assert(edge.rule_->EWords() == 1); +    assert(edge.rule_->FWords() == 1); +    out_src_index = edge.i_; +    out_is_identity = false; +    if (edge.rule_->e_[0] == edge.rule_->f_[0]) { +      const WordID word = edge.rule_->e_[0]; +      static map<WordID, bool> big_enough; +      map<WordID,bool>::iterator it = big_enough_.find(word); +      if (it == big_enough_.end()) { +        out_is_identity = big_enough_[word] = strlen(TD::Convert(word)) >= length_min_; +      } else { +        out_is_identity = it->second; +      } +    } +  } else if (edge.Arity() == 1) { +    memcpy(context, ant_contexts[0], 2); +  } else if (edge.Arity() == 2) { +    bool left_identity = IsIdentityTranslation(ant_contexts[0]); +    int left_index = SourceIndex(ant_contexts[0]); +    bool right_identity = IsIdentityTranslation(ant_contexts[1]); +    int right_index = SourceIndex(ant_contexts[1]); +    if ((left_identity && left_index == right_index && !right_identity) || +        (right_identity && left_index == right_index && !left_identity)) { +      features->set_value(fid_, 1.0); +    } +    out_is_identity = right_identity; +    out_src_index = right_index; +  } else { assert("really really bad"); } +} + +  InputIdentity::InputIdentity(const std::string& param) {}  void InputIdentity::FireFeature(WordID src, diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h index a1ffd9ca..8035000e 100644 --- a/decoder/ff_wordalign.h +++ b/decoder/ff_wordalign.h @@ -237,6 +237,23 @@ class WordPairFeatures : public FeatureFunction {    std::vector<std::map<WordID, SparseVector<float> > > values_;  // fkeys_index -> e -> value  }; +// fires when a len(word) >= length_min_ is translated as itself and then a self-transition is made +class IdentityCycleDetector : public FeatureFunction { + public: +  IdentityCycleDetector(const std::string& param); + protected: +  virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, +                                     const Hypergraph::Edge& edge, +                                     const std::vector<const void*>& ant_contexts, +                                     SparseVector<double>* features, +                                     SparseVector<double>* estimated_features, +                                     void* context) const; + private: +  int length_min_; +  int fid_; +  mutable std::map<WordID, bool> big_enough_; +}; +  class InputIdentity : public FeatureFunction {   public:    InputIdentity(const std::string& param);  | 
