summaryrefslogtreecommitdiff
path: root/klm/lm/state.hh
diff options
context:
space:
mode:
authorPatrick Simianer <simianer@cl.uni-heidelberg.de>2012-05-31 13:57:24 +0200
committerPatrick Simianer <simianer@cl.uni-heidelberg.de>2012-05-31 13:57:24 +0200
commitf1ba05780db1705493d9afb562332498b93d26f1 (patch)
treefb429a657ba97f33e8140742de9bc74d9fc88e75 /klm/lm/state.hh
parentaadabfdf37dfd451485277cb77fad02f77b361c6 (diff)
parent317d650f6cb1e24ac6f3be6f7bf9d4246a59e0e5 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'klm/lm/state.hh')
-rw-r--r--klm/lm/state.hh123
1 files changed, 123 insertions, 0 deletions
diff --git a/klm/lm/state.hh b/klm/lm/state.hh
new file mode 100644
index 00000000..c7438414
--- /dev/null
+++ b/klm/lm/state.hh
@@ -0,0 +1,123 @@
+#ifndef LM_STATE__
+#define LM_STATE__
+
+#include "lm/max_order.hh"
+#include "lm/word_index.hh"
+#include "util/murmur_hash.hh"
+
+#include <string.h>
+
+namespace lm {
+namespace ngram {
+
+// This is a POD but if you want memcmp to return the same as operator==, call
+// ZeroRemaining first.
+class State {
+ public:
+ bool operator==(const State &other) const {
+ if (length != other.length) return false;
+ return !memcmp(words, other.words, length * sizeof(WordIndex));
+ }
+
+ // Three way comparison function.
+ int Compare(const State &other) const {
+ if (length != other.length) return length < other.length ? -1 : 1;
+ return memcmp(words, other.words, length * sizeof(WordIndex));
+ }
+
+ bool operator<(const State &other) const {
+ if (length != other.length) return length < other.length;
+ return memcmp(words, other.words, length * sizeof(WordIndex)) < 0;
+ }
+
+ // Call this before using raw memcmp.
+ void ZeroRemaining() {
+ for (unsigned char i = length; i < kMaxOrder - 1; ++i) {
+ words[i] = 0;
+ backoff[i] = 0.0;
+ }
+ }
+
+ unsigned char Length() const { return length; }
+
+ // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
+ // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
+ WordIndex words[kMaxOrder - 1];
+ float backoff[kMaxOrder - 1];
+ unsigned char length;
+};
+
+inline uint64_t hash_value(const State &state, uint64_t seed = 0) {
+ return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed);
+}
+
+struct Left {
+ bool operator==(const Left &other) const {
+ return
+ (length == other.length) &&
+ pointers[length - 1] == other.pointers[length - 1] &&
+ full == other.full;
+ }
+
+ int Compare(const Left &other) const {
+ if (length < other.length) return -1;
+ if (length > other.length) return 1;
+ if (pointers[length - 1] > other.pointers[length - 1]) return 1;
+ if (pointers[length - 1] < other.pointers[length - 1]) return -1;
+ return (int)full - (int)other.full;
+ }
+
+ bool operator<(const Left &other) const {
+ return Compare(other) == -1;
+ }
+
+ void ZeroRemaining() {
+ for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i)
+ *i = 0;
+ }
+
+ uint64_t pointers[kMaxOrder - 1];
+ unsigned char length;
+ bool full;
+};
+
+inline uint64_t hash_value(const Left &left) {
+ unsigned char add[2];
+ add[0] = left.length;
+ add[1] = left.full;
+ return util::MurmurHashNative(add, 2, left.length ? left.pointers[left.length - 1] : 0);
+}
+
+struct ChartState {
+ bool operator==(const ChartState &other) {
+ return (right == other.right) && (left == other.left);
+ }
+
+ int Compare(const ChartState &other) const {
+ int lres = left.Compare(other.left);
+ if (lres) return lres;
+ return right.Compare(other.right);
+ }
+
+ bool operator<(const ChartState &other) const {
+ return Compare(other) == -1;
+ }
+
+ void ZeroRemaining() {
+ left.ZeroRemaining();
+ right.ZeroRemaining();
+ }
+
+ Left left;
+ State right;
+};
+
+inline uint64_t hash_value(const ChartState &state) {
+ return hash_value(state.right, hash_value(state.left));
+}
+
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_STATE__