diff options
Diffstat (limited to 'decoder/ff_klm.cc')
-rw-r--r-- | decoder/ff_klm.cc | 44 |
1 files changed, 27 insertions, 17 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 092c07b0..5049f156 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -4,12 +4,12 @@ #include "hg.h" #include "tdict.h" -#include "lm/model.hh" #include "lm/enumerate_vocab.hh" using namespace std; -string KLanguageModel::usage(bool param,bool verbose) { +template <class Model> +string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; } @@ -25,6 +25,7 @@ struct VMapper : public lm::ngram::EnumerateVocab { const lm::WordIndex kLM_UNKNOWN_TOKEN; }; +template <class Model> class KLanguageModelImpl { // returns the number of unscored words at the left edge of a span @@ -36,11 +37,11 @@ class KLanguageModelImpl { *(static_cast<char*>(state) + unscored_size_offset_) = size; } - static inline const lm::ngram::Model::State& RemnantLMState(const void* state) { - return *static_cast<const lm::ngram::Model::State*>(state); + static inline const lm::ngram::State& RemnantLMState(const void* state) { + return *static_cast<const lm::ngram::State*>(state); } - inline void SetRemnantLMState(const lm::ngram::Model::State& lmstate, void* state) const { + inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const { // if we were clever, we could use the memory pointed to by state to do all // the work, avoiding this copy memcpy(state, &lmstate, ngram_->StateSize()); @@ -68,10 +69,9 @@ class KLanguageModelImpl { double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, void* remnant) { double sum = 0.0; double est_sum = 0.0; - int len = rule.ELength() - rule.Arity(); int num_scored = 0; int num_estimated = 0; - lm::ngram::Model::State state = ngram_->NullContextState(); + lm::ngram::State state = ngram_->NullContextState(); const vector<WordID>& e = rule.e(); bool context_complete = false; for (int j = 0; j < e.size(); ++j) { @@ -80,7 +80,7 @@ class KLanguageModelImpl { int unscored_ant_len = UnscoredSize(astate); for (int k = 0; k < unscored_ant_len; ++k) { const lm::WordIndex cur_word = IthUnscoredWord(k, astate); - const lm::ngram::Model::State scopy(state); + const lm::ngram::State scopy(state); const double p = ngram_->Score(scopy, cur_word, state); ++num_scored; if (!context_complete) { @@ -101,7 +101,7 @@ class KLanguageModelImpl { } } else { const lm::WordIndex cur_word = MapWord(e[j]); - const lm::ngram::Model::State scopy(state); + const lm::ngram::State scopy(state); const double p = ngram_->Score(scopy, cur_word, state); ++num_scored; if (!context_complete) { @@ -149,7 +149,7 @@ class KLanguageModelImpl { lm::ngram::Config conf; VMapper vm(&map_); conf.enumerate_vocab = &vm; - ngram_ = new lm::ngram::Model(param.c_str(), conf); + ngram_ = new Model(param.c_str(), conf); order_ = ngram_->Order(); cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n"; state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); @@ -172,7 +172,7 @@ class KLanguageModelImpl { int ReserveStateSize() const { return state_size_; } private: - lm::ngram::Model* ngram_; + Model* ngram_; int order_; int state_size_; int unscored_size_offset_; @@ -184,21 +184,25 @@ class KLanguageModelImpl { TRulePtr dummy_rule_; }; -KLanguageModel::KLanguageModel(const string& param) { - pimpl_ = new KLanguageModelImpl(param); +template <class Model> +KLanguageModel<Model>::KLanguageModel(const string& param) { + pimpl_ = new KLanguageModelImpl<Model>(param); fid_ = FD::Convert("LanguageModel"); SetStateSize(pimpl_->ReserveStateSize()); } -Features KLanguageModel::features() const { +template <class Model> +Features KLanguageModel<Model>::features() const { return single_feature(fid_); } -KLanguageModel::~KLanguageModel() { +template <class Model> +KLanguageModel<Model>::~KLanguageModel() { delete pimpl_; } -void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, +template <class Model> +void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, const Hypergraph::Edge& edge, const vector<const void*>& ant_states, SparseVector<double>* features, @@ -209,8 +213,14 @@ void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, estimated_features->set_value(fid_, est); } -void KLanguageModel::FinalTraversalFeatures(const void* ant_state, +template <class Model> +void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state, SparseVector<double>* features) const { features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); } +// instantiate templates +template class KLanguageModel<lm::ngram::ProbingModel>; +template class KLanguageModel<lm::ngram::SortedModel>; +template class KLanguageModel<lm::ngram::TrieModel>; + |