From de53e2e98acd0e2d07efb39bef430bd598908aa8 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 14 Dec 2012 12:39:04 -0800 Subject: Updated incremental, updated kenlm. Incremental assumes --- klm/search/rule.cc | 52 ++++++++++++++++++++++++++-------------------------- 1 file changed, 26 insertions(+), 26 deletions(-) (limited to 'klm/search/rule.cc') diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 5b00207e..0244a09f 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -1,7 +1,7 @@ #include "search/rule.hh" +#include "lm/model.hh" #include "search/context.hh" -#include "search/final.hh" #include @@ -9,35 +9,35 @@ namespace search { -template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing) { - unsigned int oov_count = 0; - float prob = 0.0; - const Model &model = context.LanguageModel(); - const lm::WordIndex oov = model.GetVocabulary().NotFound(); - for (std::vector::const_iterator word = words.begin(); ; ++word) { - lm::ngram::RuleScore scorer(model, *(writing++)); - // TODO: optimize - if (prepend_bos && (word == words.begin())) { - scorer.BeginSentence(); - } - for (; ; ++word) { - if (word == words.end()) { - prob += scorer.Finish(); - return static_cast(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); - } - if (*word == kNonTerminal) break; - if (*word == oov) ++oov_count; +template ScoreRuleRet ScoreRule(const Model &model, const std::vector &words, lm::ngram::ChartState *writing) { + ScoreRuleRet ret; + ret.prob = 0.0; + ret.oov = 0; + const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence(); + lm::ngram::RuleScore scorer(model, *(writing++)); + std::vector::const_iterator word = words.begin(); + if (word != words.end() && *word == bos) { + scorer.BeginSentence(); + ++word; + } + for (; word != words.end(); ++word) { + if (*word == kNonTerminal) { + ret.prob += scorer.Finish(); + scorer.Reset(*(writing++)); + } else { + if (*word == oov) ++ret.oov; scorer.Terminal(*word); } - prob += scorer.Finish(); } + ret.prob += scorer.Finish(); + return ret; } -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); } // namespace search -- cgit v1.2.3