summaryrefslogtreecommitdiff
path: root/decoder/ff_klm.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-03-02 18:45:51 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-03-02 18:45:51 -0500
commitd5d86d6513e34f0b26030814c9eb516271ce2aec (patch)
tree8101bbff0f150d48ed66db3699b8a935065ae4f7 /decoder/ff_klm.cc
parent88c224217307f40f5361150f5bd2e8b68f51b17b (diff)
better handling of SOS/EOS markers, part 1
Diffstat (limited to 'decoder/ff_klm.cc')
-rw-r--r--decoder/ff_klm.cc43
1 files changed, 35 insertions, 8 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 9ba2cbaa..a55a0b41 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -8,6 +8,10 @@
using namespace std;
+static const unsigned char HAS_FULL_CONTEXT = 1;
+static const unsigned char HAS_EOS_ON_RIGHT = 2;
+static const unsigned char MASK = 7;
+
template <class Model>
string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) {
return "KLanguageModel";
@@ -57,12 +61,24 @@ class KLanguageModelImpl {
mem[i] = index;
}
- bool HasFullContext(const void *state) const {
- return *(static_cast<const char*>(state) + is_complete_offset_);
+ inline bool GetFlag(const void *state, unsigned char flag) const {
+ return (*(static_cast<const char*>(state) + is_complete_offset_) & flag);
+ }
+
+ inline void SetFlag(bool on, unsigned char flag, void *state) const {
+ if (on) {
+ *(static_cast<char*>(state) + is_complete_offset_) |= flag;
+ } else {
+ *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag);
+ }
+ }
+
+ inline bool HasFullContext(const void *state) const {
+ return GetFlag(state, HAS_FULL_CONTEXT);
}
- void SetHasFullContext(bool flag, void *state) const {
- *(static_cast<char*>(state) + is_complete_offset_) = flag;
+ inline void SetHasFullContext(bool flag, void *state) const {
+ SetFlag(flag, HAS_FULL_CONTEXT, state);
}
public:
@@ -71,6 +87,8 @@ class KLanguageModelImpl {
double est_sum = 0.0;
int num_scored = 0;
int num_estimated = 0;
+ bool saw_eos = false;
+ bool has_some_history = false;
lm::ngram::State state = ngram_->NullContextState();
const vector<WordID>& e = rule.e();
bool context_complete = false;
@@ -82,13 +100,16 @@ class KLanguageModelImpl {
const lm::WordIndex cur_word = IthUnscoredWord(k, astate);
double p = 0;
if (cur_word == kSOS_) {
- if (state.ValidLength() > 0) { p = -100; }
+ if (has_some_history) { p = -100; }
state = ngram_->BeginSentenceState();
- context_complete = true;
+ if (!context_complete && num_scored < (order_ - 2)) num_scored = order_ - 2;
} else {
const lm::ngram::State scopy(state);
p = ngram_->Score(scopy, cur_word, state);
+ if (saw_eos) { p = -100; }
+ saw_eos = (cur_word == kEOS_);
}
+ has_some_history = true;
++num_scored;
if (!context_complete) {
if (num_scored >= order_) context_complete = true;
@@ -102,6 +123,7 @@ class KLanguageModelImpl {
est_sum += p;
}
}
+ saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT);
if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007
state = RemnantLMState(astate);
context_complete = true;
@@ -110,13 +132,16 @@ class KLanguageModelImpl {
const lm::WordIndex cur_word = MapWord(e[j]);
double p = 0;
if (cur_word == kSOS_) {
- if (state.ValidLength() > 0) p = -100;
+ if (has_some_history) p = -100;
state = ngram_->BeginSentenceState();
- context_complete = true;
+ if (!context_complete && num_scored < (order_ - 2)) num_scored = order_ - 2;
} else {
const lm::ngram::State scopy(state);
p = ngram_->Score(scopy, cur_word, state);
+ if (saw_eos) { p = -100; }
+ saw_eos = (cur_word == kEOS_);
}
+ has_some_history = true;
++num_scored;
if (!context_complete) {
if (num_scored >= order_) context_complete = true;
@@ -134,6 +159,7 @@ class KLanguageModelImpl {
if (pest_sum) *pest_sum = est_sum;
if (remnant) {
state.ZeroRemaining();
+ SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant);
SetRemnantLMState(state, remnant);
SetUnscoredSize(num_estimated, remnant);
SetHasFullContext(context_complete || (num_scored >= order_), remnant);
@@ -190,6 +216,7 @@ class KLanguageModelImpl {
kSOS_ = MapWord(TD::Convert("<s>"));
assert(kSOS_ > 0);
kEOS_ = MapWord(TD::Convert("</s>"));
+ assert(kEOS_ > 0);
}
~KLanguageModelImpl() {