diff options
-rw-r--r-- | decoder/cdec_ff.cc | 5 | ||||
-rw-r--r-- | decoder/ff_klm.cc | 44 | ||||
-rw-r--r-- | decoder/ff_klm.h | 8 |
3 files changed, 37 insertions, 20 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 75591af8..686905ad 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -35,7 +35,6 @@ void register_feature_functions() { RegisterFsaDynToFF<SameFirstLetter>(); RegisterFF<LanguageModel>(); - RegisterFF<KLanguageModel>(); RegisterFF<WordPenalty>(); RegisterFF<SourceWordPenalty>(); @@ -48,6 +47,10 @@ void register_feature_functions() { #ifdef HAVE_RANDLM ff_registry.Register("RandLM", new FFFactory<LanguageModelRandLM>); #endif + ff_registry.Register("KLanguageModel", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >()); + ff_registry.Register("KLanguageModel_Sorted", new FFFactory<KLanguageModel<lm::ngram::SortedModel> >()); + ff_registry.Register("KLanguageModel_Trie", new FFFactory<KLanguageModel<lm::ngram::TrieModel> >()); + ff_registry.Register("KLanguageModel_Probing", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >()); ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>); ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>); ff_registry.Register("LexNullJump", new FFFactory<LexNullJump>); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 092c07b0..5049f156 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -4,12 +4,12 @@ #include "hg.h" #include "tdict.h" -#include "lm/model.hh" #include "lm/enumerate_vocab.hh" using namespace std; -string KLanguageModel::usage(bool param,bool verbose) { +template <class Model> +string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; } @@ -25,6 +25,7 @@ struct VMapper : public lm::ngram::EnumerateVocab { const lm::WordIndex kLM_UNKNOWN_TOKEN; }; +template <class Model> class KLanguageModelImpl { // returns the number of unscored words at the left edge of a span @@ -36,11 +37,11 @@ class KLanguageModelImpl { *(static_cast<char*>(state) + unscored_size_offset_) = size; } - static inline const lm::ngram::Model::State& RemnantLMState(const void* state) { - return *static_cast<const lm::ngram::Model::State*>(state); + static inline const lm::ngram::State& RemnantLMState(const void* state) { + return *static_cast<const lm::ngram::State*>(state); } - inline void SetRemnantLMState(const lm::ngram::Model::State& lmstate, void* state) const { + inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const { // if we were clever, we could use the memory pointed to by state to do all // the work, avoiding this copy memcpy(state, &lmstate, ngram_->StateSize()); @@ -68,10 +69,9 @@ class KLanguageModelImpl { double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, void* remnant) { double sum = 0.0; double est_sum = 0.0; - int len = rule.ELength() - rule.Arity(); int num_scored = 0; int num_estimated = 0; - lm::ngram::Model::State state = ngram_->NullContextState(); + lm::ngram::State state = ngram_->NullContextState(); const vector<WordID>& e = rule.e(); bool context_complete = false; for (int j = 0; j < e.size(); ++j) { @@ -80,7 +80,7 @@ class KLanguageModelImpl { int unscored_ant_len = UnscoredSize(astate); for (int k = 0; k < unscored_ant_len; ++k) { const lm::WordIndex cur_word = IthUnscoredWord(k, astate); - const lm::ngram::Model::State scopy(state); + const lm::ngram::State scopy(state); const double p = ngram_->Score(scopy, cur_word, state); ++num_scored; if (!context_complete) { @@ -101,7 +101,7 @@ class KLanguageModelImpl { } } else { const lm::WordIndex cur_word = MapWord(e[j]); - const lm::ngram::Model::State scopy(state); + const lm::ngram::State scopy(state); const double p = ngram_->Score(scopy, cur_word, state); ++num_scored; if (!context_complete) { @@ -149,7 +149,7 @@ class KLanguageModelImpl { lm::ngram::Config conf; VMapper vm(&map_); conf.enumerate_vocab = &vm; - ngram_ = new lm::ngram::Model(param.c_str(), conf); + ngram_ = new Model(param.c_str(), conf); order_ = ngram_->Order(); cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n"; state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); @@ -172,7 +172,7 @@ class KLanguageModelImpl { int ReserveStateSize() const { return state_size_; } private: - lm::ngram::Model* ngram_; + Model* ngram_; int order_; int state_size_; int unscored_size_offset_; @@ -184,21 +184,25 @@ class KLanguageModelImpl { TRulePtr dummy_rule_; }; -KLanguageModel::KLanguageModel(const string& param) { - pimpl_ = new KLanguageModelImpl(param); +template <class Model> +KLanguageModel<Model>::KLanguageModel(const string& param) { + pimpl_ = new KLanguageModelImpl<Model>(param); fid_ = FD::Convert("LanguageModel"); SetStateSize(pimpl_->ReserveStateSize()); } -Features KLanguageModel::features() const { +template <class Model> +Features KLanguageModel<Model>::features() const { return single_feature(fid_); } -KLanguageModel::~KLanguageModel() { +template <class Model> +KLanguageModel<Model>::~KLanguageModel() { delete pimpl_; } -void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, +template <class Model> +void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, const Hypergraph::Edge& edge, const vector<const void*>& ant_states, SparseVector<double>* features, @@ -209,8 +213,14 @@ void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, estimated_features->set_value(fid_, est); } -void KLanguageModel::FinalTraversalFeatures(const void* ant_state, +template <class Model> +void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state, SparseVector<double>* features) const { features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); } +// instantiate templates +template class KLanguageModel<lm::ngram::ProbingModel>; +template class KLanguageModel<lm::ngram::SortedModel>; +template class KLanguageModel<lm::ngram::TrieModel>; + diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 0569286f..95e1e897 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -5,9 +5,13 @@ #include <string> #include "ff.h" +#include "lm/model.hh" -struct KLanguageModelImpl; +template <class Model> struct KLanguageModelImpl; +// the supported template types are instantiated explicitly +// in ff_klm.cc. +template <class Model> class KLanguageModel : public FeatureFunction { public: // param = "filename.lm [-o n]" @@ -26,7 +30,7 @@ class KLanguageModel : public FeatureFunction { void* out_context) const; private: int fid_; // conceptually const; mutable only to simplify constructor - KLanguageModelImpl* pimpl_; + KLanguageModelImpl<Model>* pimpl_; }; #endif |