From d5d86d6513e34f0b26030814c9eb516271ce2aec Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 2 Mar 2011 18:45:51 -0500 Subject: better handling of SOS/EOS markers, part 1 --- decoder/ff_klm.cc | 43 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 9ba2cbaa..a55a0b41 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -8,6 +8,10 @@ using namespace std; +static const unsigned char HAS_FULL_CONTEXT = 1; +static const unsigned char HAS_EOS_ON_RIGHT = 2; +static const unsigned char MASK = 7; + template string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; @@ -57,12 +61,24 @@ class KLanguageModelImpl { mem[i] = index; } - bool HasFullContext(const void *state) const { - return *(static_cast(state) + is_complete_offset_); + inline bool GetFlag(const void *state, unsigned char flag) const { + return (*(static_cast(state) + is_complete_offset_) & flag); + } + + inline void SetFlag(bool on, unsigned char flag, void *state) const { + if (on) { + *(static_cast(state) + is_complete_offset_) |= flag; + } else { + *(static_cast(state) + is_complete_offset_) &= (MASK ^ flag); + } + } + + inline bool HasFullContext(const void *state) const { + return GetFlag(state, HAS_FULL_CONTEXT); } - void SetHasFullContext(bool flag, void *state) const { - *(static_cast(state) + is_complete_offset_) = flag; + inline void SetHasFullContext(bool flag, void *state) const { + SetFlag(flag, HAS_FULL_CONTEXT, state); } public: @@ -71,6 +87,8 @@ class KLanguageModelImpl { double est_sum = 0.0; int num_scored = 0; int num_estimated = 0; + bool saw_eos = false; + bool has_some_history = false; lm::ngram::State state = ngram_->NullContextState(); const vector& e = rule.e(); bool context_complete = false; @@ -82,13 +100,16 @@ class KLanguageModelImpl { const lm::WordIndex cur_word = IthUnscoredWord(k, astate); double p = 0; if (cur_word == kSOS_) { - if (state.ValidLength() > 0) { p = -100; } + if (has_some_history) { p = -100; } state = ngram_->BeginSentenceState(); - context_complete = true; + if (!context_complete && num_scored < (order_ - 2)) num_scored = order_ - 2; } else { const lm::ngram::State scopy(state); p = ngram_->Score(scopy, cur_word, state); + if (saw_eos) { p = -100; } + saw_eos = (cur_word == kEOS_); } + has_some_history = true; ++num_scored; if (!context_complete) { if (num_scored >= order_) context_complete = true; @@ -102,6 +123,7 @@ class KLanguageModelImpl { est_sum += p; } } + saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 state = RemnantLMState(astate); context_complete = true; @@ -110,13 +132,16 @@ class KLanguageModelImpl { const lm::WordIndex cur_word = MapWord(e[j]); double p = 0; if (cur_word == kSOS_) { - if (state.ValidLength() > 0) p = -100; + if (has_some_history) p = -100; state = ngram_->BeginSentenceState(); - context_complete = true; + if (!context_complete && num_scored < (order_ - 2)) num_scored = order_ - 2; } else { const lm::ngram::State scopy(state); p = ngram_->Score(scopy, cur_word, state); + if (saw_eos) { p = -100; } + saw_eos = (cur_word == kEOS_); } + has_some_history = true; ++num_scored; if (!context_complete) { if (num_scored >= order_) context_complete = true; @@ -134,6 +159,7 @@ class KLanguageModelImpl { if (pest_sum) *pest_sum = est_sum; if (remnant) { state.ZeroRemaining(); + SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); SetRemnantLMState(state, remnant); SetUnscoredSize(num_estimated, remnant); SetHasFullContext(context_complete || (num_scored >= order_), remnant); @@ -190,6 +216,7 @@ class KLanguageModelImpl { kSOS_ = MapWord(TD::Convert("")); assert(kSOS_ > 0); kEOS_ = MapWord(TD::Convert("")); + assert(kEOS_ > 0); } ~KLanguageModelImpl() { -- cgit v1.2.3