summaryrefslogtreecommitdiff
path: root/klm/search/rule.cc
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-10-14 10:46:34 +0100
committerKenneth Heafield <github@kheafield.com>2012-10-14 10:46:34 +0100
commitf568b392f82fd94b788a1b38094855234d318205 (patch)
tree37c09a6c2a006afd719e04009b0908117960354a /klm/search/rule.cc
parentb7005385f267596436dd07b1a9e798023ef1c30a (diff)
Update to faster but less cute search
Diffstat (limited to 'klm/search/rule.cc')
-rw-r--r--klm/search/rule.cc32
1 files changed, 16 insertions, 16 deletions
diff --git a/klm/search/rule.cc b/klm/search/rule.cc
index 0a941527..5b00207e 100644
--- a/klm/search/rule.cc
+++ b/klm/search/rule.cc
@@ -9,35 +9,35 @@
namespace search {
-template <class Model> void Rule::Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos) {
- additive_ = additive;
- Score lm_score = 0.0;
- lexical_.clear();
- const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound();
-
+template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) {
+ unsigned int oov_count = 0;
+ float prob = 0.0;
+ const Model &model = context.LanguageModel();
+ const lm::WordIndex oov = model.GetVocabulary().NotFound();
for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) {
- lexical_.resize(lexical_.size() + 1);
- lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back());
+ lm::ngram::RuleScore<Model> scorer(model, *(writing++));
// TODO: optimize
if (prepend_bos && (word == words.begin())) {
scorer.BeginSentence();
}
for (; ; ++word) {
if (word == words.end()) {
- lm_score += scorer.Finish();
- bound_ = additive_ + context.GetWeights().LM() * lm_score;
- arity_ = lexical_.size() - 1;
- return;
+ prob += scorer.Finish();
+ return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM();
}
if (*word == kNonTerminal) break;
- if (*word == oov) additive_ += context.GetWeights().OOV();
+ if (*word == oov) ++oov_count;
scorer.Terminal(*word);
}
- lm_score += scorer.Finish();
+ prob += scorer.Finish();
}
}
-template void Rule::Init(const Context<lm::ngram::RestProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
-template void Rule::Init(const Context<lm::ngram::ProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
+template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
} // namespace search