diff options
Diffstat (limited to 'klm/util/bit_packing.hh')
| -rw-r--r-- | klm/util/bit_packing.hh | 66 | 
1 files changed, 50 insertions, 16 deletions
diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 33266b94..73a5cb22 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -1,33 +1,37 @@  #ifndef UTIL_BIT_PACKING__  #define UTIL_BIT_PACKING__ -/* Bit-level packing routines */ +/* Bit-level packing routines  + * + * WARNING WARNING WARNING: + * The write functions assume that memory is zero initially.  This makes them + * faster and is the appropriate case for mmapped language model construction. + * These routines assume that unaligned access to uint64_t is fast.  This is + * the case on x86_64.  I'm not sure how fast unaligned 64-bit access is on + * x86 but my target audience is large language models for which 64-bit is + * necessary.   + * + * Call the BitPackingSanity function to sanity check.  Calling once suffices, + * but it may be called multiple times when that's inconvenient.   + * + * ARM and MinGW ports contributed by Hideo Okuma and Tomoyuki Yoshimura at + * NICT. + */  #include <assert.h>  #ifdef __APPLE__  #include <architecture/byte_order.h>  #elif __linux__  #include <endian.h> -#else +#elif !defined(_WIN32) && !defined(_WIN64)  #include <arpa/nameser_compat.h>  #endif  -#include <inttypes.h> - -namespace util { +#include <stdint.h> -/* WARNING WARNING WARNING: - * The write functions assume that memory is zero initially.  This makes them - * faster and is the appropriate case for mmapped language model construction. - * These routines assume that unaligned access to uint64_t is fast and that - * storage is little endian.  This is the case on x86_64.  I'm not sure how  - * fast unaligned 64-bit access is on x86 but my target audience is large - * language models for which 64-bit is necessary.   - * - * Call the BitPackingSanity function to sanity check.  Calling once suffices, - * but it may be called multiple times when that's inconvenient.   - */ +#include <string.h> +namespace util {  // Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct.    #if BYTE_ORDER == LITTLE_ENDIAN @@ -43,7 +47,14 @@ inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {  #endif  inline uint64_t ReadOff(const void *base, uint64_t bit_off) { +#if defined(__arm) || defined(__arm__) +  const uint8_t *base_off = reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3); +  uint64_t value64; +  memcpy(&value64, base_off, sizeof(value64)); +  return value64; +#else    return *reinterpret_cast<const uint64_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)); +#endif  }  /* Pack integers up to 57 bits using their least significant digits.  @@ -57,18 +68,41 @@ inline uint64_t ReadInt57(const void *base, uint64_t bit_off, uint8_t length, ui   * Assumes the memory is zero initially.    */  inline void WriteInt57(void *base, uint64_t bit_off, uint8_t length, uint64_t value) { +#if defined(__arm) || defined(__arm__) +  uint8_t *base_off = reinterpret_cast<uint8_t*>(base) + (bit_off >> 3); +  uint64_t value64; +  memcpy(&value64, base_off, sizeof(value64)); +  value64 |= (value << BitPackShift(bit_off & 7, length)); +  memcpy(base_off, &value64, sizeof(value64)); +#else    *reinterpret_cast<uint64_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |=       (value << BitPackShift(bit_off & 7, length)); +#endif  }  /* Same caveats as above, but for a 25 bit limit. */  inline uint32_t ReadInt25(const void *base, uint64_t bit_off, uint8_t length, uint32_t mask) { +#if defined(__arm) || defined(__arm__) +  const uint8_t *base_off = reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3); +  uint32_t value32; +  memcpy(&value32, base_off, sizeof(value32)); +  return (value32 >> BitPackShift(bit_off & 7, length)) & mask; +#else    return (*reinterpret_cast<const uint32_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)) >> BitPackShift(bit_off & 7, length)) & mask; +#endif  }  inline void WriteInt25(void *base, uint64_t bit_off, uint8_t length, uint32_t value) { +#if defined(__arm) || defined(__arm__) +  uint8_t *base_off = reinterpret_cast<uint8_t*>(base) + (bit_off >> 3); +  uint32_t value32; +  memcpy(&value32, base_off, sizeof(value32)); +  value32 |= (value << BitPackShift(bit_off & 7, length)); +  memcpy(base_off, &value32, sizeof(value32)); +#else    *reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |=       (value << BitPackShift(bit_off & 7, length)); +#endif  }  typedef union { float f; uint32_t i; } FloatEnc;  | 
