diff options
-rw-r--r-- | gi/pf/ngram_base.cc | 69 | ||||
-rw-r--r-- | gi/pf/ngram_base.h | 25 |
2 files changed, 94 insertions, 0 deletions
diff --git a/gi/pf/ngram_base.cc b/gi/pf/ngram_base.cc new file mode 100644 index 00000000..1299f06f --- /dev/null +++ b/gi/pf/ngram_base.cc @@ -0,0 +1,69 @@ +#include "ngram_base.h" + +#include "lm/model.hh" +#include "tdict.h" + +using namespace std; + +namespace { +struct GICSVMapper : public lm::EnumerateVocab { + GICSVMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_->size()) + out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); + (*out_)[cdec_id] = index; + } + vector<lm::WordIndex>* out_; + const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; +} + +struct FixedNgramBaseImpl { + FixedNgramBaseImpl(const string& param) { + GICSVMapper vm(&cdec2klm_map_); + lm::ngram::Config conf; + conf.enumerate_vocab = &vm; + cerr << "Reading character LM from " << param << endl; + model = new lm::ngram::ProbingModel(param.c_str(), conf); + order = model->Order(); + kEOS = MapWord(TD::Convert("</s>")); + assert(kEOS > 0); + } + + lm::WordIndex MapWord(const WordID w) const { + if (w < cdec2klm_map_.size()) return cdec2klm_map_[w]; + return 0; + } + + ~FixedNgramBaseImpl() { delete model; } + + prob_t StringProbability(const vector<WordID>& s) const { + lm::ngram::State state = model->BeginSentenceState(); + double prob = 0; + for (unsigned i = 0; i < s.size(); ++i) { + const lm::ngram::State scopy(state); + prob += model->Score(scopy, MapWord(s[i]), state); + } + const lm::ngram::State scopy(state); + prob += model->Score(scopy, kEOS, state); + prob_t p; p.logeq(prob * log(10)); + return p; + } + + lm::ngram::ProbingModel* model; + unsigned order; + vector<lm::WordIndex> cdec2klm_map_; + lm::WordIndex kEOS; +}; + +FixedNgramBase::~FixedNgramBase() { delete impl; } + +FixedNgramBase::FixedNgramBase(const string& lmfname) { + impl = new FixedNgramBaseImpl(lmfname); +} + +prob_t FixedNgramBase::StringProbability(const vector<WordID>& s) const { + return impl->StringProbability(s); +} + diff --git a/gi/pf/ngram_base.h b/gi/pf/ngram_base.h new file mode 100644 index 00000000..4ea999f3 --- /dev/null +++ b/gi/pf/ngram_base.h @@ -0,0 +1,25 @@ +#ifndef _NGRAM_BASE_H_ +#define _NGRAM_BASE_H_ + +#include <string> +#include <vector> +#include "trule.h" +#include "wordid.h" +#include "prob.h" + +struct FixedNgramBaseImpl; +struct FixedNgramBase { + FixedNgramBase(const std::string& lmfname); + ~FixedNgramBase(); + prob_t StringProbability(const std::vector<WordID>& s) const; + + prob_t operator()(const TRule& rule) const { + return StringProbability(rule.e_); + } + + private: + FixedNgramBaseImpl* impl; + +}; + +#endif |