diff options
-rw-r--r-- | decoder/ff_klm.cc | 84 |
1 files changed, 58 insertions, 26 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 658aef80..3c941fbf 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,8 +12,8 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" +#define NEW_KENLM #undef NEW_KENLM -#ifdef NEW_KENLM #include "lm/left.hh" @@ -95,14 +95,58 @@ struct BoundaryAnnotatedState { #pragma pack(pop) -void BoundaryCheck(bool &annotated, bool sub, double &ret) { - if (!sub) return; - if (annotated) { - ret -= 100.0; - } else { - annotated = true; - } -} +template <class Model> class BoundaryRuleScore { + public: + BoundaryRuleScore(const Model &m, BoundaryAnnotatedState &state) : + back_(m, state.state), + bos_(state.seen_bos), + eos_(state.seen_eos), + penalty_(0.0), + end_sentence_(m.GetVocabulary().EndSentence()) { + bos_ = false; + eos_ = false; + } + + void BeginSentence() { + back_.BeginSentence(); + bos_ = true; + } + + void BeginNonTerminal(const BoundaryAnnotatedState &sub) { + back_.BeginNonTerminal(sub.state, 0.0f); + bos_ = sub.seen_bos; + eos_ = sub.seen_eos; + } + + void NonTerminal(const BoundaryAnnotatedState &sub) { + back_.NonTerminal(sub.state, 0.0f); + // cdec only calls this if there's content. + if (sub.seen_bos) { + bos_ = true; + penalty_ -= 100.0f; + } + if (eos_) penalty_ -= 100.0f; + eos_ |= sub.seen_eos; + } + + void Terminal(lm::WordIndex word) { + back_.Terminal(word); + if (eos_) penalty_ -= 100.0f; + if (word == end_sentence_) eos_ = true; + } + + float Finish() { + return penalty_ + back_.Finish(); + } + + private: + lm::ngram::RuleScore<Model> back_; + bool &bos_, &eos_; + + float penalty_; + + lm::WordIndex end_sentence_; +}; } // namespace @@ -112,42 +156,30 @@ class KLanguageModelImpl { double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) { *oovs = 0; const vector<WordID>& e = rule.e(); - BoundaryAnnotatedState &annotated = *static_cast<BoundaryAnnotatedState*>(remnant); - lm::ngram::RuleScore<Model> ruleScore(*ngram_, annotated.state); - annotated.seen_bos = false; - annotated.seen_eos = false; + BoundaryRuleScore<Model> ruleScore(*ngram_, *static_cast<BoundaryAnnotatedState*>(remnant)); unsigned i = 0; - double ret = 0.0; if (e.size()) { if (e[i] == kCDEC_SOS) { ++i; ruleScore.BeginSentence(); - annotated.seen_bos = true; } else if (e[i] <= 0) { // special case for left-edge NT - const BoundaryAnnotatedState &sub = *static_cast<const BoundaryAnnotatedState*>(ant_states[-e[0]]); - ruleScore.BeginNonTerminal(sub.state, 0.0f); - annotated.seen_bos = sub.seen_bos; - annotated.seen_eos = sub.seen_eos; + ruleScore.BeginNonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[0]])); ++i; } } for (; i < e.size(); ++i) { if (e[i] <= 0) { - const BoundaryAnnotatedState &sub = *static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]]); - ruleScore.NonTerminal(sub.state, 0.0f); - BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret); - BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret); + ruleScore.NonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]])); } else { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id if (cur_word == 0) (*oovs) += 1.0; - BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret); ruleScore.Terminal(cur_word); } } - ret += ruleScore.Finish(); - annotated.state.ZeroRemaining(); + double ret = ruleScore.Finish(); + static_cast<BoundaryAnnotatedState*>(remnant)->state.ZeroRemaining(); return ret; } |