summaryrefslogtreecommitdiff
path: root/klm/lm/state.hh
blob: c74384143657f578ba1e9f34c5e910091f726b2a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#ifndef LM_STATE__
#define LM_STATE__

#include "lm/max_order.hh"
#include "lm/word_index.hh"
#include "util/murmur_hash.hh"

#include <string.h>

namespace lm {
namespace ngram {

// This is a POD but if you want memcmp to return the same as operator==, call
// ZeroRemaining first.    
class State {
  public:
    bool operator==(const State &other) const {
      if (length != other.length) return false;
      return !memcmp(words, other.words, length * sizeof(WordIndex));
    }

    // Three way comparison function.  
    int Compare(const State &other) const {
      if (length != other.length) return length < other.length ? -1 : 1;
      return memcmp(words, other.words, length * sizeof(WordIndex));
    }

    bool operator<(const State &other) const {
      if (length != other.length) return length < other.length;
      return memcmp(words, other.words, length * sizeof(WordIndex)) < 0;
    }

    // Call this before using raw memcmp.  
    void ZeroRemaining() {
      for (unsigned char i = length; i < kMaxOrder - 1; ++i) {
        words[i] = 0;
        backoff[i] = 0.0;
      }
    }

    unsigned char Length() const { return length; }

    // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.  
    // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.  
    WordIndex words[kMaxOrder - 1];
    float backoff[kMaxOrder - 1];
    unsigned char length;
};

inline uint64_t hash_value(const State &state, uint64_t seed = 0) {
  return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed);
}

struct Left {
  bool operator==(const Left &other) const {
    return 
      (length == other.length) && 
      pointers[length - 1] == other.pointers[length - 1] &&
      full == other.full;
  }

  int Compare(const Left &other) const {
    if (length < other.length) return -1;
    if (length > other.length) return 1;
    if (pointers[length - 1] > other.pointers[length - 1]) return 1;
    if (pointers[length - 1] < other.pointers[length - 1]) return -1;
    return (int)full - (int)other.full;
  }

  bool operator<(const Left &other) const {
    return Compare(other) == -1;
  }

  void ZeroRemaining() {
    for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i)
      *i = 0;
  }

  uint64_t pointers[kMaxOrder - 1];
  unsigned char length;
  bool full;
};

inline uint64_t hash_value(const Left &left) {
  unsigned char add[2];
  add[0] = left.length;
  add[1] = left.full;
  return util::MurmurHashNative(add, 2, left.length ? left.pointers[left.length - 1] : 0);
}

struct ChartState {
  bool operator==(const ChartState &other) {
    return (right == other.right) && (left == other.left);
  }

  int Compare(const ChartState &other) const {
    int lres = left.Compare(other.left);
    if (lres) return lres;
    return right.Compare(other.right);
  }

  bool operator<(const ChartState &other) const {
    return Compare(other) == -1;
  }

  void ZeroRemaining() {
    left.ZeroRemaining();
    right.ZeroRemaining();
  }

  Left left;
  State right;
};

inline uint64_t hash_value(const ChartState &state) {
  return hash_value(state.right, hash_value(state.left));
}


} // namespace ngram
} // namespace lm

#endif // LM_STATE__