diff options
Diffstat (limited to 'klm/lm/left.hh')
-rw-r--r-- | klm/lm/left.hh | 110 |
1 files changed, 30 insertions, 80 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh index a07f9803..c00af88a 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -39,7 +39,7 @@ #define LM_LEFT__ #include "lm/max_order.hh" -#include "lm/model.hh" +#include "lm/state.hh" #include "lm/return.hh" #include "util/murmur_hash.hh" @@ -49,72 +49,6 @@ namespace lm { namespace ngram { -struct Left { - bool operator==(const Left &other) const { - return - (length == other.length) && - pointers[length - 1] == other.pointers[length - 1]; - } - - int Compare(const Left &other) const { - if (length != other.length) return length < other.length ? -1 : 1; - if (pointers[length - 1] > other.pointers[length - 1]) return 1; - if (pointers[length - 1] < other.pointers[length - 1]) return -1; - return 0; - } - - bool operator<(const Left &other) const { - if (length != other.length) return length < other.length; - return pointers[length - 1] < other.pointers[length - 1]; - } - - void ZeroRemaining() { - for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) - *i = 0; - } - - unsigned char length; - uint64_t pointers[kMaxOrder - 1]; -}; - -inline size_t hash_value(const Left &left) { - return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]); -} - -struct ChartState { - bool operator==(const ChartState &other) { - return (left == other.left) && (right == other.right) && (full == other.full); - } - - int Compare(const ChartState &other) const { - int lres = left.Compare(other.left); - if (lres) return lres; - int rres = right.Compare(other.right); - if (rres) return rres; - return (int)full - (int)other.full; - } - - bool operator<(const ChartState &other) const { - return Compare(other) == -1; - } - - void ZeroRemaining() { - left.ZeroRemaining(); - right.ZeroRemaining(); - } - - Left left; - bool full; - State right; -}; - -inline size_t hash_value(const ChartState &state) { - size_t hashes[2]; - hashes[0] = hash_value(state.left); - hashes[1] = hash_value(state.right); - return util::MurmurHashNative(hashes, sizeof(size_t) * 2, state.full); -} - template <class M> class RuleScore { public: explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { @@ -131,29 +65,30 @@ template <class M> class RuleScore { void Terminal(WordIndex word) { State copy(out_.right); FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); - prob_ += ret.prob; - if (left_done_) return; + if (left_done_) { prob_ += ret.prob; return; } if (ret.independent_left) { + prob_ += ret.prob; left_done_ = true; return; } out_.left.pointers[out_.left.length++] = ret.extend_left; + prob_ += ret.rest; if (out_.right.length != copy.length + 1) left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. - void BeginNonTerminal(const ChartState &in, float prob) { + void BeginNonTerminal(const ChartState &in, float prob = 0.0) { prob_ = prob; out_ = in; - left_done_ = in.full; + left_done_ = in.left.full; } - void NonTerminal(const ChartState &in, float prob) { + void NonTerminal(const ChartState &in, float prob = 0.0) { prob_ += prob; if (!in.left.length) { - if (in.full) { + if (in.left.full) { for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; left_done_ = true; out_.right = in.right; @@ -163,12 +98,15 @@ template <class M> class RuleScore { if (!out_.right.length) { out_.right = in.right; - if (left_done_) return; + if (left_done_) { + prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1); + return; + } if (out_.left.length) { left_done_ = true; } else { out_.left = in.left; - left_done_ = in.full; + left_done_ = in.left.full; } return; } @@ -186,7 +124,7 @@ template <class M> class RuleScore { std::swap(back, back2); } - if (in.full) { + if (in.left.full) { for (const float *i = back; i != back + next_use; ++i) prob_ += *i; left_done_ = true; out_.right = in.right; @@ -213,10 +151,17 @@ template <class M> class RuleScore { float Finish() { // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. - out_.full = left_done_ || (out_.left.length == model_.Order() - 1); + out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1); return prob_; } + void Reset() { + prob_ = 0.0; + left_done_ = false; + out_.left.length = 0; + out_.right.length = 0; + } + private: bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { ProcessRet(model_.ExtendLeft( @@ -228,8 +173,9 @@ template <class M> class RuleScore { if (next_use != out_.right.length) { left_done_ = true; if (!next_use) { - out_.right = in.right; // Early exit. + out_.right = in.right; + prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1); return true; } } @@ -238,13 +184,17 @@ template <class M> class RuleScore { } void ProcessRet(const FullScoreReturn &ret) { - prob_ += ret.prob; - if (left_done_) return; + if (left_done_) { + prob_ += ret.prob; + return; + } if (ret.independent_left) { + prob_ += ret.prob; left_done_ = true; return; } out_.left.pointers[out_.left.length++] = ret.extend_left; + prob_ += ret.rest; } const M &model_; |