summaryrefslogtreecommitdiff
path: root/klm/lm/left.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/left.hh')
-rw-r--r--klm/lm/left.hh26
1 files changed, 15 insertions, 11 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index effa0560..bb3f5539 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -117,7 +117,7 @@ inline size_t hash_value(const ChartState &state) {
template <class M> class RuleScore {
public:
- explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) {
+ explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) {
out.left.length = 0;
out.right.length = 0;
}
@@ -130,15 +130,22 @@ template <class M> class RuleScore {
void Terminal(WordIndex word) {
State copy(out_.right);
- ProcessRet(model_.FullScore(copy, word, out_.right));
- if (out_.right.length != copy.length + 1) left_done_ = true;
+ FullScoreReturn ret(model_.FullScore(copy, word, out_.right));
+ prob_ += ret.prob;
+ if (left_done_) return;
+ if (ret.independent_left) {
+ left_done_ = true;
+ return;
+ }
+ out_.left.pointers[out_.left.length++] = ret.extend_left;
+ if (out_.right.length != copy.length + 1)
+ left_done_ = true;
}
// Faster version of NonTerminal for the case where the rule begins with a non-terminal.
void BeginNonTerminal(const ChartState &in, float prob) {
prob_ = prob;
out_ = in;
- left_write_ = out_.left.pointers + out_.left.length;
left_done_ = in.full;
}
@@ -157,11 +164,10 @@ template <class M> class RuleScore {
if (!out_.right.length) {
out_.right = in.right;
if (left_done_) return;
- if (left_write_ != out_.left.pointers) {
+ if (out_.left.length) {
left_done_ = true;
} else {
out_.left = in.left;
- left_write_ = out_.left.pointers + in.left.length;
left_done_ = in.full;
}
return;
@@ -214,8 +220,8 @@ template <class M> class RuleScore {
}
float Finish() {
- out_.left.length = left_write_ - out_.left.pointers;
- out_.full = left_done_;
+ // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
+ out_.full = left_done_ || (out_.left.length == model_.Order() - 1);
return prob_;
}
@@ -227,7 +233,7 @@ template <class M> class RuleScore {
left_done_ = true;
return;
}
- *(left_write_++) = ret.extend_left;
+ out_.left.pointers[out_.left.length++] = ret.extend_left;
}
const M &model_;
@@ -236,8 +242,6 @@ template <class M> class RuleScore {
bool left_done_;
- uint64_t *left_write_;
-
float prob_;
};