From 97520627383196a8ba8c4e0656573546f4a03fbc Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 18 Oct 2011 10:25:56 +0100 Subject: Revised and handling --- decoder/ff_klm.cc | 84 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 26 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 658aef80..3c941fbf 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,8 +12,8 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" +#define NEW_KENLM #undef NEW_KENLM -#ifdef NEW_KENLM #include "lm/left.hh" @@ -95,14 +95,58 @@ struct BoundaryAnnotatedState { #pragma pack(pop) -void BoundaryCheck(bool &annotated, bool sub, double &ret) { - if (!sub) return; - if (annotated) { - ret -= 100.0; - } else { - annotated = true; - } -} +template class BoundaryRuleScore { + public: + BoundaryRuleScore(const Model &m, BoundaryAnnotatedState &state) : + back_(m, state.state), + bos_(state.seen_bos), + eos_(state.seen_eos), + penalty_(0.0), + end_sentence_(m.GetVocabulary().EndSentence()) { + bos_ = false; + eos_ = false; + } + + void BeginSentence() { + back_.BeginSentence(); + bos_ = true; + } + + void BeginNonTerminal(const BoundaryAnnotatedState &sub) { + back_.BeginNonTerminal(sub.state, 0.0f); + bos_ = sub.seen_bos; + eos_ = sub.seen_eos; + } + + void NonTerminal(const BoundaryAnnotatedState &sub) { + back_.NonTerminal(sub.state, 0.0f); + // cdec only calls this if there's content. + if (sub.seen_bos) { + bos_ = true; + penalty_ -= 100.0f; + } + if (eos_) penalty_ -= 100.0f; + eos_ |= sub.seen_eos; + } + + void Terminal(lm::WordIndex word) { + back_.Terminal(word); + if (eos_) penalty_ -= 100.0f; + if (word == end_sentence_) eos_ = true; + } + + float Finish() { + return penalty_ + back_.Finish(); + } + + private: + lm::ngram::RuleScore back_; + bool &bos_, &eos_; + + float penalty_; + + lm::WordIndex end_sentence_; +}; } // namespace @@ -112,42 +156,30 @@ class KLanguageModelImpl { double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { *oovs = 0; const vector& e = rule.e(); - BoundaryAnnotatedState &annotated = *static_cast(remnant); - lm::ngram::RuleScore ruleScore(*ngram_, annotated.state); - annotated.seen_bos = false; - annotated.seen_eos = false; + BoundaryRuleScore ruleScore(*ngram_, *static_cast(remnant)); unsigned i = 0; - double ret = 0.0; if (e.size()) { if (e[i] == kCDEC_SOS) { ++i; ruleScore.BeginSentence(); - annotated.seen_bos = true; } else if (e[i] <= 0) { // special case for left-edge NT - const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[0]]); - ruleScore.BeginNonTerminal(sub.state, 0.0f); - annotated.seen_bos = sub.seen_bos; - annotated.seen_eos = sub.seen_eos; + ruleScore.BeginNonTerminal(*static_cast(ant_states[-e[0]])); ++i; } } for (; i < e.size(); ++i) { if (e[i] <= 0) { - const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[i]]); - ruleScore.NonTerminal(sub.state, 0.0f); - BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret); - BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret); + ruleScore.NonTerminal(*static_cast(ant_states[-e[i]])); } else { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id if (cur_word == 0) (*oovs) += 1.0; - BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret); ruleScore.Terminal(cur_word); } } - ret += ruleScore.Finish(); - annotated.state.ZeroRemaining(); + double ret = ruleScore.Finish(); + static_cast(remnant)->state.ZeroRemaining(); return ret; } -- cgit v1.2.3