From 508c20ab1589533c43c8e4378bfe12d31913705a Mon Sep 17 00:00:00 2001 From: redpony Date: Wed, 13 Oct 2010 21:47:08 +0000 Subject: target unigram feature git-svn-id: https://ws10smt.googlecode.com/svn/trunk@672 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/cdec_ff.cc | 1 + decoder/ff_tagger.cc | 34 ++++++++++++++++++++++++++++++++++ decoder/ff_tagger.h | 16 ++++++++++++++++ 3 files changed, 51 insertions(+) (limited to 'decoder') diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 3240b6f2..91646253 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -54,6 +54,7 @@ void register_feature_functions() { ff_registry.Register("CSplit_ReverseCharLM", new FFFactory); ff_registry.Register("Tagger_BigramIdentity", new FFFactory); ff_registry.Register("LexicalPairIdentity", new FFFactory); + ff_registry.Register("OutputIdentity", new FFFactory); ff_registry.Register("LexicalTranslationTrigger", new FFFactory); } diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc index 7a9d1def..05de8ba3 100644 --- a/decoder/ff_tagger.cc +++ b/decoder/ff_tagger.cc @@ -93,4 +93,38 @@ void LexicalPairIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } +OutputIdentity::OutputIdentity(const std::string& param) {} + +void OutputIdentity::FireFeature(WordID trg, + SparseVector* features) const { + int& fid = fmap_[trg]; + if (!fid) { + static map escape; + if (escape.empty()) { + escape[TD::Convert("=")] = TD::Convert("__EQ"); + escape[TD::Convert(";")] = TD::Convert("__SC"); + escape[TD::Convert(",")] = TD::Convert("__CO"); + } + if (escape.count(trg)) trg = escape[trg]; + ostringstream os; + os << "T:" << TD::Convert(trg); + fid = FD::Convert(os.str()); + } + features->set_value(fid, 1.0); +} + +void OutputIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + const vector& ew = edge.rule_->e_; + for (int i = 0; i < ew.size(); ++i) { + const WordID& e = ew[i]; + if (e > 0) FireFeature(e, features); + } +} + + diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h index 41c3ee5b..9e47854e 100644 --- a/decoder/ff_tagger.h +++ b/decoder/ff_tagger.h @@ -48,4 +48,20 @@ class LexicalPairIdentity : public FeatureFunction { }; +class OutputIdentity : public FeatureFunction { + public: + OutputIdentity(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + void FireFeature(WordID trg, + SparseVector* features) const; + mutable Class2FID fmap_; +}; + #endif -- cgit v1.2.3