diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-09-11 14:30:16 +0100 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-09-11 14:31:42 +0100 |
commit | 45c9efc7d558dbe160056f02e74df1fee5d2d0e5 (patch) | |
tree | 0c26a847b78b4c0d86507b21b7338beb98ff1e73 /klm/search/rule.cc | |
parent | 8882e9ebe158aef382bb5544559ef7f2a553db62 (diff) |
Add search library to cdec (not used yet)
Diffstat (limited to 'klm/search/rule.cc')
-rw-r--r-- | klm/search/rule.cc | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/klm/search/rule.cc b/klm/search/rule.cc new file mode 100644 index 00000000..a8b993eb --- /dev/null +++ b/klm/search/rule.cc @@ -0,0 +1,55 @@ +#include "search/rule.hh" + +#include "search/context.hh" +#include "search/final.hh" + +#include <ostream> + +#include <math.h> + +namespace search { + +template <class Model> void Rule::FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos) { + additive_ = additive; + Score lm_score = 0.0; + lexical_.clear(); + const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); + + for (std::vector<Word>::const_iterator word = items_.begin(); ; ++word) { + lexical_.resize(lexical_.size() + 1); + lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back()); + // TODO: optimize + if (prepend_bos && (word == items_.begin())) { + scorer.BeginSentence(); + } + for (; ; ++word) { + if (word == items_.end()) { + lm_score += scorer.Finish(); + bound_ = additive_ + context.GetWeights().LM() * lm_score; + assert(lexical_.size() == arity_ + 1); + return; + } + if (!word->Terminal()) break; + if (word->Index() == oov) additive_ += context.GetWeights().OOV(); + scorer.Terminal(word->Index()); + } + lm_score += scorer.Finish(); + } +} + +template void Rule::FinishedAdding(const Context<lm::ngram::RestProbingModel> &context, Score additive, bool prepend_bos); +template void Rule::FinishedAdding(const Context<lm::ngram::ProbingModel> &context, Score additive, bool prepend_bos); + +std::ostream &operator<<(std::ostream &o, const Rule &rule) { + const Rule::ItemsRet &items = rule.Items(); + for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) { + if (i->Terminal()) { + o << i->String() << ' '; + } else { + o << "[] "; + } + } + return o; +} + +} // namespace search |