diff options
Diffstat (limited to 'klm/util/bit_packing.hh')
-rw-r--r-- | klm/util/bit_packing.hh | 48 |
1 files changed, 30 insertions, 18 deletions
diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 422ed873..0fd39d7f 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -6,56 +6,68 @@ #include <assert.h> #ifdef __APPLE__ #include <architecture/byte_order.h> -#else +#elif __linux__ #include <endian.h> -#endif +#else +#include <arpa/nameser_compat.h> +#endif #include <inttypes.h> -#if __BYTE_ORDER != __LITTLE_ENDIAN -#error The bit aligned storage functions assume little endian architecture -#endif - namespace util { /* WARNING WARNING WARNING: * The write functions assume that memory is zero initially. This makes them * faster and is the appropriate case for mmapped language model construction. * These routines assume that unaligned access to uint64_t is fast and that - * storage is little endian. This is the case on x86_64. It may not be the - * case on 32-bit x86 but my target audience is large language models for which - * 64-bit is necessary. + * storage is little endian. This is the case on x86_64. I'm not sure how + * fast unaligned 64-bit access is on x86 but my target audience is large + * language models for which 64-bit is necessary. + * + * Call the BitPackingSanity function to sanity check. Calling once suffices, + * but it may be called multiple times when that's inconvenient. */ +inline uint8_t BitPackShift(uint8_t bit, uint8_t length) { +// Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct. +#if BYTE_ORDER == LITTLE_ENDIAN + return bit; +#elif BYTE_ORDER == BIG_ENDIAN + return 64 - length - bit; +#else +#error "Bit packing code isn't written for your byte order." +#endif +} + /* Pack integers up to 57 bits using their least significant digits. * The length is specified using mask: * Assumes mask == (1 << length) - 1 where length <= 57. */ -inline uint64_t ReadInt57(const void *base, uint8_t bit, uint64_t mask) { - return (*reinterpret_cast<const uint64_t*>(base) >> bit) & mask; +inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) { + return (*reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, length)) & mask; } -/* Assumes value <= mask and mask == (1 << length) - 1 where length <= 57. +/* Assumes value < (1 << length) and length <= 57. * Assumes the memory is zero initially. */ -inline void WriteInt57(void *base, uint8_t bit, uint64_t value) { - *reinterpret_cast<uint64_t*>(base) |= (value << bit); +inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) { + *reinterpret_cast<uint64_t*>(base) |= (value << BitPackShift(bit, length)); } namespace detail { typedef union { float f; uint32_t i; } FloatEnc; } inline float ReadFloat32(const void *base, uint8_t bit) { detail::FloatEnc encoded; - encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit; + encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 32); return encoded.f; } inline void WriteFloat32(void *base, uint8_t bit, float value) { detail::FloatEnc encoded; encoded.f = value; - WriteInt57(base, bit, encoded.i); + WriteInt57(base, bit, 32, encoded.i); } inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) { detail::FloatEnc encoded; - encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit; + encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 31); // Sign bit set means negative. encoded.i |= 0x80000000; return encoded.f; @@ -65,7 +77,7 @@ inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) { detail::FloatEnc encoded; encoded.f = value; encoded.i &= ~0x80000000; - WriteInt57(base, bit, encoded.i); + WriteInt57(base, bit, 31, encoded.i); } void BitPackingSanity(); |