diff options
Diffstat (limited to 'klm/search/rule.cc')
-rw-r--r-- | klm/search/rule.cc | 32 |
1 files changed, 10 insertions, 22 deletions
diff --git a/klm/search/rule.cc b/klm/search/rule.cc index a8b993eb..0a941527 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,47 +9,35 @@ namespace search { -template <class Model> void Rule::FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos) { +template <class Model> void Rule::Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos) { additive_ = additive; Score lm_score = 0.0; lexical_.clear(); const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); - for (std::vector<Word>::const_iterator word = items_.begin(); ; ++word) { + for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) { lexical_.resize(lexical_.size() + 1); lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back()); // TODO: optimize - if (prepend_bos && (word == items_.begin())) { + if (prepend_bos && (word == words.begin())) { scorer.BeginSentence(); } for (; ; ++word) { - if (word == items_.end()) { + if (word == words.end()) { lm_score += scorer.Finish(); bound_ = additive_ + context.GetWeights().LM() * lm_score; - assert(lexical_.size() == arity_ + 1); + arity_ = lexical_.size() - 1; return; } - if (!word->Terminal()) break; - if (word->Index() == oov) additive_ += context.GetWeights().OOV(); - scorer.Terminal(word->Index()); + if (*word == kNonTerminal) break; + if (*word == oov) additive_ += context.GetWeights().OOV(); + scorer.Terminal(*word); } lm_score += scorer.Finish(); } } -template void Rule::FinishedAdding(const Context<lm::ngram::RestProbingModel> &context, Score additive, bool prepend_bos); -template void Rule::FinishedAdding(const Context<lm::ngram::ProbingModel> &context, Score additive, bool prepend_bos); - -std::ostream &operator<<(std::ostream &o, const Rule &rule) { - const Rule::ItemsRet &items = rule.Items(); - for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) { - if (i->Terminal()) { - o << i->String() << ' '; - } else { - o << "[] "; - } - } - return o; -} +template void Rule::Init(const Context<lm::ngram::RestProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); +template void Rule::Init(const Context<lm::ngram::ProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); } // namespace search |