summaryrefslogtreecommitdiff
path: root/klm/util/bit_packing.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/util/bit_packing.hh')
-rw-r--r--klm/util/bit_packing.hh48
1 files changed, 30 insertions, 18 deletions
diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh
index 422ed873..0fd39d7f 100644
--- a/klm/util/bit_packing.hh
+++ b/klm/util/bit_packing.hh
@@ -6,56 +6,68 @@
#include <assert.h>
#ifdef __APPLE__
#include <architecture/byte_order.h>
-#else
+#elif __linux__
#include <endian.h>
-#endif
+#else
+#include <arpa/nameser_compat.h>
+#endif
#include <inttypes.h>
-#if __BYTE_ORDER != __LITTLE_ENDIAN
-#error The bit aligned storage functions assume little endian architecture
-#endif
-
namespace util {
/* WARNING WARNING WARNING:
* The write functions assume that memory is zero initially. This makes them
* faster and is the appropriate case for mmapped language model construction.
* These routines assume that unaligned access to uint64_t is fast and that
- * storage is little endian. This is the case on x86_64. It may not be the
- * case on 32-bit x86 but my target audience is large language models for which
- * 64-bit is necessary.
+ * storage is little endian. This is the case on x86_64. I'm not sure how
+ * fast unaligned 64-bit access is on x86 but my target audience is large
+ * language models for which 64-bit is necessary.
+ *
+ * Call the BitPackingSanity function to sanity check. Calling once suffices,
+ * but it may be called multiple times when that's inconvenient.
*/
+inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {
+// Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct.
+#if BYTE_ORDER == LITTLE_ENDIAN
+ return bit;
+#elif BYTE_ORDER == BIG_ENDIAN
+ return 64 - length - bit;
+#else
+#error "Bit packing code isn't written for your byte order."
+#endif
+}
+
/* Pack integers up to 57 bits using their least significant digits.
* The length is specified using mask:
* Assumes mask == (1 << length) - 1 where length <= 57.
*/
-inline uint64_t ReadInt57(const void *base, uint8_t bit, uint64_t mask) {
- return (*reinterpret_cast<const uint64_t*>(base) >> bit) & mask;
+inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) {
+ return (*reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, length)) & mask;
}
-/* Assumes value <= mask and mask == (1 << length) - 1 where length <= 57.
+/* Assumes value < (1 << length) and length <= 57.
* Assumes the memory is zero initially.
*/
-inline void WriteInt57(void *base, uint8_t bit, uint64_t value) {
- *reinterpret_cast<uint64_t*>(base) |= (value << bit);
+inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) {
+ *reinterpret_cast<uint64_t*>(base) |= (value << BitPackShift(bit, length));
}
namespace detail { typedef union { float f; uint32_t i; } FloatEnc; }
inline float ReadFloat32(const void *base, uint8_t bit) {
detail::FloatEnc encoded;
- encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit;
+ encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 32);
return encoded.f;
}
inline void WriteFloat32(void *base, uint8_t bit, float value) {
detail::FloatEnc encoded;
encoded.f = value;
- WriteInt57(base, bit, encoded.i);
+ WriteInt57(base, bit, 32, encoded.i);
}
inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) {
detail::FloatEnc encoded;
- encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit;
+ encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 31);
// Sign bit set means negative.
encoded.i |= 0x80000000;
return encoded.f;
@@ -65,7 +77,7 @@ inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) {
detail::FloatEnc encoded;
encoded.f = value;
encoded.i &= ~0x80000000;
- WriteInt57(base, bit, encoded.i);
+ WriteInt57(base, bit, 31, encoded.i);
}
void BitPackingSanity();