From 143ba7317dcaee3058d66f9e6558316f88f95212 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 12 Sep 2012 12:01:26 +0100 Subject: Refactor search so that it knows even less, but keeps track of edge pointers --- klm/search/rule.cc | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) (limited to 'klm/search/rule.cc') diff --git a/klm/search/rule.cc b/klm/search/rule.cc index a8b993eb..0a941527 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,47 +9,35 @@ namespace search { -template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos) { +template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos) { additive_ = additive; Score lm_score = 0.0; lexical_.clear(); const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); - for (std::vector::const_iterator word = items_.begin(); ; ++word) { + for (std::vector::const_iterator word = words.begin(); ; ++word) { lexical_.resize(lexical_.size() + 1); lm::ngram::RuleScore scorer(context.LanguageModel(), lexical_.back()); // TODO: optimize - if (prepend_bos && (word == items_.begin())) { + if (prepend_bos && (word == words.begin())) { scorer.BeginSentence(); } for (; ; ++word) { - if (word == items_.end()) { + if (word == words.end()) { lm_score += scorer.Finish(); bound_ = additive_ + context.GetWeights().LM() * lm_score; - assert(lexical_.size() == arity_ + 1); + arity_ = lexical_.size() - 1; return; } - if (!word->Terminal()) break; - if (word->Index() == oov) additive_ += context.GetWeights().OOV(); - scorer.Terminal(word->Index()); + if (*word == kNonTerminal) break; + if (*word == oov) additive_ += context.GetWeights().OOV(); + scorer.Terminal(*word); } lm_score += scorer.Finish(); } } -template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos); -template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos); - -std::ostream &operator<<(std::ostream &o, const Rule &rule) { - const Rule::ItemsRet &items = rule.Items(); - for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) { - if (i->Terminal()) { - o << i->String() << ' '; - } else { - o << "[] "; - } - } - return o; -} +template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); +template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); } // namespace search -- cgit v1.2.3