summaryrefslogtreecommitdiff
path: root/klm/util/bit_packing.hh
blob: 73a5cb2268f5b6d3b19b28733f89efe52d49eeac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#ifndef UTIL_BIT_PACKING__
#define UTIL_BIT_PACKING__

/* Bit-level packing routines 
 *
 * WARNING WARNING WARNING:
 * The write functions assume that memory is zero initially.  This makes them
 * faster and is the appropriate case for mmapped language model construction.
 * These routines assume that unaligned access to uint64_t is fast.  This is
 * the case on x86_64.  I'm not sure how fast unaligned 64-bit access is on
 * x86 but my target audience is large language models for which 64-bit is
 * necessary.  
 *
 * Call the BitPackingSanity function to sanity check.  Calling once suffices,
 * but it may be called multiple times when that's inconvenient.  
 *
 * ARM and MinGW ports contributed by Hideo Okuma and Tomoyuki Yoshimura at
 * NICT.
 */

#include <assert.h>
#ifdef __APPLE__
#include <architecture/byte_order.h>
#elif __linux__
#include <endian.h>
#elif !defined(_WIN32) && !defined(_WIN64)
#include <arpa/nameser_compat.h>
#endif 

#include <stdint.h>

#include <string.h>

namespace util {

// Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct.  
#if BYTE_ORDER == LITTLE_ENDIAN
inline uint8_t BitPackShift(uint8_t bit, uint8_t /*length*/) {
  return bit;
}
#elif BYTE_ORDER == BIG_ENDIAN
inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {
  return 64 - length - bit;
}
#else
#error "Bit packing code isn't written for your byte order."
#endif

inline uint64_t ReadOff(const void *base, uint64_t bit_off) {
#if defined(__arm) || defined(__arm__)
  const uint8_t *base_off = reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3);
  uint64_t value64;
  memcpy(&value64, base_off, sizeof(value64));
  return value64;
#else
  return *reinterpret_cast<const uint64_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3));
#endif
}

/* Pack integers up to 57 bits using their least significant digits. 
 * The length is specified using mask:
 * Assumes mask == (1 << length) - 1 where length <= 57.   
 */
inline uint64_t ReadInt57(const void *base, uint64_t bit_off, uint8_t length, uint64_t mask) {
  return (ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, length)) & mask;
}
/* Assumes value < (1 << length) and length <= 57.
 * Assumes the memory is zero initially. 
 */
inline void WriteInt57(void *base, uint64_t bit_off, uint8_t length, uint64_t value) {
#if defined(__arm) || defined(__arm__)
  uint8_t *base_off = reinterpret_cast<uint8_t*>(base) + (bit_off >> 3);
  uint64_t value64;
  memcpy(&value64, base_off, sizeof(value64));
  value64 |= (value << BitPackShift(bit_off & 7, length));
  memcpy(base_off, &value64, sizeof(value64));
#else
  *reinterpret_cast<uint64_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |= 
    (value << BitPackShift(bit_off & 7, length));
#endif
}

/* Same caveats as above, but for a 25 bit limit. */
inline uint32_t ReadInt25(const void *base, uint64_t bit_off, uint8_t length, uint32_t mask) {
#if defined(__arm) || defined(__arm__)
  const uint8_t *base_off = reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3);
  uint32_t value32;
  memcpy(&value32, base_off, sizeof(value32));
  return (value32 >> BitPackShift(bit_off & 7, length)) & mask;
#else
  return (*reinterpret_cast<const uint32_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)) >> BitPackShift(bit_off & 7, length)) & mask;
#endif
}

inline void WriteInt25(void *base, uint64_t bit_off, uint8_t length, uint32_t value) {
#if defined(__arm) || defined(__arm__)
  uint8_t *base_off = reinterpret_cast<uint8_t*>(base) + (bit_off >> 3);
  uint32_t value32;
  memcpy(&value32, base_off, sizeof(value32));
  value32 |= (value << BitPackShift(bit_off & 7, length));
  memcpy(base_off, &value32, sizeof(value32));
#else
  *reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |= 
    (value << BitPackShift(bit_off & 7, length));
#endif
}

typedef union { float f; uint32_t i; } FloatEnc;

inline float ReadFloat32(const void *base, uint64_t bit_off) {
  FloatEnc encoded;
  encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32);
  return encoded.f;
}
inline void WriteFloat32(void *base, uint64_t bit_off, float value) {
  FloatEnc encoded;
  encoded.f = value;
  WriteInt57(base, bit_off, 32, encoded.i);
}

const uint32_t kSignBit = 0x80000000;

inline void SetSign(float &to) {
  FloatEnc enc;
  enc.f = to;
  enc.i |= kSignBit;
  to = enc.f;
}

inline void UnsetSign(float &to) {
  FloatEnc enc;
  enc.f = to;
  enc.i &= ~kSignBit;
  to = enc.f;
}

inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) {
  FloatEnc encoded;
  encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31);
  // Sign bit set means negative.  
  encoded.i |= kSignBit;
  return encoded.f;
}
inline void WriteNonPositiveFloat31(void *base, uint64_t bit_off, float value) {
  FloatEnc encoded;
  encoded.f = value;
  encoded.i &= ~kSignBit;
  WriteInt57(base, bit_off, 31, encoded.i);
}

void BitPackingSanity();

// Return bits required to store integers upto max_value.  Not the most
// efficient implementation, but this is only called a few times to size tries. 
uint8_t RequiredBits(uint64_t max_value);

struct BitsMask {
  static BitsMask ByMax(uint64_t max_value) {
    BitsMask ret;
    ret.FromMax(max_value);
    return ret;
  }
  static BitsMask ByBits(uint8_t bits) {
    BitsMask ret;
    ret.bits = bits;
    ret.mask = (1ULL << bits) - 1;
    return ret;
  }
  void FromMax(uint64_t max_value) {
    bits = RequiredBits(max_value);
    mask = (1ULL << bits) - 1;
  }
  uint8_t bits;
  uint64_t mask;
};

} // namespace util

#endif // UTIL_BIT_PACKING__