From 1ffcac39647bdc13e6f6ef73ade6b88d59a08101 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 23 Dec 2010 14:50:41 -0600 Subject: support different types in kenlm --- decoder/ff_klm.cc | 44 +++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 17 deletions(-) (limited to 'decoder/ff_klm.cc') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 092c07b0..5049f156 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -4,12 +4,12 @@ #include "hg.h" #include "tdict.h" -#include "lm/model.hh" #include "lm/enumerate_vocab.hh" using namespace std; -string KLanguageModel::usage(bool param,bool verbose) { +template +string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; } @@ -25,6 +25,7 @@ struct VMapper : public lm::ngram::EnumerateVocab { const lm::WordIndex kLM_UNKNOWN_TOKEN; }; +template class KLanguageModelImpl { // returns the number of unscored words at the left edge of a span @@ -36,11 +37,11 @@ class KLanguageModelImpl { *(static_cast(state) + unscored_size_offset_) = size; } - static inline const lm::ngram::Model::State& RemnantLMState(const void* state) { - return *static_cast(state); + static inline const lm::ngram::State& RemnantLMState(const void* state) { + return *static_cast(state); } - inline void SetRemnantLMState(const lm::ngram::Model::State& lmstate, void* state) const { + inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const { // if we were clever, we could use the memory pointed to by state to do all // the work, avoiding this copy memcpy(state, &lmstate, ngram_->StateSize()); @@ -68,10 +69,9 @@ class KLanguageModelImpl { double LookupWords(const TRule& rule, const vector& ant_states, double* pest_sum, void* remnant) { double sum = 0.0; double est_sum = 0.0; - int len = rule.ELength() - rule.Arity(); int num_scored = 0; int num_estimated = 0; - lm::ngram::Model::State state = ngram_->NullContextState(); + lm::ngram::State state = ngram_->NullContextState(); const vector& e = rule.e(); bool context_complete = false; for (int j = 0; j < e.size(); ++j) { @@ -80,7 +80,7 @@ class KLanguageModelImpl { int unscored_ant_len = UnscoredSize(astate); for (int k = 0; k < unscored_ant_len; ++k) { const lm::WordIndex cur_word = IthUnscoredWord(k, astate); - const lm::ngram::Model::State scopy(state); + const lm::ngram::State scopy(state); const double p = ngram_->Score(scopy, cur_word, state); ++num_scored; if (!context_complete) { @@ -101,7 +101,7 @@ class KLanguageModelImpl { } } else { const lm::WordIndex cur_word = MapWord(e[j]); - const lm::ngram::Model::State scopy(state); + const lm::ngram::State scopy(state); const double p = ngram_->Score(scopy, cur_word, state); ++num_scored; if (!context_complete) { @@ -149,7 +149,7 @@ class KLanguageModelImpl { lm::ngram::Config conf; VMapper vm(&map_); conf.enumerate_vocab = &vm; - ngram_ = new lm::ngram::Model(param.c_str(), conf); + ngram_ = new Model(param.c_str(), conf); order_ = ngram_->Order(); cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n"; state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); @@ -172,7 +172,7 @@ class KLanguageModelImpl { int ReserveStateSize() const { return state_size_; } private: - lm::ngram::Model* ngram_; + Model* ngram_; int order_; int state_size_; int unscored_size_offset_; @@ -184,21 +184,25 @@ class KLanguageModelImpl { TRulePtr dummy_rule_; }; -KLanguageModel::KLanguageModel(const string& param) { - pimpl_ = new KLanguageModelImpl(param); +template +KLanguageModel::KLanguageModel(const string& param) { + pimpl_ = new KLanguageModelImpl(param); fid_ = FD::Convert("LanguageModel"); SetStateSize(pimpl_->ReserveStateSize()); } -Features KLanguageModel::features() const { +template +Features KLanguageModel::features() const { return single_feature(fid_); } -KLanguageModel::~KLanguageModel() { +template +KLanguageModel::~KLanguageModel() { delete pimpl_; } -void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, +template +void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, const Hypergraph::Edge& edge, const vector& ant_states, SparseVector* features, @@ -209,8 +213,14 @@ void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, estimated_features->set_value(fid_, est); } -void KLanguageModel::FinalTraversalFeatures(const void* ant_state, +template +void KLanguageModel::FinalTraversalFeatures(const void* ant_state, SparseVector* features) const { features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); } +// instantiate templates +template class KLanguageModel; +template class KLanguageModel; +template class KLanguageModel; + -- cgit v1.2.3