diff options
Diffstat (limited to 'klm/lm/left.hh')
-rw-r--r-- | klm/lm/left.hh | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh index effa0560..bb3f5539 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -117,7 +117,7 @@ inline size_t hash_value(const ChartState &state) { template <class M> class RuleScore { public: - explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { out.left.length = 0; out.right.length = 0; } @@ -130,15 +130,22 @@ template <class M> class RuleScore { void Terminal(WordIndex word) { State copy(out_.right); - ProcessRet(model_.FullScore(copy, word, out_.right)); - if (out_.right.length != copy.length + 1) left_done_ = true; + FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + out_.left.pointers[out_.left.length++] = ret.extend_left; + if (out_.right.length != copy.length + 1) + left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. void BeginNonTerminal(const ChartState &in, float prob) { prob_ = prob; out_ = in; - left_write_ = out_.left.pointers + out_.left.length; left_done_ = in.full; } @@ -157,11 +164,10 @@ template <class M> class RuleScore { if (!out_.right.length) { out_.right = in.right; if (left_done_) return; - if (left_write_ != out_.left.pointers) { + if (out_.left.length) { left_done_ = true; } else { out_.left = in.left; - left_write_ = out_.left.pointers + in.left.length; left_done_ = in.full; } return; @@ -214,8 +220,8 @@ template <class M> class RuleScore { } float Finish() { - out_.left.length = left_write_ - out_.left.pointers; - out_.full = left_done_; + // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. + out_.full = left_done_ || (out_.left.length == model_.Order() - 1); return prob_; } @@ -227,7 +233,7 @@ template <class M> class RuleScore { left_done_ = true; return; } - *(left_write_++) = ret.extend_left; + out_.left.pointers[out_.left.length++] = ret.extend_left; } const M &model_; @@ -236,8 +242,6 @@ template <class M> class RuleScore { bool left_done_; - uint64_t *left_write_; - float prob_; }; |