blob: 1299f06ff05546ec80a1c5029b604919a76d284d (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
#include "ngram_base.h"
#include "lm/model.hh"
#include "tdict.h"
using namespace std;
namespace {
struct GICSVMapper : public lm::EnumerateVocab {
GICSVMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); }
void Add(lm::WordIndex index, const StringPiece &str) {
const WordID cdec_id = TD::Convert(str.as_string());
if (cdec_id >= out_->size())
out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN);
(*out_)[cdec_id] = index;
}
vector<lm::WordIndex>* out_;
const lm::WordIndex kLM_UNKNOWN_TOKEN;
};
}
struct FixedNgramBaseImpl {
FixedNgramBaseImpl(const string& param) {
GICSVMapper vm(&cdec2klm_map_);
lm::ngram::Config conf;
conf.enumerate_vocab = &vm;
cerr << "Reading character LM from " << param << endl;
model = new lm::ngram::ProbingModel(param.c_str(), conf);
order = model->Order();
kEOS = MapWord(TD::Convert("</s>"));
assert(kEOS > 0);
}
lm::WordIndex MapWord(const WordID w) const {
if (w < cdec2klm_map_.size()) return cdec2klm_map_[w];
return 0;
}
~FixedNgramBaseImpl() { delete model; }
prob_t StringProbability(const vector<WordID>& s) const {
lm::ngram::State state = model->BeginSentenceState();
double prob = 0;
for (unsigned i = 0; i < s.size(); ++i) {
const lm::ngram::State scopy(state);
prob += model->Score(scopy, MapWord(s[i]), state);
}
const lm::ngram::State scopy(state);
prob += model->Score(scopy, kEOS, state);
prob_t p; p.logeq(prob * log(10));
return p;
}
lm::ngram::ProbingModel* model;
unsigned order;
vector<lm::WordIndex> cdec2klm_map_;
lm::WordIndex kEOS;
};
FixedNgramBase::~FixedNgramBase() { delete impl; }
FixedNgramBase::FixedNgramBase(const string& lmfname) {
impl = new FixedNgramBaseImpl(lmfname);
}
prob_t FixedNgramBase::StringProbability(const vector<WordID>& s) const {
return impl->StringProbability(s);
}
|