diff options
-rw-r--r-- | decoder/cdec_ff.cc | 1 | ||||
-rw-r--r-- | decoder/ff_tagger.cc | 34 | ||||
-rw-r--r-- | decoder/ff_tagger.h | 16 |
3 files changed, 51 insertions, 0 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 3240b6f2..91646253 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -54,6 +54,7 @@ void register_feature_functions() { ff_registry.Register("CSplit_ReverseCharLM", new FFFactory<ReverseCharLMCSplitFeature>); ff_registry.Register("Tagger_BigramIdentity", new FFFactory<Tagger_BigramIdentity>); ff_registry.Register("LexicalPairIdentity", new FFFactory<LexicalPairIdentity>); + ff_registry.Register("OutputIdentity", new FFFactory<OutputIdentity>); ff_registry.Register("LexicalTranslationTrigger", new FFFactory<LexicalTranslationTrigger>); } diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc index 7a9d1def..05de8ba3 100644 --- a/decoder/ff_tagger.cc +++ b/decoder/ff_tagger.cc @@ -93,4 +93,38 @@ void LexicalPairIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } +OutputIdentity::OutputIdentity(const std::string& param) {} + +void OutputIdentity::FireFeature(WordID trg, + SparseVector<double>* features) const { + int& fid = fmap_[trg]; + if (!fid) { + static map<WordID, WordID> escape; + if (escape.empty()) { + escape[TD::Convert("=")] = TD::Convert("__EQ"); + escape[TD::Convert(";")] = TD::Convert("__SC"); + escape[TD::Convert(",")] = TD::Convert("__CO"); + } + if (escape.count(trg)) trg = escape[trg]; + ostringstream os; + os << "T:" << TD::Convert(trg); + fid = FD::Convert(os.str()); + } + features->set_value(fid, 1.0); +} + +void OutputIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { + const vector<WordID>& ew = edge.rule_->e_; + for (int i = 0; i < ew.size(); ++i) { + const WordID& e = ew[i]; + if (e > 0) FireFeature(e, features); + } +} + + diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h index 41c3ee5b..9e47854e 100644 --- a/decoder/ff_tagger.h +++ b/decoder/ff_tagger.h @@ -48,4 +48,20 @@ class LexicalPairIdentity : public FeatureFunction { }; +class OutputIdentity : public FeatureFunction { + public: + OutputIdentity(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + private: + void FireFeature(WordID trg, + SparseVector<double>* features) const; + mutable Class2FID fmap_; +}; + #endif |