summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2010-12-23 14:50:41 -0600
committerChris Dyer <cdyer@cs.cmu.edu>2010-12-23 14:50:41 -0600
commitd4907ddee2012dce728bd1a6eb4e6cad452a54b2 (patch)
treed8730fb06078dff0a39b432f1761fb6f1631a509 /decoder
parent61116d4ce5f6a3ea9ae1c4d5a5b97a954d486597 (diff)
support different types in kenlm
Diffstat (limited to 'decoder')
-rw-r--r--decoder/cdec_ff.cc5
-rw-r--r--decoder/ff_klm.cc44
-rw-r--r--decoder/ff_klm.h8
3 files changed, 37 insertions, 20 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 75591af8..686905ad 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -35,7 +35,6 @@ void register_feature_functions() {
RegisterFsaDynToFF<SameFirstLetter>();
RegisterFF<LanguageModel>();
- RegisterFF<KLanguageModel>();
RegisterFF<WordPenalty>();
RegisterFF<SourceWordPenalty>();
@@ -48,6 +47,10 @@ void register_feature_functions() {
#ifdef HAVE_RANDLM
ff_registry.Register("RandLM", new FFFactory<LanguageModelRandLM>);
#endif
+ ff_registry.Register("KLanguageModel", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >());
+ ff_registry.Register("KLanguageModel_Sorted", new FFFactory<KLanguageModel<lm::ngram::SortedModel> >());
+ ff_registry.Register("KLanguageModel_Trie", new FFFactory<KLanguageModel<lm::ngram::TrieModel> >());
+ ff_registry.Register("KLanguageModel_Probing", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >());
ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>);
ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);
ff_registry.Register("LexNullJump", new FFFactory<LexNullJump>);
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 092c07b0..5049f156 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -4,12 +4,12 @@
#include "hg.h"
#include "tdict.h"
-#include "lm/model.hh"
#include "lm/enumerate_vocab.hh"
using namespace std;
-string KLanguageModel::usage(bool param,bool verbose) {
+template <class Model>
+string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) {
return "KLanguageModel";
}
@@ -25,6 +25,7 @@ struct VMapper : public lm::ngram::EnumerateVocab {
const lm::WordIndex kLM_UNKNOWN_TOKEN;
};
+template <class Model>
class KLanguageModelImpl {
// returns the number of unscored words at the left edge of a span
@@ -36,11 +37,11 @@ class KLanguageModelImpl {
*(static_cast<char*>(state) + unscored_size_offset_) = size;
}
- static inline const lm::ngram::Model::State& RemnantLMState(const void* state) {
- return *static_cast<const lm::ngram::Model::State*>(state);
+ static inline const lm::ngram::State& RemnantLMState(const void* state) {
+ return *static_cast<const lm::ngram::State*>(state);
}
- inline void SetRemnantLMState(const lm::ngram::Model::State& lmstate, void* state) const {
+ inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const {
// if we were clever, we could use the memory pointed to by state to do all
// the work, avoiding this copy
memcpy(state, &lmstate, ngram_->StateSize());
@@ -68,10 +69,9 @@ class KLanguageModelImpl {
double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, void* remnant) {
double sum = 0.0;
double est_sum = 0.0;
- int len = rule.ELength() - rule.Arity();
int num_scored = 0;
int num_estimated = 0;
- lm::ngram::Model::State state = ngram_->NullContextState();
+ lm::ngram::State state = ngram_->NullContextState();
const vector<WordID>& e = rule.e();
bool context_complete = false;
for (int j = 0; j < e.size(); ++j) {
@@ -80,7 +80,7 @@ class KLanguageModelImpl {
int unscored_ant_len = UnscoredSize(astate);
for (int k = 0; k < unscored_ant_len; ++k) {
const lm::WordIndex cur_word = IthUnscoredWord(k, astate);
- const lm::ngram::Model::State scopy(state);
+ const lm::ngram::State scopy(state);
const double p = ngram_->Score(scopy, cur_word, state);
++num_scored;
if (!context_complete) {
@@ -101,7 +101,7 @@ class KLanguageModelImpl {
}
} else {
const lm::WordIndex cur_word = MapWord(e[j]);
- const lm::ngram::Model::State scopy(state);
+ const lm::ngram::State scopy(state);
const double p = ngram_->Score(scopy, cur_word, state);
++num_scored;
if (!context_complete) {
@@ -149,7 +149,7 @@ class KLanguageModelImpl {
lm::ngram::Config conf;
VMapper vm(&map_);
conf.enumerate_vocab = &vm;
- ngram_ = new lm::ngram::Model(param.c_str(), conf);
+ ngram_ = new Model(param.c_str(), conf);
order_ = ngram_->Order();
cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n";
state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex);
@@ -172,7 +172,7 @@ class KLanguageModelImpl {
int ReserveStateSize() const { return state_size_; }
private:
- lm::ngram::Model* ngram_;
+ Model* ngram_;
int order_;
int state_size_;
int unscored_size_offset_;
@@ -184,21 +184,25 @@ class KLanguageModelImpl {
TRulePtr dummy_rule_;
};
-KLanguageModel::KLanguageModel(const string& param) {
- pimpl_ = new KLanguageModelImpl(param);
+template <class Model>
+KLanguageModel<Model>::KLanguageModel(const string& param) {
+ pimpl_ = new KLanguageModelImpl<Model>(param);
fid_ = FD::Convert("LanguageModel");
SetStateSize(pimpl_->ReserveStateSize());
}
-Features KLanguageModel::features() const {
+template <class Model>
+Features KLanguageModel<Model>::features() const {
return single_feature(fid_);
}
-KLanguageModel::~KLanguageModel() {
+template <class Model>
+KLanguageModel<Model>::~KLanguageModel() {
delete pimpl_;
}
-void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
+template <class Model>
+void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
const Hypergraph::Edge& edge,
const vector<const void*>& ant_states,
SparseVector<double>* features,
@@ -209,8 +213,14 @@ void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
estimated_features->set_value(fid_, est);
}
-void KLanguageModel::FinalTraversalFeatures(const void* ant_state,
+template <class Model>
+void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state,
SparseVector<double>* features) const {
features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state));
}
+// instantiate templates
+template class KLanguageModel<lm::ngram::ProbingModel>;
+template class KLanguageModel<lm::ngram::SortedModel>;
+template class KLanguageModel<lm::ngram::TrieModel>;
+
diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h
index 0569286f..95e1e897 100644
--- a/decoder/ff_klm.h
+++ b/decoder/ff_klm.h
@@ -5,9 +5,13 @@
#include <string>
#include "ff.h"
+#include "lm/model.hh"
-struct KLanguageModelImpl;
+template <class Model> struct KLanguageModelImpl;
+// the supported template types are instantiated explicitly
+// in ff_klm.cc.
+template <class Model>
class KLanguageModel : public FeatureFunction {
public:
// param = "filename.lm [-o n]"
@@ -26,7 +30,7 @@ class KLanguageModel : public FeatureFunction {
void* out_context) const;
private:
int fid_; // conceptually const; mutable only to simplify constructor
- KLanguageModelImpl* pimpl_;
+ KLanguageModelImpl<Model>* pimpl_;
};
#endif