summaryrefslogtreecommitdiff
path: root/decoder/ff_klm.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/ff_klm.cc')
-rw-r--r--decoder/ff_klm.cc44
1 files changed, 27 insertions, 17 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 092c07b0..5049f156 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -4,12 +4,12 @@
#include "hg.h"
#include "tdict.h"
-#include "lm/model.hh"
#include "lm/enumerate_vocab.hh"
using namespace std;
-string KLanguageModel::usage(bool param,bool verbose) {
+template <class Model>
+string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) {
return "KLanguageModel";
}
@@ -25,6 +25,7 @@ struct VMapper : public lm::ngram::EnumerateVocab {
const lm::WordIndex kLM_UNKNOWN_TOKEN;
};
+template <class Model>
class KLanguageModelImpl {
// returns the number of unscored words at the left edge of a span
@@ -36,11 +37,11 @@ class KLanguageModelImpl {
*(static_cast<char*>(state) + unscored_size_offset_) = size;
}
- static inline const lm::ngram::Model::State& RemnantLMState(const void* state) {
- return *static_cast<const lm::ngram::Model::State*>(state);
+ static inline const lm::ngram::State& RemnantLMState(const void* state) {
+ return *static_cast<const lm::ngram::State*>(state);
}
- inline void SetRemnantLMState(const lm::ngram::Model::State& lmstate, void* state) const {
+ inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const {
// if we were clever, we could use the memory pointed to by state to do all
// the work, avoiding this copy
memcpy(state, &lmstate, ngram_->StateSize());
@@ -68,10 +69,9 @@ class KLanguageModelImpl {
double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, void* remnant) {
double sum = 0.0;
double est_sum = 0.0;
- int len = rule.ELength() - rule.Arity();
int num_scored = 0;
int num_estimated = 0;
- lm::ngram::Model::State state = ngram_->NullContextState();
+ lm::ngram::State state = ngram_->NullContextState();
const vector<WordID>& e = rule.e();
bool context_complete = false;
for (int j = 0; j < e.size(); ++j) {
@@ -80,7 +80,7 @@ class KLanguageModelImpl {
int unscored_ant_len = UnscoredSize(astate);
for (int k = 0; k < unscored_ant_len; ++k) {
const lm::WordIndex cur_word = IthUnscoredWord(k, astate);
- const lm::ngram::Model::State scopy(state);
+ const lm::ngram::State scopy(state);
const double p = ngram_->Score(scopy, cur_word, state);
++num_scored;
if (!context_complete) {
@@ -101,7 +101,7 @@ class KLanguageModelImpl {
}
} else {
const lm::WordIndex cur_word = MapWord(e[j]);
- const lm::ngram::Model::State scopy(state);
+ const lm::ngram::State scopy(state);
const double p = ngram_->Score(scopy, cur_word, state);
++num_scored;
if (!context_complete) {
@@ -149,7 +149,7 @@ class KLanguageModelImpl {
lm::ngram::Config conf;
VMapper vm(&map_);
conf.enumerate_vocab = &vm;
- ngram_ = new lm::ngram::Model(param.c_str(), conf);
+ ngram_ = new Model(param.c_str(), conf);
order_ = ngram_->Order();
cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n";
state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex);
@@ -172,7 +172,7 @@ class KLanguageModelImpl {
int ReserveStateSize() const { return state_size_; }
private:
- lm::ngram::Model* ngram_;
+ Model* ngram_;
int order_;
int state_size_;
int unscored_size_offset_;
@@ -184,21 +184,25 @@ class KLanguageModelImpl {
TRulePtr dummy_rule_;
};
-KLanguageModel::KLanguageModel(const string& param) {
- pimpl_ = new KLanguageModelImpl(param);
+template <class Model>
+KLanguageModel<Model>::KLanguageModel(const string& param) {
+ pimpl_ = new KLanguageModelImpl<Model>(param);
fid_ = FD::Convert("LanguageModel");
SetStateSize(pimpl_->ReserveStateSize());
}
-Features KLanguageModel::features() const {
+template <class Model>
+Features KLanguageModel<Model>::features() const {
return single_feature(fid_);
}
-KLanguageModel::~KLanguageModel() {
+template <class Model>
+KLanguageModel<Model>::~KLanguageModel() {
delete pimpl_;
}
-void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
+template <class Model>
+void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
const Hypergraph::Edge& edge,
const vector<const void*>& ant_states,
SparseVector<double>* features,
@@ -209,8 +213,14 @@ void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
estimated_features->set_value(fid_, est);
}
-void KLanguageModel::FinalTraversalFeatures(const void* ant_state,
+template <class Model>
+void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state,
SparseVector<double>* features) const {
features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state));
}
+// instantiate templates
+template class KLanguageModel<lm::ngram::ProbingModel>;
+template class KLanguageModel<lm::ngram::SortedModel>;
+template class KLanguageModel<lm::ngram::TrieModel>;
+