diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-02 18:45:51 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-02 18:45:51 -0500 |
commit | 9f8f30c5a8506272f9a3b74bc65c08e6cc62a4b2 (patch) | |
tree | 19923be8c752533daa8e8e3dee3b55de7dcd61de | |
parent | 965c26a8ce00a603e639a6609b5147256ab1a189 (diff) |
better handling of SOS/EOS markers, part 1
-rw-r--r-- | decoder/ff_klm.cc | 43 |
1 files changed, 35 insertions, 8 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 9ba2cbaa..a55a0b41 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -8,6 +8,10 @@ using namespace std; +static const unsigned char HAS_FULL_CONTEXT = 1; +static const unsigned char HAS_EOS_ON_RIGHT = 2; +static const unsigned char MASK = 7; + template <class Model> string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; @@ -57,12 +61,24 @@ class KLanguageModelImpl { mem[i] = index; } - bool HasFullContext(const void *state) const { - return *(static_cast<const char*>(state) + is_complete_offset_); + inline bool GetFlag(const void *state, unsigned char flag) const { + return (*(static_cast<const char*>(state) + is_complete_offset_) & flag); + } + + inline void SetFlag(bool on, unsigned char flag, void *state) const { + if (on) { + *(static_cast<char*>(state) + is_complete_offset_) |= flag; + } else { + *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag); + } + } + + inline bool HasFullContext(const void *state) const { + return GetFlag(state, HAS_FULL_CONTEXT); } - void SetHasFullContext(bool flag, void *state) const { - *(static_cast<char*>(state) + is_complete_offset_) = flag; + inline void SetHasFullContext(bool flag, void *state) const { + SetFlag(flag, HAS_FULL_CONTEXT, state); } public: @@ -71,6 +87,8 @@ class KLanguageModelImpl { double est_sum = 0.0; int num_scored = 0; int num_estimated = 0; + bool saw_eos = false; + bool has_some_history = false; lm::ngram::State state = ngram_->NullContextState(); const vector<WordID>& e = rule.e(); bool context_complete = false; @@ -82,13 +100,16 @@ class KLanguageModelImpl { const lm::WordIndex cur_word = IthUnscoredWord(k, astate); double p = 0; if (cur_word == kSOS_) { - if (state.ValidLength() > 0) { p = -100; } + if (has_some_history) { p = -100; } state = ngram_->BeginSentenceState(); - context_complete = true; + if (!context_complete && num_scored < (order_ - 2)) num_scored = order_ - 2; } else { const lm::ngram::State scopy(state); p = ngram_->Score(scopy, cur_word, state); + if (saw_eos) { p = -100; } + saw_eos = (cur_word == kEOS_); } + has_some_history = true; ++num_scored; if (!context_complete) { if (num_scored >= order_) context_complete = true; @@ -102,6 +123,7 @@ class KLanguageModelImpl { est_sum += p; } } + saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 state = RemnantLMState(astate); context_complete = true; @@ -110,13 +132,16 @@ class KLanguageModelImpl { const lm::WordIndex cur_word = MapWord(e[j]); double p = 0; if (cur_word == kSOS_) { - if (state.ValidLength() > 0) p = -100; + if (has_some_history) p = -100; state = ngram_->BeginSentenceState(); - context_complete = true; + if (!context_complete && num_scored < (order_ - 2)) num_scored = order_ - 2; } else { const lm::ngram::State scopy(state); p = ngram_->Score(scopy, cur_word, state); + if (saw_eos) { p = -100; } + saw_eos = (cur_word == kEOS_); } + has_some_history = true; ++num_scored; if (!context_complete) { if (num_scored >= order_) context_complete = true; @@ -134,6 +159,7 @@ class KLanguageModelImpl { if (pest_sum) *pest_sum = est_sum; if (remnant) { state.ZeroRemaining(); + SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); SetRemnantLMState(state, remnant); SetUnscoredSize(num_estimated, remnant); SetHasFullContext(context_complete || (num_scored >= order_), remnant); @@ -190,6 +216,7 @@ class KLanguageModelImpl { kSOS_ = MapWord(TD::Convert("<s>")); assert(kSOS_ > 0); kEOS_ = MapWord(TD::Convert("</s>")); + assert(kEOS_ > 0); } ~KLanguageModelImpl() { |