diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-12-01 05:27:13 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-12-01 05:27:13 +0000 |
commit | 5694fc704f0c7b040c28f88a034e67a1ed19d3ba (patch) | |
tree | 5ee46a3429414b1c1cdf9712f27a645b7438eed6 /decoder | |
parent | 083e28a2694df51a4631d81347c45f57a5182560 (diff) |
alternative def of neighborhoods
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@739 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/cdec_ff.cc | 3 | ||||
-rw-r--r-- | decoder/ff_wordalign.cc | 256 | ||||
-rw-r--r-- | decoder/ff_wordalign.h | 55 | ||||
-rw-r--r-- | decoder/lextrans.cc | 32 | ||||
-rw-r--r-- | decoder/trule.cc | 20 |
5 files changed, 348 insertions, 18 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 3953118c..d6cf4572 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -51,6 +51,8 @@ void register_feature_functions() { ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>); ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>); ff_registry.Register("Model2BinaryFeatures", new FFFactory<Model2BinaryFeatures>); + ff_registry.Register("LexNullJump", new FFFactory<LexNullJump>); + ff_registry.Register("NewJump", new FFFactory<NewJump>); ff_registry.Register("MarkovJump", new FFFactory<MarkovJump>); ff_registry.Register("MarkovJumpFClass", new FFFactory<MarkovJumpFClass>); ff_registry.Register("SourceBigram", new FFFactory<SourceBigram>); @@ -64,6 +66,7 @@ void register_feature_functions() { ff_registry.Register("OutputIdentity", new FFFactory<OutputIdentity>); ff_registry.Register("InputIdentity", new FFFactory<InputIdentity>); ff_registry.Register("LexicalTranslationTrigger", new FFFactory<LexicalTranslationTrigger>); + ff_registry.Register("WordPairFeatures", new FFFactory<WordPairFeatures>); ff_registry.Register("WordSet", new FFFactory<WordSet>); #ifdef HAVE_GLC ff_registry.Register("ContextCRF", new FFFactory<Model1Features>); diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc index 5f42b438..980c64ad 100644 --- a/decoder/ff_wordalign.cc +++ b/decoder/ff_wordalign.cc @@ -1,10 +1,13 @@ #include "ff_wordalign.h" +#include <algorithm> +#include <iterator> #include <set> #include <sstream> #include <string> #include <cmath> +#include "verbose.h" #include "alignment_pharaoh.h" #include "stringlib.h" #include "sentence_metadata.h" @@ -20,6 +23,8 @@ static const int kNULL_i = 255; // -1 as an unsigned char using namespace std; +// TODO new feature: if a word is translated as itself and there is a transition back to the same word, fire a feature + Model2BinaryFeatures::Model2BinaryFeatures(const string& ) : fids_(boost::extents[MAX_SENTENCE_SIZE][MAX_SENTENCE_SIZE][MAX_SENTENCE_SIZE]) { for (int i = 1; i < MAX_SENTENCE_SIZE; ++i) { @@ -195,6 +200,45 @@ void MarkovJumpFClass::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } +LexNullJump::LexNullJump(const string& param) : + FeatureFunction(1), + fid_lex_null_(FD::Convert("JumpLexNull")), + fid_null_lex_(FD::Convert("JumpNullLex")), + fid_null_null_(FD::Convert("JumpNullNull")), + fid_lex_lex_(FD::Convert("JumpLexLex")) {} + +void LexNullJump::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* /* estimated_features */, + void* state) const { + char& dpstate = *((char*)state); + if (edge.Arity() == 0) { + // dpstate is 'N' = null or 'L' = lex + if (edge.i_ < 0) { dpstate = 'N'; } else { dpstate = 'L'; } + } else if (edge.Arity() == 1) { + dpstate = *((unsigned char*)ant_states[0]); + } else if (edge.Arity() == 2) { + char left = *((char*)ant_states[0]); + char right = *((char*)ant_states[1]); + dpstate = right; + if (left == 'N') { + if (right == 'N') + features->set_value(fid_null_null_, 1.0); + else + features->set_value(fid_null_lex_, 1.0); + } else { // left == 'L' + if (right == 'N') + features->set_value(fid_lex_null_, 1.0); + else + features->set_value(fid_lex_lex_, 1.0); + } + } else { + assert(!"something really unexpected is happening"); + } +} + MarkovJump::MarkovJump(const string& param) : FeatureFunction(1), fid_(FD::Convert("MarkovJump")), @@ -287,6 +331,100 @@ void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } +NewJump::NewJump(const string& param) : + FeatureFunction(1) { + cerr << " NewJump"; + vector<string> argv; + int argc = SplitOnWhitespace(param, &argv); + set<string> config; + for (int i = 0; i < argc; ++i) config.insert(argv[i]); + cerr << endl; + use_binned_log_lengths_ = config.count("use_binned_log_lengths") > 0; +} + +// do a log transform on the length (of a sentence, a jump, etc) +// this basically means that large distances that are close to each other +// are put into the same bin +int BinnedLogLength(int len) { + int res = static_cast<int>(log(len+1) / log(1.3)); + if (res > 16) res = 16; + return res; +} + +void NewJump::FireFeature(const SentenceMetadata& smeta, + const int prev_src_index, + const int cur_src_index, + SparseVector<double>* features) const { + const int src_len = smeta.GetSourceLength(); + const int raw_jump = cur_src_index - prev_src_index; + char jtype = 0; + int jump_magnitude = raw_jump; + if (raw_jump > 0) { jtype = 'R'; } // Right + else if (raw_jump == 0) { jtype = 'S'; } // Stay + else { jtype = 'L'; jump_magnitude = raw_jump * -1; } // Left + int effective_length = src_len; + if (use_binned_log_lengths_) { + jump_magnitude = BinnedLogLength(jump_magnitude); + effective_length = BinnedLogLength(src_len); + } + + if (true) { + static map<int, map<int, int> > len2jump2fid; + int& fid = len2jump2fid[src_len][raw_jump]; + if (!fid) { + ostringstream os; + os << fid_str_ << ":FLen" << effective_length << ":" << jtype << jump_magnitude; + fid = FD::Convert(os.str()); + } + features->set_value(fid, 1.0); + } +} + +void NewJump::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* /* estimated_features */, + void* state) const { + unsigned char& dpstate = *((unsigned char*)state); + const int flen = smeta.GetSourceLength(); + if (edge.Arity() == 0) { + dpstate = static_cast<unsigned int>(edge.i_); + if (edge.prev_i_ == 0) { // first target word in sentence + if (edge.i_ >= 0) { // generated from non-Null token? + FireFeature(smeta, + -1, // previous src = beginning of sentence index + edge.i_, // current src + features); + } + } else if (edge.prev_i_ == smeta.GetTargetLength() - 1) { // last word + if (edge.i_ >= 0) { // generated from non-Null token? + FireFeature(smeta, + edge.i_, // previous src = last word position + flen, // current src + features); + } + } + } else if (edge.Arity() == 1) { + dpstate = *((unsigned char*)ant_states[0]); + } else if (edge.Arity() == 2) { + int left_index = *((unsigned char*)ant_states[0]); + int right_index = *((unsigned char*)ant_states[1]); + if (right_index == -1) + dpstate = static_cast<unsigned int>(left_index); + else + dpstate = static_cast<unsigned int>(right_index); + if (left_index != kNULL_i && right_index != kNULL_i) { + FireFeature(smeta, + left_index, // previous src index + right_index, // current src index + features); + } + } else { + assert(!"something really unexpected is happening"); + } +} + SourceBigram::SourceBigram(const std::string& param) : FeatureFunction(sizeof(WordID) + sizeof(int)) { } @@ -626,6 +764,122 @@ void InputIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } +WordPairFeatures::WordPairFeatures(const string& param) { + vector<string> argv; + int argc = SplitOnWhitespace(param, &argv); + if (argc != 1) { + cerr << "WordPairFeature /path/to/feature_values.table\n"; + abort(); + } + set<WordID> all_srcs; + { + ReadFile rf(argv[0]); + istream& in = *rf.stream(); + string buf; + while (in) { + getline(in, buf); + if (buf.empty()) continue; + int start = 0; + while(start < buf.size() && buf[start] == ' ') ++start; + int end = start; + while(end < buf.size() && buf[end] != ' ') ++end; + const WordID src = TD::Convert(buf.substr(start, end - start)); + all_srcs.insert(src); + } + } + if (all_srcs.empty()) { + cerr << "WordPairFeature " << param << " loaded empty file!\n"; + return; + } + fkeys_.reserve(all_srcs.size()); + copy(all_srcs.begin(), all_srcs.end(), back_inserter(fkeys_)); + values_.resize(all_srcs.size()); + if (!SILENT) { cerr << "WordPairFeature: " << all_srcs.size() << " sources\n"; } + ReadFile rf(argv[0]); + istream& in = *rf.stream(); + string buf; + double val = 0; + WordID cur_src = 0; + map<WordID, SparseVector<float> > *pv = NULL; + const WordID kBARRIER = TD::Convert("|||"); + while (in) { + getline(in, buf); + if (buf.size() == 0) continue; + int start = 0; + while(start < buf.size() && buf[start] == ' ') ++start; + int end = start; + while(end < buf.size() && buf[end] != ' ') ++end; + const WordID src = TD::Convert(buf.substr(start, end - start)); + if (cur_src != src) { + cur_src = src; + size_t ind = distance(fkeys_.begin(), lower_bound(fkeys_.begin(), fkeys_.end(), cur_src)); + pv = &values_[ind]; + } + end += 1; + start = end; + while(end < buf.size() && buf[end] != ' ') ++end; + WordID x = TD::Convert(buf.substr(start, end - start)); + if (x != kBARRIER) { + cerr << "1 Format error: " << buf << endl; + abort(); + } + start = end + 1; + end = start + 1; + while(end < buf.size() && buf[end] != ' ') ++end; + WordID trg = TD::Convert(buf.substr(start, end - start)); + if (trg == kBARRIER) { + cerr << "2 Format error: " << buf << endl; + abort(); + } + start = end + 1; + end = start + 1; + while(end < buf.size() && buf[end] != ' ') ++end; + WordID x2 = TD::Convert(buf.substr(start, end - start)); + if (x2 != kBARRIER) { + cerr << "3 Format error: " << buf << endl; + abort(); + } + start = end + 1; + + SparseVector<float>& v = (*pv)[trg]; + while(start < buf.size()) { + end = start + 1; + while(end < buf.size() && buf[end] != '=' && buf[end] != ' ') ++end; + if (end == buf.size() || buf[end] != '=') { cerr << "4 Format error: " << buf << endl; abort(); } + const int fid = FD::Convert(buf.substr(start, end - start)); + start = end + 1; + while(start < buf.size() && buf[start] == ' ') ++start; + end = start + 1; + while(end < buf.size() && buf[end] != ' ') ++end; + assert(end > start); + if (end < buf.size()) buf[end] = 0; + val = strtod(&buf.c_str()[start], NULL); + v.set_value(fid, val); + start = end + 1; + } + } +} - +void WordPairFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { + if (edge.Arity() == 0) { + assert(edge.rule_->EWords() == 1); + assert(edge.rule_->FWords() == 1); + const WordID trg = edge.rule_->e()[0]; + const WordID src = edge.rule_->f()[0]; + size_t ind = distance(fkeys_.begin(), lower_bound(fkeys_.begin(), fkeys_.end(), src)); + if (ind == fkeys_.size() || fkeys_[ind] != src) { + cerr << "WordPairFeatures no source entries for " << TD::Convert(src) << endl; + abort(); + } + const map<WordID, SparseVector<float> >::const_iterator it = values_[ind].find(trg); + // TODO optional strict flag to make sure there are features for all pairs? + if (it != values_[ind].end()) + (*features) += it->second; + } +} diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h index 0714229c..418c8768 100644 --- a/decoder/ff_wordalign.h +++ b/decoder/ff_wordalign.h @@ -103,6 +103,43 @@ class SourceBigram : public FeatureFunction { mutable Class2Class2FID fmap_; }; +class LexNullJump : public FeatureFunction { + public: + LexNullJump(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const; + private: + const int fid_lex_null_; + const int fid_null_lex_; + const int fid_null_null_; + const int fid_lex_lex_; +}; + +class NewJump : public FeatureFunction { + public: + NewJump(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const; + private: + void FireFeature(const SentenceMetadata& smeta, + const int prev_src_index, + const int cur_src_index, + SparseVector<double>* features) const; + + bool use_binned_log_lengths_; + std::string fid_str_; // identifies configuration uniquely +}; + class SourcePOSBigram : public FeatureFunction { public: SourcePOSBigram(const std::string& param); @@ -238,6 +275,24 @@ class BlunsomSynchronousParseHack : public FeatureFunction { mutable std::vector<std::vector<WordID> > refs_; }; +// association feature type look up a pair (e,f) in a table and return a vector +// of feature values +class WordPairFeatures : public FeatureFunction { + public: + WordPairFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + + private: + std::vector<WordID> fkeys_; // parallel to values_ + std::vector<std::map<WordID, SparseVector<float> > > values_; // fkeys_index -> e -> value +}; + class InputIdentity : public FeatureFunction { public: InputIdentity(const std::string& param); diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index 4476fe63..35d2d15d 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -76,13 +76,13 @@ struct LexicalTransImpl { // hack to tell the feature function system how big the sentence pair is const int f_start = (use_null ? -1 : 0); int prev_node_id = -1; - set<WordID> target_vocab; // only set for alignment_only mode - if (align_only_) { - const Lattice& ref = smeta.GetReference(); - for (int i = 0; i < ref.size(); ++i) { - target_vocab.insert(ref[i][0].label); - } + set<WordID> target_vocab; + const Lattice& ref = smeta.GetReference(); + for (int i = 0; i < ref.size(); ++i) { + target_vocab.insert(ref[i][0].label); } + bool all_sources_to_all_targets_ = true; + set<WordID> trgs_used; for (int i = 0; i < e_len; ++i) { // for each word in the *target* Hypergraph::Node* node = forest->AddNode(kXCAT); const int new_node_id = node->id_; @@ -101,10 +101,13 @@ struct LexicalTransImpl { assert(rb); for (int k = 0; k < rb->GetNumRules(); ++k) { TRulePtr rule = rb->GetIthRule(k); + const WordID trg_word = rule->e_[0]; if (align_only_) { - if (target_vocab.count(rule->e_[0]) == 0) + if (target_vocab.count(trg_word) == 0) continue; } + if (all_sources_to_all_targets_ && (target_vocab.count(trg_word) > 0)) + trgs_used.insert(trg_word); Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); edge->i_ = j; edge->j_ = j+1; @@ -113,6 +116,21 @@ struct LexicalTransImpl { edge->feature_values_ += edge->rule_->GetFeatureValues(); forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); } + if (all_sources_to_all_targets_) { + for (set<WordID>::iterator it = target_vocab.begin(); it != target_vocab.end(); ++it) { + if (trgs_used.count(*it)) continue; + const WordID ungenerated_trg_word = *it; + TRulePtr rule; + rule.reset(TRule::CreateLexicalRule(src_sym, ungenerated_trg_word)); + Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); + edge->i_ = j; + edge->j_ = j+1; + edge->prev_i_ = i; + edge->prev_j_ = i+1; + forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); + } + trgs_used.clear(); + } } if (prev_node_id >= 0) { const int comb_node_id = forest->AddNode(kXCAT)->id_; diff --git a/decoder/trule.cc b/decoder/trule.cc index a40c4e14..eedf8f30 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -246,18 +246,18 @@ string TRule::AsString(bool verbose) const { int idx = 0; if (lhs_ && verbose) { os << '[' << TD::Convert(lhs_ * -1) << "] |||"; - for (int i = 0; i < f_.size(); ++i) { - const WordID& w = f_[i]; - if (w < 0) { - int wi = w * -1; - ++idx; - os << " [" << TD::Convert(wi) << ',' << idx << ']'; - } else { - os << ' ' << TD::Convert(w); - } + } + for (int i = 0; i < f_.size(); ++i) { + const WordID& w = f_[i]; + if (w < 0) { + int wi = w * -1; + ++idx; + os << " [" << TD::Convert(wi) << ',' << idx << ']'; + } else { + os << ' ' << TD::Convert(w); } - os << " ||| "; } + os << " ||| "; if (idx > 9) { cerr << "Too many non-terminals!\n partial: " << os.str() << endl; exit(1); |