diff options
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/Makefile.am | 2 | ||||
-rw-r--r-- | decoder/cdec_ff.cc | 9 | ||||
-rw-r--r-- | decoder/ff_klm.cc | 38 | ||||
-rw-r--r-- | decoder/ff_klm.h | 7 | ||||
-rw-r--r-- | decoder/ff_ngrams.cc | 341 | ||||
-rw-r--r-- | decoder/ff_ngrams.h | 29 | ||||
-rw-r--r-- | decoder/ff_rules.cc | 107 | ||||
-rw-r--r-- | decoder/ff_rules.h | 40 | ||||
-rw-r--r-- | decoder/ff_spans.cc | 74 | ||||
-rw-r--r-- | decoder/ff_spans.h | 15 |
10 files changed, 587 insertions, 75 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 244da2de..e5f7505f 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -61,10 +61,12 @@ libcdec_a_SOURCES = \ phrasetable_fst.cc \ trule.cc \ ff.cc \ + ff_rules.cc \ ff_wordset.cc \ ff_charset.cc \ ff_lm.cc \ ff_klm.cc \ + ff_ngrams.cc \ ff_spans.cc \ ff_ruleshape.cc \ ff_wordalign.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 31f88a4f..588842f1 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -4,10 +4,12 @@ #include "ff_spans.h" #include "ff_lm.h" #include "ff_klm.h" +#include "ff_ngrams.h" #include "ff_csplit.h" #include "ff_wordalign.h" #include "ff_tagger.h" #include "ff_factory.h" +#include "ff_rules.h" #include "ff_ruleshape.h" #include "ff_bleu.h" #include "ff_lm_fsa.h" @@ -51,12 +53,11 @@ void register_feature_functions() { ff_registry.Register("RandLM", new FFFactory<LanguageModelRandLM>); #endif ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>()); + ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>()); + ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>()); ff_registry.Register("RuleNgramFeatures", new FFFactory<RuleNgramFeatures>()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory<CMR2008ReorderingFeatures>()); - ff_registry.Register("KLanguageModel", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >()); - ff_registry.Register("KLanguageModel_Trie", new FFFactory<KLanguageModel<lm::ngram::TrieModel> >()); - ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory<KLanguageModel<lm::ngram::QuantTrieModel> >()); - ff_registry.Register("KLanguageModel_Probing", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >()); + ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>); ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>); ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 9b7fe2d3..24dcb9c3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -9,6 +9,7 @@ #include "stringlib.h" #include "hg.h" #include "tdict.h" +#include "lm/model.hh" #include "lm/enumerate_vocab.hh" using namespace std; @@ -434,8 +435,37 @@ void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state, features->set_value(oov_fid_, oovs); } -// instantiate templates -template class KLanguageModel<lm::ngram::ProbingModel>; -template class KLanguageModel<lm::ngram::TrieModel>; -template class KLanguageModel<lm::ngram::QuantTrieModel>; +template <class Model> boost::shared_ptr<FeatureFunction> CreateModel(const std::string ¶m) { + KLanguageModel<Model> *ret = new KLanguageModel<Model>(param); + ret->Init(); + return boost::shared_ptr<FeatureFunction>(ret); +} +boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string param) const { + using namespace lm::ngram; + std::string filename, ignored_map; + bool ignored_markers; + std::string ignored_featname; + ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); + ModelType m; + if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + + switch (m) { + case HASH_PROBING: + return CreateModel<ProbingModel>(param); + case TRIE_SORTED: + return CreateModel<TrieModel>(param); + case ARRAY_TRIE_SORTED: + return CreateModel<ArrayTrieModel>(param); + case QUANT_TRIE_SORTED: + return CreateModel<QuantTrieModel>(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel<QuantArrayTrieModel>(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel<lm::ngram::Model>::usage(params, verbose); +} diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 5eafe8be..6efe50f6 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -4,8 +4,8 @@ #include <vector> #include <string> +#include "ff_factory.h" #include "ff.h" -#include "lm/model.hh" template <class Model> struct KLanguageModelImpl; @@ -34,4 +34,9 @@ class KLanguageModel : public FeatureFunction { KLanguageModelImpl<Model>* pimpl_; }; +struct KLanguageModelFactory : public FactoryBase<FeatureFunction> { + FP Create(std::string param) const; + std::string usage(bool params,bool verbose) const; +}; + #endif diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc new file mode 100644 index 00000000..04dd1906 --- /dev/null +++ b/decoder/ff_ngrams.cc @@ -0,0 +1,341 @@ +#include "ff_ngrams.h" + +#include <cstring> +#include <iostream> + +#include <boost/scoped_ptr.hpp> + +#include "filelib.h" +#include "stringlib.h" +#include "hg.h" +#include "tdict.h" + +using namespace std; + +static const unsigned char HAS_FULL_CONTEXT = 1; +static const unsigned char HAS_EOS_ON_RIGHT = 2; +static const unsigned char MASK = 7; + +namespace { +template <unsigned MAX_ORDER = 5> +struct State { + explicit State() { + memset(state, 0, sizeof(state)); + } + explicit State(int order) { + memset(state, 0, (order - 1) * sizeof(WordID)); + } + State<MAX_ORDER>(char order, const WordID* mem) { + memcpy(state, mem, (order - 1) * sizeof(WordID)); + } + State(const State<MAX_ORDER>& other) { + memcpy(state, other.state, sizeof(state)); + } + const State& operator=(const State<MAX_ORDER>& other) { + memcpy(state, other.state, sizeof(state)); + } + explicit State(const State<MAX_ORDER>& other, unsigned order, WordID extend) { + char om1 = order - 1; + assert(om1 > 0); + for (char i = 1; i < om1; ++i) state[i - 1]= other.state[i]; + state[om1 - 1] = extend; + } + const WordID& operator[](size_t i) const { return state[i]; } + WordID& operator[](size_t i) { return state[i]; } + WordID state[MAX_ORDER]; +}; +} + +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + +class NgramDetectorImpl { + + // returns the number of unscored words at the left edge of a span + inline int UnscoredSize(const void* state) const { + return *(static_cast<const char*>(state) + unscored_size_offset_); + } + + inline void SetUnscoredSize(int size, void* state) const { + *(static_cast<char*>(state) + unscored_size_offset_) = size; + } + + inline State<5> RemnantLMState(const void* cstate) const { + return State<5>(order_, static_cast<const WordID*>(cstate)); + } + + inline const State<5> BeginSentenceState() const { + State<5> state(order_); + state.state[0] = kSOS_; + return state; + } + + inline void SetRemnantLMState(const State<5>& lmstate, void* state) const { + // if we were clever, we could use the memory pointed to by state to do all + // the work, avoiding this copy + memcpy(state, lmstate.state, (order_-1) * sizeof(WordID)); + } + + WordID IthUnscoredWord(int i, const void* state) const { + const WordID* const mem = reinterpret_cast<const WordID*>(static_cast<const char*>(state) + unscored_words_offset_); + return mem[i]; + } + + void SetIthUnscoredWord(int i, const WordID index, void *state) const { + WordID* mem = reinterpret_cast<WordID*>(static_cast<char*>(state) + unscored_words_offset_); + mem[i] = index; + } + + inline bool GetFlag(const void *state, unsigned char flag) const { + return (*(static_cast<const char*>(state) + is_complete_offset_) & flag); + } + + inline void SetFlag(bool on, unsigned char flag, void *state) const { + if (on) { + *(static_cast<char*>(state) + is_complete_offset_) |= flag; + } else { + *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag); + } + } + + inline bool HasFullContext(const void *state) const { + return GetFlag(state, HAS_FULL_CONTEXT); + } + + inline void SetHasFullContext(bool flag, void *state) const { + SetFlag(flag, HAS_FULL_CONTEXT, state); + } + + void FireFeatures(const State<5>& state, WordID cur, SparseVector<double>* feats) { + FidTree* ft = &fidroot_; + int n = 0; + WordID buf[10]; + int ci = order_ - 1; + WordID curword = cur; + while(curword) { + buf[n] = curword; + int& fid = ft->fids[curword]; + ++n; + if (!fid) { + const char* code="_UBT456789"; // prefix code (unigram, bigram, etc.) + ostringstream os; + os << code[n] << ':'; + for (int i = n-1; i >= 0; --i) { + os << (i != n-1 ? "_" : ""); + const string& tok = TD::Convert(buf[i]); + if (tok.find('=') == string::npos) + os << tok; + else + os << Escape(tok); + } + fid = FD::Convert(os.str()); + } + feats->set_value(fid, 1); + ft = &ft->levels[curword]; + --ci; + if (ci < 0) break; + curword = state[ci]; + } + } + + public: + void LookupWords(const TRule& rule, const vector<const void*>& ant_states, SparseVector<double>* feats, SparseVector<double>* est_feats, void* remnant) { + double sum = 0.0; + double est_sum = 0.0; + int num_scored = 0; + int num_estimated = 0; + bool saw_eos = false; + bool has_some_history = false; + State<5> state; + const vector<WordID>& e = rule.e(); + bool context_complete = false; + for (int j = 0; j < e.size(); ++j) { + if (e[j] < 1) { // handle non-terminal substitution + const void* astate = (ant_states[-e[j]]); + int unscored_ant_len = UnscoredSize(astate); + for (int k = 0; k < unscored_ant_len; ++k) { + const WordID cur_word = IthUnscoredWord(k, astate); + const bool is_oov = (cur_word == 0); + SparseVector<double> p; + if (cur_word == kSOS_) { + state = BeginSentenceState(); + if (has_some_history) { // this is immediately fully scored, and bad + p.set_value(FD::Convert("Malformed"), 1.0); + context_complete = true; + } else { // this might be a real <s> + num_scored = max(0, order_ - 2); + } + } else { + FireFeatures(state, cur_word, &p); + const State<5> scopy = State<5>(state, order_, cur_word); + state = scopy; + if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); } + saw_eos = (cur_word == kEOS_); + } + has_some_history = true; + ++num_scored; + if (!context_complete) { + if (num_scored >= order_) context_complete = true; + } + if (context_complete) { + (*feats) += p; + } else { + if (remnant) + SetIthUnscoredWord(num_estimated, cur_word, remnant); + ++num_estimated; + (*est_feats) += p; + } + } + saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); + if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 + state = RemnantLMState(astate); + context_complete = true; + } + } else { // handle terminal + const WordID cur_word = e[j]; + SparseVector<double> p; + if (cur_word == kSOS_) { + state = BeginSentenceState(); + if (has_some_history) { // this is immediately fully scored, and bad + p.set_value(FD::Convert("Malformed"), -100); + context_complete = true; + } else { // this might be a real <s> + num_scored = max(0, order_ - 2); + } + } else { + FireFeatures(state, cur_word, &p); + const State<5> scopy = State<5>(state, order_, cur_word); + state = scopy; + if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); } + saw_eos = (cur_word == kEOS_); + } + has_some_history = true; + ++num_scored; + if (!context_complete) { + if (num_scored >= order_) context_complete = true; + } + if (context_complete) { + (*feats) += p; + } else { + if (remnant) + SetIthUnscoredWord(num_estimated, cur_word, remnant); + ++num_estimated; + (*est_feats) += p; + } + } + } + if (remnant) { + SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); + SetRemnantLMState(state, remnant); + SetUnscoredSize(num_estimated, remnant); + SetHasFullContext(context_complete || (num_scored >= order_), remnant); + } + } + + // this assumes no target words on final unary -> goal rule. is that ok? + // for <s> (n-1 left words) and (n-1 right words) </s> + void FinalTraversal(const void* state, SparseVector<double>* feats) { + if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here + SetRemnantLMState(BeginSentenceState(), dummy_state_); + SetHasFullContext(1, dummy_state_); + SetUnscoredSize(0, dummy_state_); + dummy_ants_[1] = state; + LookupWords(*dummy_rule_, dummy_ants_, feats, NULL, NULL); + } else { // rules DO produce <s> ... </s> +#if 0 + double p = 0; + if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } + if (UnscoredSize(state) > 0) { // are there unscored words + if (kSOS_ != IthUnscoredWord(0, state)) { + p -= 100 * UnscoredSize(state); + } + } + return p; +#endif + } + } + + public: + explicit NgramDetectorImpl(bool explicit_markers) : + kCDEC_UNK(TD::Convert("<unk>")) , + add_sos_eos_(!explicit_markers) { + order_ = 3; + state_size_ = (order_ - 1) * sizeof(WordID) + 2 + (order_ - 1) * sizeof(WordID); + unscored_size_offset_ = (order_ - 1) * sizeof(WordID); + is_complete_offset_ = unscored_size_offset_ + 1; + unscored_words_offset_ = is_complete_offset_ + 1; + + // special handling of beginning / ending sentence markers + dummy_state_ = new char[state_size_]; + memset(dummy_state_, 0, state_size_); + dummy_ants_.push_back(dummy_state_); + dummy_ants_.push_back(NULL); + dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0")); + kSOS_ = TD::Convert("<s>"); + kEOS_ = TD::Convert("</s>"); + } + + ~NgramDetectorImpl() { + delete[] dummy_state_; + } + + int ReserveStateSize() const { return state_size_; } + + private: + const WordID kCDEC_UNK; + WordID kSOS_; // <s> - requires special handling. + WordID kEOS_; // </s> + const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s> + // if this is true, FinalTransitionFeatures will "add" <s> and </s> + // if false, FinalTransitionFeatures will score anything with the + // markers in the right place (i.e., the beginning and end of + // the sentence) with 0, and anything else with -100 + + int order_; + int state_size_; + int unscored_size_offset_; + int is_complete_offset_; + int unscored_words_offset_; + char* dummy_state_; + vector<const void*> dummy_ants_; + TRulePtr dummy_rule_; + struct FidTree { + map<WordID, int> fids; + map<WordID, FidTree> levels; + }; + mutable FidTree fidroot_; +}; + +NgramDetector::NgramDetector(const string& param) { + string filename, mapfile, featname; + bool explicit_markers = (param == "-x"); + pimpl_ = new NgramDetectorImpl(explicit_markers); + SetStateSize(pimpl_->ReserveStateSize()); +} + +NgramDetector::~NgramDetector() { + delete pimpl_; +} + +void NgramDetector::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* state) const { + pimpl_->LookupWords(*edge.rule_, ant_states, features, estimated_features, state); +} + +void NgramDetector::FinalTraversalFeatures(const void* ant_state, + SparseVector<double>* features) const { + pimpl_->FinalTraversal(ant_state, features); +} + diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h new file mode 100644 index 00000000..82f61b33 --- /dev/null +++ b/decoder/ff_ngrams.h @@ -0,0 +1,29 @@ +#ifndef _NGRAMS_FF_H_ +#define _NGRAMS_FF_H_ + +#include <vector> +#include <map> +#include <string> + +#include "ff.h" + +struct NgramDetectorImpl; +class NgramDetector : public FeatureFunction { + public: + // param = "filename.lm [-o n]" + NgramDetector(const std::string& param); + ~NgramDetector(); + virtual void FinalTraversalFeatures(const void* context, + SparseVector<double>* features) const; + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const; + private: + NgramDetectorImpl* pimpl_; +}; + +#endif diff --git a/decoder/ff_rules.cc b/decoder/ff_rules.cc new file mode 100644 index 00000000..bd4c4cc0 --- /dev/null +++ b/decoder/ff_rules.cc @@ -0,0 +1,107 @@ +#include "ff_rules.h" + +#include <sstream> +#include <cassert> +#include <cmath> + +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" +#include "verbose.h" + +using namespace std; + +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + +RuleIdentityFeatures::RuleIdentityFeatures(const std::string& param) { +} + +void RuleIdentityFeatures::PrepareForInput(const SentenceMetadata& smeta) { +// std::map<const TRule*, SparseVector<double> > + rule2_fid_.clear(); +} + +void RuleIdentityFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { + map<const TRule*, int>::iterator it = rule2_fid_.find(edge.rule_.get()); + if (it == rule2_fid_.end()) { + const TRule& rule = *edge.rule_; + ostringstream os; + os << "R:"; + if (rule.lhs_ < 0) os << TD::Convert(-rule.lhs_) << ':'; + for (unsigned i = 0; i < rule.f_.size(); ++i) { + if (i > 0) os << '_'; + WordID w = rule.f_[i]; + if (w < 0) { os << 'N'; w = -w; } + assert(w > 0); + os << TD::Convert(w); + } + os << ':'; + for (unsigned i = 0; i < rule.e_.size(); ++i) { + if (i > 0) os << '_'; + WordID w = rule.e_[i]; + if (w <= 0) { + os << 'N' << (1-w); + } else { + os << TD::Convert(w); + } + } + it = rule2_fid_.insert(make_pair(&rule, FD::Convert(Escape(os.str())))).first; + } + features->add_value(it->second, 1); +} + +RuleNgramFeatures::RuleNgramFeatures(const std::string& param) { +} + +void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) { +// std::map<const TRule*, SparseVector<double> > + rule2_feats_.clear(); +} + +void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { + map<const TRule*, SparseVector<double> >::iterator it = rule2_feats_.find(edge.rule_.get()); + if (it == rule2_feats_.end()) { + const TRule& rule = *edge.rule_; + it = rule2_feats_.insert(make_pair(&rule, SparseVector<double>())).first; + SparseVector<double>& f = it->second; + string prev = "<r>"; + for (int i = 0; i < rule.f_.size(); ++i) { + WordID w = rule.f_[i]; + if (w < 0) w = -w; + assert(w > 0); + const string& cur = TD::Convert(w); + ostringstream os; + os << "RB:" << prev << '_' << cur; + const int fid = FD::Convert(Escape(os.str())); + if (fid <= 0) return; + f.add_value(fid, 1.0); + prev = cur; + } + ostringstream os; + os << "RB:" << prev << '_' << "</r>"; + f.set_value(FD::Convert(Escape(os.str())), 1.0); + } + (*features) += it->second; +} + diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h new file mode 100644 index 00000000..48d8bd05 --- /dev/null +++ b/decoder/ff_rules.h @@ -0,0 +1,40 @@ +#ifndef _FF_RULES_H_ +#define _FF_RULES_H_ + +#include <vector> +#include <map> +#include "ff.h" +#include "array2d.h" +#include "wordid.h" + +class RuleIdentityFeatures : public FeatureFunction { + public: + RuleIdentityFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + mutable std::map<const TRule*, int> rule2_fid_; +}; + +class RuleNgramFeatures : public FeatureFunction { + public: + RuleNgramFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + mutable std::map<const TRule*, SparseVector<double> > rule2_feats_; +}; + +#endif diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc index e1da088d..0483517b 100644 --- a/decoder/ff_spans.cc +++ b/decoder/ff_spans.cc @@ -13,6 +13,17 @@ using namespace std; +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + // log transform to make long spans cluster together // but preserve differences int SpanSizeTransform(unsigned span_size) { @@ -140,19 +151,19 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { word = MapIfNecessary(word); ostringstream sfid; sfid << "ES:" << TD::Convert(word); - end_span_ids_[i] = FD::Convert(sfid.str()); + end_span_ids_[i] = FD::Convert(Escape(sfid.str())); ostringstream esbiid; esbiid << "EBI:" << TD::Convert(bword) << "_" << TD::Convert(word); - end_bigram_ids_[i] = FD::Convert(esbiid.str()); + end_bigram_ids_[i] = FD::Convert(Escape(esbiid.str())); ostringstream bsbiid; bsbiid << "BBI:" << TD::Convert(bword) << "_" << TD::Convert(word); - beg_bigram_ids_[i] = FD::Convert(bsbiid.str()); + beg_bigram_ids_[i] = FD::Convert(Escape(bsbiid.str())); ostringstream bfid; bfid << "BS:" << TD::Convert(bword); - beg_span_ids_[i] = FD::Convert(bfid.str()); + beg_span_ids_[i] = FD::Convert(Escape(bfid.str())); if (use_collapsed_features_) { - end_span_vals_[i] = feat2val_[sfid.str()] + feat2val_[esbiid.str()]; - beg_span_vals_[i] = feat2val_[bfid.str()] + feat2val_[bsbiid.str()]; + end_span_vals_[i] = feat2val_[Escape(sfid.str())] + feat2val_[Escape(esbiid.str())]; + beg_span_vals_[i] = feat2val_[Escape(bfid.str())] + feat2val_[Escape(bsbiid.str())]; } } for (int i = 0; i <= lattice.size(); ++i) { @@ -167,60 +178,21 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { word = MapIfNecessary(word); ostringstream pf; pf << "S:" << TD::Convert(bword) << "_" << TD::Convert(word); - span_feats_(i,j).first = FD::Convert(pf.str()); - span_feats_(i,j).second = FD::Convert("S_" + pf.str()); + span_feats_(i,j).first = FD::Convert(Escape(pf.str())); + span_feats_(i,j).second = FD::Convert(Escape("S_" + pf.str())); ostringstream lf; const unsigned span_size = (i < j ? j - i : i - j); lf << "LS:" << SpanSizeTransform(span_size) << "_" << TD::Convert(bword) << "_" << TD::Convert(word); - len_span_feats_(i,j).first = FD::Convert(lf.str()); - len_span_feats_(i,j).second = FD::Convert("S_" + lf.str()); + len_span_feats_(i,j).first = FD::Convert(Escape(lf.str())); + len_span_feats_(i,j).second = FD::Convert(Escape("S_" + lf.str())); if (use_collapsed_features_) { - span_vals_(i,j).first = feat2val_[pf.str()] + feat2val_[lf.str()]; - span_vals_(i,j).second = feat2val_["S_" + pf.str()] + feat2val_["S_" + lf.str()]; + span_vals_(i,j).first = feat2val_[Escape(pf.str())] + feat2val_[Escape(lf.str())]; + span_vals_(i,j).second = feat2val_[Escape("S_" + pf.str())] + feat2val_[Escape("S_" + lf.str())]; } } } } -RuleNgramFeatures::RuleNgramFeatures(const std::string& param) { -} - -void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) { -// std::map<const TRule*, SparseVector<double> > - rule2_feats_.clear(); -} - -void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const vector<const void*>& ant_contexts, - SparseVector<double>* features, - SparseVector<double>* estimated_features, - void* context) const { - map<const TRule*, SparseVector<double> >::iterator it = rule2_feats_.find(edge.rule_.get()); - if (it == rule2_feats_.end()) { - const TRule& rule = *edge.rule_; - it = rule2_feats_.insert(make_pair(&rule, SparseVector<double>())).first; - SparseVector<double>& f = it->second; - string prev = "<r>"; - for (int i = 0; i < rule.f_.size(); ++i) { - WordID w = rule.f_[i]; - if (w < 0) w = -w; - assert(w > 0); - const string& cur = TD::Convert(w); - ostringstream os; - os << "RB:" << prev << '_' << cur; - const int fid = FD::Convert(os.str()); - if (fid <= 0) return; - f.add_value(fid, 1.0); - prev = cur; - } - ostringstream os; - os << "RB:" << prev << '_' << "</r>"; - f.set_value(FD::Convert(os.str()), 1.0); - } - (*features) += it->second; -} - inline bool IsArity2RuleReordered(const TRule& rule) { const vector<WordID>& e = rule.e_; for (int i = 0; i < e.size(); ++i) { diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h index b22c4d03..24e0dede 100644 --- a/decoder/ff_spans.h +++ b/decoder/ff_spans.h @@ -44,21 +44,6 @@ class SpanFeatures : public FeatureFunction { WordID oov_; }; -class RuleNgramFeatures : public FeatureFunction { - public: - RuleNgramFeatures(const std::string& param); - protected: - virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector<const void*>& ant_contexts, - SparseVector<double>* features, - SparseVector<double>* estimated_features, - void* context) const; - virtual void PrepareForInput(const SentenceMetadata& smeta); - private: - mutable std::map<const TRule*, SparseVector<double> > rule2_feats_; -}; - class CMR2008ReorderingFeatures : public FeatureFunction { public: CMR2008ReorderingFeatures(const std::string& param); |