diff options
Diffstat (limited to 'klm/util')
| -rw-r--r-- | klm/util/bit_packing.cc | 4 | ||||
| -rw-r--r-- | klm/util/bit_packing.hh | 39 | ||||
| -rw-r--r-- | klm/util/bit_packing_test.cc | 25 | ||||
| -rw-r--r-- | klm/util/sorted_uniform.hh | 120 | 
4 files changed, 133 insertions, 55 deletions
| diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc index 681da5f2..41999b72 100644 --- a/klm/util/bit_packing.cc +++ b/klm/util/bit_packing.cc @@ -28,10 +28,10 @@ void BitPackingSanity() {    memset(mem, 0, sizeof(mem));    const uint64_t test57 = 0x123456789abcdefULL;    for (uint64_t b = 0; b < 57 * 8; b += 57) { -    WriteInt57(mem + b / 8, b % 8, 57, test57); +    WriteInt57(mem, b, 57, test57);    }    for (uint64_t b = 0; b < 57 * 8; b += 57) { -    if (test57 != ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)) +    if (test57 != ReadInt57(mem, b, 57, (1ULL << 57) - 1))        UTIL_THROW(Exception, "The bit packing routines are failing for your architecture.  Please send a bug report with your architecture, operating system, and compiler.");    }    // TODO: more checks.   diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 5c71c792..b35d80c8 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -42,47 +42,62 @@ inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {  #error "Bit packing code isn't written for your byte order."  #endif +inline uint64_t ReadOff(const void *base, uint64_t bit_off) { +  return *reinterpret_cast<const uint64_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)); +} +  /* Pack integers up to 57 bits using their least significant digits.    * The length is specified using mask:   * Assumes mask == (1 << length) - 1 where length <= 57.      */ -inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) { -  return (*reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, length)) & mask; +inline uint64_t ReadInt57(const void *base, uint64_t bit_off, uint8_t length, uint64_t mask) { +  return (ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, length)) & mask;  }  /* Assumes value < (1 << length) and length <= 57.   * Assumes the memory is zero initially.    */ -inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) { -  *reinterpret_cast<uint64_t*>(base) |= (value << BitPackShift(bit, length)); +inline void WriteInt57(void *base, uint64_t bit_off, uint8_t length, uint64_t value) { +  *reinterpret_cast<uint64_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |=  +    (value << BitPackShift(bit_off & 7, length)); +} + +/* Same caveats as above, but for a 25 bit limit. */ +inline uint32_t ReadInt25(const void *base, uint64_t bit_off, uint8_t length, uint32_t mask) { +  return (*reinterpret_cast<const uint32_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)) >> BitPackShift(bit_off & 7, length)) & mask; +} + +inline void WriteInt25(void *base, uint64_t bit_off, uint8_t length, uint32_t value) { +  *reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |=  +    (value << BitPackShift(bit_off & 7, length));  }  typedef union { float f; uint32_t i; } FloatEnc; -inline float ReadFloat32(const void *base, uint8_t bit) { +inline float ReadFloat32(const void *base, uint64_t bit_off) {    FloatEnc encoded; -  encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 32); +  encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32);    return encoded.f;  } -inline void WriteFloat32(void *base, uint8_t bit, float value) { +inline void WriteFloat32(void *base, uint64_t bit_off, float value) {    FloatEnc encoded;    encoded.f = value; -  WriteInt57(base, bit, 32, encoded.i); +  WriteInt57(base, bit_off, 32, encoded.i);  }  const uint32_t kSignBit = 0x80000000; -inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) { +inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) {    FloatEnc encoded; -  encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 31); +  encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31);    // Sign bit set means negative.      encoded.i |= kSignBit;    return encoded.f;  } -inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) { +inline void WriteNonPositiveFloat31(void *base, uint64_t bit_off, float value) {    FloatEnc encoded;    encoded.f = value;    encoded.i &= ~kSignBit; -  WriteInt57(base, bit, 31, encoded.i); +  WriteInt57(base, bit_off, 31, encoded.i);  }  void BitPackingSanity(); diff --git a/klm/util/bit_packing_test.cc b/klm/util/bit_packing_test.cc index c578ddd1..4edc2004 100644 --- a/klm/util/bit_packing_test.cc +++ b/klm/util/bit_packing_test.cc @@ -9,15 +9,16 @@ namespace util {  namespace {  const uint64_t test57 = 0x123456789abcdefULL; +const uint32_t test25 = 0x1234567; -BOOST_AUTO_TEST_CASE(ZeroBit) { +BOOST_AUTO_TEST_CASE(ZeroBit57) {    char mem[16];    memset(mem, 0, sizeof(mem));    WriteInt57(mem, 0, 57, test57);    BOOST_CHECK_EQUAL(test57, ReadInt57(mem, 0, 57, (1ULL << 57) - 1));  } -BOOST_AUTO_TEST_CASE(EachBit) { +BOOST_AUTO_TEST_CASE(EachBit57) {    char mem[16];    for (uint8_t b = 0; b < 8; ++b) {      memset(mem, 0, sizeof(mem)); @@ -26,15 +27,27 @@ BOOST_AUTO_TEST_CASE(EachBit) {    }  } -BOOST_AUTO_TEST_CASE(Consecutive) { +BOOST_AUTO_TEST_CASE(Consecutive57) {    char mem[57+8];    memset(mem, 0, sizeof(mem));    for (uint64_t b = 0; b < 57 * 8; b += 57) { -    WriteInt57(mem + (b / 8), b % 8, 57, test57); -    BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); +    WriteInt57(mem, b, 57, test57); +    BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1));    }    for (uint64_t b = 0; b < 57 * 8; b += 57) { -    BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); +    BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); +  } +} + +BOOST_AUTO_TEST_CASE(Consecutive25) { +  char mem[25+8]; +  memset(mem, 0, sizeof(mem)); +  for (uint64_t b = 0; b < 25 * 8; b += 25) { +    WriteInt25(mem, b, 25, test25); +    BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); +  } +  for (uint64_t b = 0; b < 25 * 8; b += 25) { +    BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1));    }  } diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 05826b51..84d7aa02 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -9,52 +9,96 @@  namespace util { -inline std::size_t Pivot(uint64_t off, uint64_t range, std::size_t width) { -  std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width)); -  // Cap for floating point rounding -  return (ret < width) ? ret : width - 1; -} -/*inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) { -  return static_cast<std::size_t>(static_cast<uint64_t>(off) * static_cast<uint64_t>(width) / static_cast<uint64_t>(range)); +template <class T> class IdentityAccessor { +  public: +    typedef T Key; +    T operator()(const uint64_t *in) const { return *in; } +}; + +struct Pivot64 { +  static inline std::size_t Calc(uint64_t off, uint64_t range, std::size_t width) { +    std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width)); +    // Cap for floating point rounding +    return (ret < width) ? ret : width - 1; +  } +}; + +// Use when off * width is <2^64.  This is guaranteed when each of them is actually a 32-bit value.    +struct Pivot32 { +  static inline std::size_t Calc(uint64_t off, uint64_t range, uint64_t width) { +    return static_cast<std::size_t>((off * width) / (range + 1)); +  } +}; + +// Usage: PivotSelect<sizeof(DataType)>::T +template <unsigned> struct PivotSelect; +template <> struct PivotSelect<8> { typedef Pivot64 T; }; +template <> struct PivotSelect<4> { typedef Pivot32 T; }; +template <> struct PivotSelect<2> { typedef Pivot32 T; }; + +/* Binary search. */ +template <class Iterator, class Accessor> bool BinaryFind( +    const Accessor &accessor, +    Iterator begin, +    Iterator end, +    const typename Accessor::Key key, Iterator &out) { +  while (end > begin) { +    Iterator pivot(begin + (end - begin) / 2); +    typename Accessor::Key mid(accessor(pivot)); +    if (mid < key) { +      begin = pivot + 1; +    } else if (mid > key) { +      end = pivot; +    } else { +      out = pivot; +      return true; +    } +  } +  return false;  } -inline std::size_t Pivot(uint16_t off, uint16_t range, std::size_t width) { -  return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range)); + +// Search the range [before_it + 1, after_it - 1] for key.   +// Preconditions: +// before_v <= key <= after_v +// before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v +// range is sorted. +template <class Iterator, class Accessor, class Pivot> bool BoundedSortedUniformFind( +    const Accessor &accessor, +    Iterator before_it, typename Accessor::Key before_v, +    Iterator after_it, typename Accessor::Key after_v, +    const typename Accessor::Key key, Iterator &out) { +  while (after_it - before_it > 1) { +    Iterator pivot(before_it + (1 + Pivot::Calc(key - before_v, after_v - before_v, after_it - before_it - 1))); +    typename Accessor::Key mid(accessor(pivot)); +    if (mid < key) { +      before_it = pivot; +      before_v = mid; +    } else if (mid > key) { +      after_it = pivot; +      after_v = mid; +    } else { +      out = pivot; +      return true; +    } +  } +  return false;  } -inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t width) { -  return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range)); -}*/ -template <class Iterator, class Key> bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) { +template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(const Accessor &accessor, Iterator begin, Iterator end, const typename Accessor::Key key, Iterator &out) {    if (begin == end) return false; -  Key below(begin->GetKey()); +  typename Accessor::Key below(accessor(begin));    if (key <= below) {      if (key == below) { out = begin; return true; }      return false;    }    // Make the range [begin, end].      --end; -  Key above(end->GetKey()); +  typename Accessor::Key above(accessor(end));    if (key >= above) {      if (key == above) { out = end; return true; }      return false;    } - -  // Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above.   -  while (end - begin > 1) { -    Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast<std::size_t>(end - begin - 1)))); -    Key mid(pivot->GetKey()); -    if (mid < key) { -      begin = pivot; -      below = mid; -    } else if (mid > key) { -      end = pivot; -      above = mid; -    } else { -      out = pivot; -      return true; -    } -  } -  return false; +  return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out);  }  // To use this template, you need to define a Pivot function to match Key.   @@ -64,7 +108,13 @@ template <class PackingT> class SortedUniformMap {      typedef typename Packing::ConstIterator ConstIterator;      typedef typename Packing::MutableIterator MutableIterator; -  public: +    struct Accessor { +      public: +        typedef typename Packing::Key Key; +        const Key &operator()(const ConstIterator &i) const { return i->GetKey(); } +        Key &operator()(const MutableIterator &i) const { return i->GetKey(); } +    }; +      // Offer consistent API with probing hash.      static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) {        return sizeof(uint64_t) + entries * Packing::kBytes; @@ -120,7 +170,7 @@ template <class PackingT> class SortedUniformMap {        assert(initialized_);        assert(loaded_);  #endif -      return SortedUniformFind<MutableIterator, Key>(begin_, end_, key, out); +      return SortedUniformFind<MutableIterator, Accessor, Pivot64>(begin_, end_, key, out);      }      // Do not call before FinishedInserting.   @@ -129,7 +179,7 @@ template <class PackingT> class SortedUniformMap {        assert(initialized_);        assert(loaded_);  #endif -      return SortedUniformFind<ConstIterator, Key>(ConstIterator(begin_), ConstIterator(end_), key, out); +      return SortedUniformFind<ConstIterator, Accessor, Pivot64>(Accessor(), ConstIterator(begin_), ConstIterator(end_), key, out);      }      ConstIterator begin() const { return begin_; } | 
