diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-10-14 10:46:34 +0100 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-10-14 10:46:34 +0100 |
commit | f568b392f82fd94b788a1b38094855234d318205 (patch) | |
tree | 37c09a6c2a006afd719e04009b0908117960354a /klm/search/rule.cc | |
parent | b7005385f267596436dd07b1a9e798023ef1c30a (diff) |
Update to faster but less cute search
Diffstat (limited to 'klm/search/rule.cc')
-rw-r--r-- | klm/search/rule.cc | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 0a941527..5b00207e 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,35 +9,35 @@ namespace search { -template <class Model> void Rule::Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos) { - additive_ = additive; - Score lm_score = 0.0; - lexical_.clear(); - const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); - +template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) { + unsigned int oov_count = 0; + float prob = 0.0; + const Model &model = context.LanguageModel(); + const lm::WordIndex oov = model.GetVocabulary().NotFound(); for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) { - lexical_.resize(lexical_.size() + 1); - lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back()); + lm::ngram::RuleScore<Model> scorer(model, *(writing++)); // TODO: optimize if (prepend_bos && (word == words.begin())) { scorer.BeginSentence(); } for (; ; ++word) { if (word == words.end()) { - lm_score += scorer.Finish(); - bound_ = additive_ + context.GetWeights().LM() * lm_score; - arity_ = lexical_.size() - 1; - return; + prob += scorer.Finish(); + return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); } if (*word == kNonTerminal) break; - if (*word == oov) additive_ += context.GetWeights().OOV(); + if (*word == oov) ++oov_count; scorer.Terminal(*word); } - lm_score += scorer.Finish(); + prob += scorer.Finish(); } } -template void Rule::Init(const Context<lm::ngram::RestProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); -template void Rule::Init(const Context<lm::ngram::ProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); +template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); } // namespace search |