summaryrefslogtreecommitdiff
path: root/klm/lm/left.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/left.hh')
-rw-r--r--klm/lm/left.hh110
1 files changed, 30 insertions, 80 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index a07f9803..c00af88a 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -39,7 +39,7 @@
#define LM_LEFT__
#include "lm/max_order.hh"
-#include "lm/model.hh"
+#include "lm/state.hh"
#include "lm/return.hh"
#include "util/murmur_hash.hh"
@@ -49,72 +49,6 @@
namespace lm {
namespace ngram {
-struct Left {
- bool operator==(const Left &other) const {
- return
- (length == other.length) &&
- pointers[length - 1] == other.pointers[length - 1];
- }
-
- int Compare(const Left &other) const {
- if (length != other.length) return length < other.length ? -1 : 1;
- if (pointers[length - 1] > other.pointers[length - 1]) return 1;
- if (pointers[length - 1] < other.pointers[length - 1]) return -1;
- return 0;
- }
-
- bool operator<(const Left &other) const {
- if (length != other.length) return length < other.length;
- return pointers[length - 1] < other.pointers[length - 1];
- }
-
- void ZeroRemaining() {
- for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i)
- *i = 0;
- }
-
- unsigned char length;
- uint64_t pointers[kMaxOrder - 1];
-};
-
-inline size_t hash_value(const Left &left) {
- return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]);
-}
-
-struct ChartState {
- bool operator==(const ChartState &other) {
- return (left == other.left) && (right == other.right) && (full == other.full);
- }
-
- int Compare(const ChartState &other) const {
- int lres = left.Compare(other.left);
- if (lres) return lres;
- int rres = right.Compare(other.right);
- if (rres) return rres;
- return (int)full - (int)other.full;
- }
-
- bool operator<(const ChartState &other) const {
- return Compare(other) == -1;
- }
-
- void ZeroRemaining() {
- left.ZeroRemaining();
- right.ZeroRemaining();
- }
-
- Left left;
- bool full;
- State right;
-};
-
-inline size_t hash_value(const ChartState &state) {
- size_t hashes[2];
- hashes[0] = hash_value(state.left);
- hashes[1] = hash_value(state.right);
- return util::MurmurHashNative(hashes, sizeof(size_t) * 2, state.full);
-}
-
template <class M> class RuleScore {
public:
explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) {
@@ -131,29 +65,30 @@ template <class M> class RuleScore {
void Terminal(WordIndex word) {
State copy(out_.right);
FullScoreReturn ret(model_.FullScore(copy, word, out_.right));
- prob_ += ret.prob;
- if (left_done_) return;
+ if (left_done_) { prob_ += ret.prob; return; }
if (ret.independent_left) {
+ prob_ += ret.prob;
left_done_ = true;
return;
}
out_.left.pointers[out_.left.length++] = ret.extend_left;
+ prob_ += ret.rest;
if (out_.right.length != copy.length + 1)
left_done_ = true;
}
// Faster version of NonTerminal for the case where the rule begins with a non-terminal.
- void BeginNonTerminal(const ChartState &in, float prob) {
+ void BeginNonTerminal(const ChartState &in, float prob = 0.0) {
prob_ = prob;
out_ = in;
- left_done_ = in.full;
+ left_done_ = in.left.full;
}
- void NonTerminal(const ChartState &in, float prob) {
+ void NonTerminal(const ChartState &in, float prob = 0.0) {
prob_ += prob;
if (!in.left.length) {
- if (in.full) {
+ if (in.left.full) {
for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i;
left_done_ = true;
out_.right = in.right;
@@ -163,12 +98,15 @@ template <class M> class RuleScore {
if (!out_.right.length) {
out_.right = in.right;
- if (left_done_) return;
+ if (left_done_) {
+ prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);
+ return;
+ }
if (out_.left.length) {
left_done_ = true;
} else {
out_.left = in.left;
- left_done_ = in.full;
+ left_done_ = in.left.full;
}
return;
}
@@ -186,7 +124,7 @@ template <class M> class RuleScore {
std::swap(back, back2);
}
- if (in.full) {
+ if (in.left.full) {
for (const float *i = back; i != back + next_use; ++i) prob_ += *i;
left_done_ = true;
out_.right = in.right;
@@ -213,10 +151,17 @@ template <class M> class RuleScore {
float Finish() {
// A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
- out_.full = left_done_ || (out_.left.length == model_.Order() - 1);
+ out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1);
return prob_;
}
+ void Reset() {
+ prob_ = 0.0;
+ left_done_ = false;
+ out_.left.length = 0;
+ out_.right.length = 0;
+ }
+
private:
bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
ProcessRet(model_.ExtendLeft(
@@ -228,8 +173,9 @@ template <class M> class RuleScore {
if (next_use != out_.right.length) {
left_done_ = true;
if (!next_use) {
- out_.right = in.right;
// Early exit.
+ out_.right = in.right;
+ prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);
return true;
}
}
@@ -238,13 +184,17 @@ template <class M> class RuleScore {
}
void ProcessRet(const FullScoreReturn &ret) {
- prob_ += ret.prob;
- if (left_done_) return;
+ if (left_done_) {
+ prob_ += ret.prob;
+ return;
+ }
if (ret.independent_left) {
+ prob_ += ret.prob;
left_done_ = true;
return;
}
out_.left.pointers[out_.left.length++] = ret.extend_left;
+ prob_ += ret.rest;
}
const M &model_;