From d4907ddee2012dce728bd1a6eb4e6cad452a54b2 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 23 Dec 2010 14:50:41 -0600 Subject: support different types in kenlm --- decoder/cdec_ff.cc | 5 ++++- decoder/ff_klm.cc | 44 +++++++++++++++++++++++++++----------------- decoder/ff_klm.h | 8 ++++++-- 3 files changed, 37 insertions(+), 20 deletions(-) (limited to 'decoder') diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 75591af8..686905ad 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -35,7 +35,6 @@ void register_feature_functions() { RegisterFsaDynToFF(); RegisterFF(); - RegisterFF(); RegisterFF(); RegisterFF(); @@ -48,6 +47,10 @@ void register_feature_functions() { #ifdef HAVE_RANDLM ff_registry.Register("RandLM", new FFFactory); #endif + ff_registry.Register("KLanguageModel", new FFFactory >()); + ff_registry.Register("KLanguageModel_Sorted", new FFFactory >()); + ff_registry.Register("KLanguageModel_Trie", new FFFactory >()); + ff_registry.Register("KLanguageModel_Probing", new FFFactory >()); ff_registry.Register("RuleShape", new FFFactory); ff_registry.Register("RelativeSentencePosition", new FFFactory); ff_registry.Register("LexNullJump", new FFFactory); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 092c07b0..5049f156 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -4,12 +4,12 @@ #include "hg.h" #include "tdict.h" -#include "lm/model.hh" #include "lm/enumerate_vocab.hh" using namespace std; -string KLanguageModel::usage(bool param,bool verbose) { +template +string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; } @@ -25,6 +25,7 @@ struct VMapper : public lm::ngram::EnumerateVocab { const lm::WordIndex kLM_UNKNOWN_TOKEN; }; +template class KLanguageModelImpl { // returns the number of unscored words at the left edge of a span @@ -36,11 +37,11 @@ class KLanguageModelImpl { *(static_cast(state) + unscored_size_offset_) = size; } - static inline const lm::ngram::Model::State& RemnantLMState(const void* state) { - return *static_cast(state); + static inline const lm::ngram::State& RemnantLMState(const void* state) { + return *static_cast(state); } - inline void SetRemnantLMState(const lm::ngram::Model::State& lmstate, void* state) const { + inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const { // if we were clever, we could use the memory pointed to by state to do all // the work, avoiding this copy memcpy(state, &lmstate, ngram_->StateSize()); @@ -68,10 +69,9 @@ class KLanguageModelImpl { double LookupWords(const TRule& rule, const vector& ant_states, double* pest_sum, void* remnant) { double sum = 0.0; double est_sum = 0.0; - int len = rule.ELength() - rule.Arity(); int num_scored = 0; int num_estimated = 0; - lm::ngram::Model::State state = ngram_->NullContextState(); + lm::ngram::State state = ngram_->NullContextState(); const vector& e = rule.e(); bool context_complete = false; for (int j = 0; j < e.size(); ++j) { @@ -80,7 +80,7 @@ class KLanguageModelImpl { int unscored_ant_len = UnscoredSize(astate); for (int k = 0; k < unscored_ant_len; ++k) { const lm::WordIndex cur_word = IthUnscoredWord(k, astate); - const lm::ngram::Model::State scopy(state); + const lm::ngram::State scopy(state); const double p = ngram_->Score(scopy, cur_word, state); ++num_scored; if (!context_complete) { @@ -101,7 +101,7 @@ class KLanguageModelImpl { } } else { const lm::WordIndex cur_word = MapWord(e[j]); - const lm::ngram::Model::State scopy(state); + const lm::ngram::State scopy(state); const double p = ngram_->Score(scopy, cur_word, state); ++num_scored; if (!context_complete) { @@ -149,7 +149,7 @@ class KLanguageModelImpl { lm::ngram::Config conf; VMapper vm(&map_); conf.enumerate_vocab = &vm; - ngram_ = new lm::ngram::Model(param.c_str(), conf); + ngram_ = new Model(param.c_str(), conf); order_ = ngram_->Order(); cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n"; state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); @@ -172,7 +172,7 @@ class KLanguageModelImpl { int ReserveStateSize() const { return state_size_; } private: - lm::ngram::Model* ngram_; + Model* ngram_; int order_; int state_size_; int unscored_size_offset_; @@ -184,21 +184,25 @@ class KLanguageModelImpl { TRulePtr dummy_rule_; }; -KLanguageModel::KLanguageModel(const string& param) { - pimpl_ = new KLanguageModelImpl(param); +template +KLanguageModel::KLanguageModel(const string& param) { + pimpl_ = new KLanguageModelImpl(param); fid_ = FD::Convert("LanguageModel"); SetStateSize(pimpl_->ReserveStateSize()); } -Features KLanguageModel::features() const { +template +Features KLanguageModel::features() const { return single_feature(fid_); } -KLanguageModel::~KLanguageModel() { +template +KLanguageModel::~KLanguageModel() { delete pimpl_; } -void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, +template +void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, const Hypergraph::Edge& edge, const vector& ant_states, SparseVector* features, @@ -209,8 +213,14 @@ void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, estimated_features->set_value(fid_, est); } -void KLanguageModel::FinalTraversalFeatures(const void* ant_state, +template +void KLanguageModel::FinalTraversalFeatures(const void* ant_state, SparseVector* features) const { features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); } +// instantiate templates +template class KLanguageModel; +template class KLanguageModel; +template class KLanguageModel; + diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 0569286f..95e1e897 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -5,9 +5,13 @@ #include #include "ff.h" +#include "lm/model.hh" -struct KLanguageModelImpl; +template struct KLanguageModelImpl; +// the supported template types are instantiated explicitly +// in ff_klm.cc. +template class KLanguageModel : public FeatureFunction { public: // param = "filename.lm [-o n]" @@ -26,7 +30,7 @@ class KLanguageModel : public FeatureFunction { void* out_context) const; private: int fid_; // conceptually const; mutable only to simplify constructor - KLanguageModelImpl* pimpl_; + KLanguageModelImpl* pimpl_; }; #endif -- cgit v1.2.3