From 2236c965384b968e010b757691db20d95c349a73 Mon Sep 17 00:00:00 2001 From: redpony Date: Wed, 13 Oct 2010 16:22:54 +0000 Subject: trigger ff, max iteration for online optimizer git-svn-id: https://ws10smt.googlecode.com/svn/trunk@671 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/cdec_ff.cc | 2 +- decoder/ff_wordalign.cc | 59 ++++++++++++++++++++++++++++++++++++++++++++++++- decoder/ff_wordalign.h | 21 ++++++++++++++++++ 3 files changed, 80 insertions(+), 2 deletions(-) (limited to 'decoder') diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index c0c595a5..3240b6f2 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -54,6 +54,6 @@ void register_feature_functions() { ff_registry.Register("CSplit_ReverseCharLM", new FFFactory); ff_registry.Register("Tagger_BigramIdentity", new FFFactory); ff_registry.Register("LexicalPairIdentity", new FFFactory); - + ff_registry.Register("LexicalTranslationTrigger", new FFFactory); } diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc index da86b714..b4981961 100644 --- a/decoder/ff_wordalign.cc +++ b/decoder/ff_wordalign.cc @@ -266,7 +266,6 @@ void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } -// state: src word used, number of trg words generated SourceBigram::SourceBigram(const std::string& param) : FeatureFunction(sizeof(WordID) + sizeof(int)) { } @@ -405,6 +404,64 @@ void SourcePOSBigram::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } +LexicalTranslationTrigger::LexicalTranslationTrigger(const std::string& param) : + FeatureFunction(0) { + if (param.empty()) { + cerr << "LexicalTranslationTrigger requires a parameter (file containing triggers)!\n"; + } else { + ReadFile rf(param); + istream& in = *rf.stream(); + string line; + while(in) { + getline(in, line); + if (!in) continue; + vector v; + TD::ConvertSentence(line, &v); + triggers_.push_back(v); + } + } +} + +void LexicalTranslationTrigger::FireFeature(WordID trigger, + WordID src, + WordID trg, + SparseVector* features) const { + int& fid = fmap_[trigger][src][trg]; + if (!fid) { + ostringstream os; + os << "T:" << TD::Convert(trigger) << ':' << TD::Convert(src) << '_' << TD::Convert(trg); + fid = FD::Convert(os.str()); + } + features->set_value(fid, 1.0); + + int &tfid = target_fmap_[trigger][trg]; + if (!tfid) { + ostringstream os; + os << "TT:" << TD::Convert(trigger) << ':' << TD::Convert(trg); + tfid = FD::Convert(os.str()); + } + features->set_value(tfid, 1.0); +} + +void LexicalTranslationTrigger::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* /* estimated_features */, + void* context) const { + if (edge.Arity() == 0) { + assert(edge.rule_->EWords() == 1); + assert(edge.rule_->FWords() == 1); + WordID trg = edge.rule_->e()[0]; + WordID src = edge.rule_->f()[0]; + const vector& triggers = triggers_[smeta.GetSentenceID()]; + for (int i = 0; i < triggers.size(); ++i) { + FireFeature(triggers[i], src, trg, features); + } + } +} + +// state: src word used, number of trg words generated AlignerResults::AlignerResults(const std::string& param) : cur_sent_(-1), cur_grid_(NULL) { diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h index ebbecfea..0754d70e 100644 --- a/decoder/ff_wordalign.h +++ b/decoder/ff_wordalign.h @@ -78,6 +78,7 @@ class MarkovJumpFClass : public FeatureFunction { typedef std::map Class2FID; typedef std::map Class2Class2FID; +typedef std::map Class2Class2Class2FID; class SourceBigram : public FeatureFunction { public: SourceBigram(const std::string& param); @@ -118,6 +119,26 @@ class SourcePOSBigram : public FeatureFunction { std::vector > pos_; }; +class LexicalTranslationTrigger : public FeatureFunction { + public: + LexicalTranslationTrigger(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + void FireFeature(WordID trigger, + WordID src, + WordID trg, + SparseVector* features) const; + mutable Class2Class2Class2FID fmap_; // trigger,src,trg + mutable Class2Class2FID target_fmap_; // trigger,src,trg + std::vector > triggers_; +}; + class AlignerResults : public FeatureFunction { public: AlignerResults(const std::string& param); -- cgit v1.2.3