diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-05-16 13:24:08 -0700 |
---|---|---|
committer | Chris Dyer <cdyer@cab.ark.cs.cmu.edu> | 2012-05-26 22:59:54 -0400 |
commit | 149232c38eec558ddb1097698d1570aacb67b59f (patch) | |
tree | 5860b4d6f681eeb04a1020cbb2fe7e6ac394af99 /klm/lm/state.hh | |
parent | 01ecc09f8e3a82c32bf7dd2f90c12554becea71d (diff) |
Big kenlm change includes lower order models for probing only. And other stuff.
Diffstat (limited to 'klm/lm/state.hh')
-rw-r--r-- | klm/lm/state.hh | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/klm/lm/state.hh b/klm/lm/state.hh new file mode 100644 index 00000000..c7438414 --- /dev/null +++ b/klm/lm/state.hh @@ -0,0 +1,123 @@ +#ifndef LM_STATE__ +#define LM_STATE__ + +#include "lm/max_order.hh" +#include "lm/word_index.hh" +#include "util/murmur_hash.hh" + +#include <string.h> + +namespace lm { +namespace ngram { + +// This is a POD but if you want memcmp to return the same as operator==, call +// ZeroRemaining first. +class State { + public: + bool operator==(const State &other) const { + if (length != other.length) return false; + return !memcmp(words, other.words, length * sizeof(WordIndex)); + } + + // Three way comparison function. + int Compare(const State &other) const { + if (length != other.length) return length < other.length ? -1 : 1; + return memcmp(words, other.words, length * sizeof(WordIndex)); + } + + bool operator<(const State &other) const { + if (length != other.length) return length < other.length; + return memcmp(words, other.words, length * sizeof(WordIndex)) < 0; + } + + // Call this before using raw memcmp. + void ZeroRemaining() { + for (unsigned char i = length; i < kMaxOrder - 1; ++i) { + words[i] = 0; + backoff[i] = 0.0; + } + } + + unsigned char Length() const { return length; } + + // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. + // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. + WordIndex words[kMaxOrder - 1]; + float backoff[kMaxOrder - 1]; + unsigned char length; +}; + +inline uint64_t hash_value(const State &state, uint64_t seed = 0) { + return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed); +} + +struct Left { + bool operator==(const Left &other) const { + return + (length == other.length) && + pointers[length - 1] == other.pointers[length - 1] && + full == other.full; + } + + int Compare(const Left &other) const { + if (length < other.length) return -1; + if (length > other.length) return 1; + if (pointers[length - 1] > other.pointers[length - 1]) return 1; + if (pointers[length - 1] < other.pointers[length - 1]) return -1; + return (int)full - (int)other.full; + } + + bool operator<(const Left &other) const { + return Compare(other) == -1; + } + + void ZeroRemaining() { + for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) + *i = 0; + } + + uint64_t pointers[kMaxOrder - 1]; + unsigned char length; + bool full; +}; + +inline uint64_t hash_value(const Left &left) { + unsigned char add[2]; + add[0] = left.length; + add[1] = left.full; + return util::MurmurHashNative(add, 2, left.length ? left.pointers[left.length - 1] : 0); +} + +struct ChartState { + bool operator==(const ChartState &other) { + return (right == other.right) && (left == other.left); + } + + int Compare(const ChartState &other) const { + int lres = left.Compare(other.left); + if (lres) return lres; + return right.Compare(other.right); + } + + bool operator<(const ChartState &other) const { + return Compare(other) == -1; + } + + void ZeroRemaining() { + left.ZeroRemaining(); + right.ZeroRemaining(); + } + + Left left; + State right; +}; + +inline uint64_t hash_value(const ChartState &state) { + return hash_value(state.right, hash_value(state.left)); +} + + +} // namespace ngram +} // namespace lm + +#endif // LM_STATE__ |