diff options
| author | Kenneth Heafield <github@kheafield.com> | 2011-10-18 10:25:56 +0100 | 
|---|---|---|
| committer | Kenneth Heafield <github@kheafield.com> | 2011-10-18 10:25:56 +0100 | 
| commit | 3d1ed02a4e5d81aace80b0e004e96351d116630f (patch) | |
| tree | 194d61e38362a90544e6349366957b632b1b3f5c | |
| parent | 957d90991b4ec80b9877126c736bd60768b094aa (diff) | |
Revised <s> and </s> handling
| -rw-r--r-- | decoder/ff_klm.cc | 84 | 
1 files changed, 58 insertions, 26 deletions
| diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 658aef80..3c941fbf 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,8 +12,8 @@  #include "lm/model.hh"  #include "lm/enumerate_vocab.hh" +#define NEW_KENLM  #undef NEW_KENLM -#ifdef NEW_KENLM  #include "lm/left.hh" @@ -95,14 +95,58 @@ struct BoundaryAnnotatedState {  #pragma pack(pop) -void BoundaryCheck(bool &annotated, bool sub, double &ret) { -  if (!sub) return; -  if (annotated) { -    ret -= 100.0; -  } else { -    annotated = true; -  } -} +template <class Model> class BoundaryRuleScore { +  public: +    BoundaryRuleScore(const Model &m, BoundaryAnnotatedState &state) :  +        back_(m, state.state), +        bos_(state.seen_bos), +        eos_(state.seen_eos), +        penalty_(0.0), +        end_sentence_(m.GetVocabulary().EndSentence()) { +      bos_ = false; +      eos_ = false; +    } + +    void BeginSentence() { +      back_.BeginSentence(); +      bos_ = true; +    } + +    void BeginNonTerminal(const BoundaryAnnotatedState &sub) { +      back_.BeginNonTerminal(sub.state, 0.0f); +      bos_ = sub.seen_bos; +      eos_ = sub.seen_eos; +    } + +    void NonTerminal(const BoundaryAnnotatedState &sub) { +      back_.NonTerminal(sub.state, 0.0f); +      // cdec only calls this if there's content.   +      if (sub.seen_bos) { +        bos_ = true; +        penalty_ -= 100.0f; +      } +      if (eos_) penalty_ -= 100.0f; +      eos_ |= sub.seen_eos; +    } + +    void Terminal(lm::WordIndex word) { +      back_.Terminal(word); +      if (eos_) penalty_ -= 100.0f; +      if (word == end_sentence_) eos_ = true; +    } + +    float Finish() { +      return penalty_ + back_.Finish(); +    } + +  private: +    lm::ngram::RuleScore<Model> back_; +    bool &bos_, &eos_; + +    float penalty_; + +    lm::WordIndex end_sentence_; +};  } // namespace @@ -112,42 +156,30 @@ class KLanguageModelImpl {    double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) {      *oovs = 0;      const vector<WordID>& e = rule.e(); -    BoundaryAnnotatedState &annotated = *static_cast<BoundaryAnnotatedState*>(remnant); -    lm::ngram::RuleScore<Model> ruleScore(*ngram_, annotated.state); -    annotated.seen_bos = false; -    annotated.seen_eos = false; +    BoundaryRuleScore<Model> ruleScore(*ngram_, *static_cast<BoundaryAnnotatedState*>(remnant));      unsigned i = 0; -    double ret = 0.0;      if (e.size()) {        if (e[i] == kCDEC_SOS) {          ++i;          ruleScore.BeginSentence(); -        annotated.seen_bos = true;        } else if (e[i] <= 0) {  // special case for left-edge NT -        const BoundaryAnnotatedState &sub = *static_cast<const BoundaryAnnotatedState*>(ant_states[-e[0]]); -        ruleScore.BeginNonTerminal(sub.state, 0.0f); -        annotated.seen_bos = sub.seen_bos; -        annotated.seen_eos = sub.seen_eos; +        ruleScore.BeginNonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[0]]));          ++i;        }      }      for (; i < e.size(); ++i) {        if (e[i] <= 0) { -        const BoundaryAnnotatedState &sub = *static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]]); -        ruleScore.NonTerminal(sub.state, 0.0f); -        BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret); -        BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret); +        ruleScore.NonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]]));        } else {          const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]);  // in future,                                                                            // maybe handle emission          const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id          if (cur_word == 0) (*oovs) += 1.0; -        BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret);          ruleScore.Terminal(cur_word);        }      } -    ret += ruleScore.Finish(); -    annotated.state.ZeroRemaining(); +    double ret = ruleScore.Finish(); +    static_cast<BoundaryAnnotatedState*>(remnant)->state.ZeroRemaining();      return ret;    } | 
