summaryrefslogtreecommitdiff
path: root/klm/search/rule.cc
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-09-12 12:01:26 +0100
committerKenneth Heafield <github@kheafield.com>2012-09-12 12:01:26 +0100
commitc26c35a9bcbb4d42ae50ad0a75c1b5fb59702bd1 (patch)
treeeced40cfee4bff7c4cd3fc644016e45f7903a55a /klm/search/rule.cc
parent2ca3db90bd0a2e9a8619d2ebec7c6ac723838aca (diff)
Refactor search so that it knows even less, but keeps track of edge pointers
Diffstat (limited to 'klm/search/rule.cc')
-rw-r--r--klm/search/rule.cc32
1 files changed, 10 insertions, 22 deletions
diff --git a/klm/search/rule.cc b/klm/search/rule.cc
index a8b993eb..0a941527 100644
--- a/klm/search/rule.cc
+++ b/klm/search/rule.cc
@@ -9,47 +9,35 @@
namespace search {
-template <class Model> void Rule::FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos) {
+template <class Model> void Rule::Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos) {
additive_ = additive;
Score lm_score = 0.0;
lexical_.clear();
const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound();
- for (std::vector<Word>::const_iterator word = items_.begin(); ; ++word) {
+ for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) {
lexical_.resize(lexical_.size() + 1);
lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back());
// TODO: optimize
- if (prepend_bos && (word == items_.begin())) {
+ if (prepend_bos && (word == words.begin())) {
scorer.BeginSentence();
}
for (; ; ++word) {
- if (word == items_.end()) {
+ if (word == words.end()) {
lm_score += scorer.Finish();
bound_ = additive_ + context.GetWeights().LM() * lm_score;
- assert(lexical_.size() == arity_ + 1);
+ arity_ = lexical_.size() - 1;
return;
}
- if (!word->Terminal()) break;
- if (word->Index() == oov) additive_ += context.GetWeights().OOV();
- scorer.Terminal(word->Index());
+ if (*word == kNonTerminal) break;
+ if (*word == oov) additive_ += context.GetWeights().OOV();
+ scorer.Terminal(*word);
}
lm_score += scorer.Finish();
}
}
-template void Rule::FinishedAdding(const Context<lm::ngram::RestProbingModel> &context, Score additive, bool prepend_bos);
-template void Rule::FinishedAdding(const Context<lm::ngram::ProbingModel> &context, Score additive, bool prepend_bos);
-
-std::ostream &operator<<(std::ostream &o, const Rule &rule) {
- const Rule::ItemsRet &items = rule.Items();
- for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) {
- if (i->Terminal()) {
- o << i->String() << ' ';
- } else {
- o << "[] ";
- }
- }
- return o;
-}
+template void Rule::Init(const Context<lm::ngram::RestProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
+template void Rule::Init(const Context<lm::ngram::ProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
} // namespace search