blob: 0a941527333989192de2c33ed15482605820ae66 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
|
#include "search/rule.hh"
#include "search/context.hh"
#include "search/final.hh"
#include <ostream>
#include <math.h>
namespace search {
template <class Model> void Rule::Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos) {
additive_ = additive;
Score lm_score = 0.0;
lexical_.clear();
const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound();
for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) {
lexical_.resize(lexical_.size() + 1);
lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back());
// TODO: optimize
if (prepend_bos && (word == words.begin())) {
scorer.BeginSentence();
}
for (; ; ++word) {
if (word == words.end()) {
lm_score += scorer.Finish();
bound_ = additive_ + context.GetWeights().LM() * lm_score;
arity_ = lexical_.size() - 1;
return;
}
if (*word == kNonTerminal) break;
if (*word == oov) additive_ += context.GetWeights().OOV();
scorer.Terminal(*word);
}
lm_score += scorer.Finish();
}
}
template void Rule::Init(const Context<lm::ngram::RestProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
template void Rule::Init(const Context<lm::ngram::ProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
} // namespace search
|