diff options
Diffstat (limited to 'klm/util')
-rw-r--r-- | klm/util/bit_packing.cc | 4 | ||||
-rw-r--r-- | klm/util/bit_packing.hh | 39 | ||||
-rw-r--r-- | klm/util/bit_packing_test.cc | 25 | ||||
-rw-r--r-- | klm/util/sorted_uniform.hh | 120 |
4 files changed, 133 insertions, 55 deletions
diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc index 681da5f2..41999b72 100644 --- a/klm/util/bit_packing.cc +++ b/klm/util/bit_packing.cc @@ -28,10 +28,10 @@ void BitPackingSanity() { memset(mem, 0, sizeof(mem)); const uint64_t test57 = 0x123456789abcdefULL; for (uint64_t b = 0; b < 57 * 8; b += 57) { - WriteInt57(mem + b / 8, b % 8, 57, test57); + WriteInt57(mem, b, 57, test57); } for (uint64_t b = 0; b < 57 * 8; b += 57) { - if (test57 != ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)) + if (test57 != ReadInt57(mem, b, 57, (1ULL << 57) - 1)) UTIL_THROW(Exception, "The bit packing routines are failing for your architecture. Please send a bug report with your architecture, operating system, and compiler."); } // TODO: more checks. diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 5c71c792..b35d80c8 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -42,47 +42,62 @@ inline uint8_t BitPackShift(uint8_t bit, uint8_t length) { #error "Bit packing code isn't written for your byte order." #endif +inline uint64_t ReadOff(const void *base, uint64_t bit_off) { + return *reinterpret_cast<const uint64_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)); +} + /* Pack integers up to 57 bits using their least significant digits. * The length is specified using mask: * Assumes mask == (1 << length) - 1 where length <= 57. */ -inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) { - return (*reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, length)) & mask; +inline uint64_t ReadInt57(const void *base, uint64_t bit_off, uint8_t length, uint64_t mask) { + return (ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, length)) & mask; } /* Assumes value < (1 << length) and length <= 57. * Assumes the memory is zero initially. */ -inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) { - *reinterpret_cast<uint64_t*>(base) |= (value << BitPackShift(bit, length)); +inline void WriteInt57(void *base, uint64_t bit_off, uint8_t length, uint64_t value) { + *reinterpret_cast<uint64_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |= + (value << BitPackShift(bit_off & 7, length)); +} + +/* Same caveats as above, but for a 25 bit limit. */ +inline uint32_t ReadInt25(const void *base, uint64_t bit_off, uint8_t length, uint32_t mask) { + return (*reinterpret_cast<const uint32_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)) >> BitPackShift(bit_off & 7, length)) & mask; +} + +inline void WriteInt25(void *base, uint64_t bit_off, uint8_t length, uint32_t value) { + *reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |= + (value << BitPackShift(bit_off & 7, length)); } typedef union { float f; uint32_t i; } FloatEnc; -inline float ReadFloat32(const void *base, uint8_t bit) { +inline float ReadFloat32(const void *base, uint64_t bit_off) { FloatEnc encoded; - encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 32); + encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32); return encoded.f; } -inline void WriteFloat32(void *base, uint8_t bit, float value) { +inline void WriteFloat32(void *base, uint64_t bit_off, float value) { FloatEnc encoded; encoded.f = value; - WriteInt57(base, bit, 32, encoded.i); + WriteInt57(base, bit_off, 32, encoded.i); } const uint32_t kSignBit = 0x80000000; -inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) { +inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) { FloatEnc encoded; - encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 31); + encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31); // Sign bit set means negative. encoded.i |= kSignBit; return encoded.f; } -inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) { +inline void WriteNonPositiveFloat31(void *base, uint64_t bit_off, float value) { FloatEnc encoded; encoded.f = value; encoded.i &= ~kSignBit; - WriteInt57(base, bit, 31, encoded.i); + WriteInt57(base, bit_off, 31, encoded.i); } void BitPackingSanity(); diff --git a/klm/util/bit_packing_test.cc b/klm/util/bit_packing_test.cc index c578ddd1..4edc2004 100644 --- a/klm/util/bit_packing_test.cc +++ b/klm/util/bit_packing_test.cc @@ -9,15 +9,16 @@ namespace util { namespace { const uint64_t test57 = 0x123456789abcdefULL; +const uint32_t test25 = 0x1234567; -BOOST_AUTO_TEST_CASE(ZeroBit) { +BOOST_AUTO_TEST_CASE(ZeroBit57) { char mem[16]; memset(mem, 0, sizeof(mem)); WriteInt57(mem, 0, 57, test57); BOOST_CHECK_EQUAL(test57, ReadInt57(mem, 0, 57, (1ULL << 57) - 1)); } -BOOST_AUTO_TEST_CASE(EachBit) { +BOOST_AUTO_TEST_CASE(EachBit57) { char mem[16]; for (uint8_t b = 0; b < 8; ++b) { memset(mem, 0, sizeof(mem)); @@ -26,15 +27,27 @@ BOOST_AUTO_TEST_CASE(EachBit) { } } -BOOST_AUTO_TEST_CASE(Consecutive) { +BOOST_AUTO_TEST_CASE(Consecutive57) { char mem[57+8]; memset(mem, 0, sizeof(mem)); for (uint64_t b = 0; b < 57 * 8; b += 57) { - WriteInt57(mem + (b / 8), b % 8, 57, test57); - BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + WriteInt57(mem, b, 57, test57); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); } for (uint64_t b = 0; b < 57 * 8; b += 57) { - BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); + } +} + +BOOST_AUTO_TEST_CASE(Consecutive25) { + char mem[25+8]; + memset(mem, 0, sizeof(mem)); + for (uint64_t b = 0; b < 25 * 8; b += 25) { + WriteInt25(mem, b, 25, test25); + BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); + } + for (uint64_t b = 0; b < 25 * 8; b += 25) { + BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); } } diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 05826b51..84d7aa02 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -9,52 +9,96 @@ namespace util { -inline std::size_t Pivot(uint64_t off, uint64_t range, std::size_t width) { - std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width)); - // Cap for floating point rounding - return (ret < width) ? ret : width - 1; -} -/*inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) { - return static_cast<std::size_t>(static_cast<uint64_t>(off) * static_cast<uint64_t>(width) / static_cast<uint64_t>(range)); +template <class T> class IdentityAccessor { + public: + typedef T Key; + T operator()(const uint64_t *in) const { return *in; } +}; + +struct Pivot64 { + static inline std::size_t Calc(uint64_t off, uint64_t range, std::size_t width) { + std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width)); + // Cap for floating point rounding + return (ret < width) ? ret : width - 1; + } +}; + +// Use when off * width is <2^64. This is guaranteed when each of them is actually a 32-bit value. +struct Pivot32 { + static inline std::size_t Calc(uint64_t off, uint64_t range, uint64_t width) { + return static_cast<std::size_t>((off * width) / (range + 1)); + } +}; + +// Usage: PivotSelect<sizeof(DataType)>::T +template <unsigned> struct PivotSelect; +template <> struct PivotSelect<8> { typedef Pivot64 T; }; +template <> struct PivotSelect<4> { typedef Pivot32 T; }; +template <> struct PivotSelect<2> { typedef Pivot32 T; }; + +/* Binary search. */ +template <class Iterator, class Accessor> bool BinaryFind( + const Accessor &accessor, + Iterator begin, + Iterator end, + const typename Accessor::Key key, Iterator &out) { + while (end > begin) { + Iterator pivot(begin + (end - begin) / 2); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + begin = pivot + 1; + } else if (mid > key) { + end = pivot; + } else { + out = pivot; + return true; + } + } + return false; } -inline std::size_t Pivot(uint16_t off, uint16_t range, std::size_t width) { - return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range)); + +// Search the range [before_it + 1, after_it - 1] for key. +// Preconditions: +// before_v <= key <= after_v +// before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v +// range is sorted. +template <class Iterator, class Accessor, class Pivot> bool BoundedSortedUniformFind( + const Accessor &accessor, + Iterator before_it, typename Accessor::Key before_v, + Iterator after_it, typename Accessor::Key after_v, + const typename Accessor::Key key, Iterator &out) { + while (after_it - before_it > 1) { + Iterator pivot(before_it + (1 + Pivot::Calc(key - before_v, after_v - before_v, after_it - before_it - 1))); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + before_it = pivot; + before_v = mid; + } else if (mid > key) { + after_it = pivot; + after_v = mid; + } else { + out = pivot; + return true; + } + } + return false; } -inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t width) { - return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range)); -}*/ -template <class Iterator, class Key> bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) { +template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(const Accessor &accessor, Iterator begin, Iterator end, const typename Accessor::Key key, Iterator &out) { if (begin == end) return false; - Key below(begin->GetKey()); + typename Accessor::Key below(accessor(begin)); if (key <= below) { if (key == below) { out = begin; return true; } return false; } // Make the range [begin, end]. --end; - Key above(end->GetKey()); + typename Accessor::Key above(accessor(end)); if (key >= above) { if (key == above) { out = end; return true; } return false; } - - // Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above. - while (end - begin > 1) { - Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast<std::size_t>(end - begin - 1)))); - Key mid(pivot->GetKey()); - if (mid < key) { - begin = pivot; - below = mid; - } else if (mid > key) { - end = pivot; - above = mid; - } else { - out = pivot; - return true; - } - } - return false; + return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out); } // To use this template, you need to define a Pivot function to match Key. @@ -64,7 +108,13 @@ template <class PackingT> class SortedUniformMap { typedef typename Packing::ConstIterator ConstIterator; typedef typename Packing::MutableIterator MutableIterator; - public: + struct Accessor { + public: + typedef typename Packing::Key Key; + const Key &operator()(const ConstIterator &i) const { return i->GetKey(); } + Key &operator()(const MutableIterator &i) const { return i->GetKey(); } + }; + // Offer consistent API with probing hash. static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) { return sizeof(uint64_t) + entries * Packing::kBytes; @@ -120,7 +170,7 @@ template <class PackingT> class SortedUniformMap { assert(initialized_); assert(loaded_); #endif - return SortedUniformFind<MutableIterator, Key>(begin_, end_, key, out); + return SortedUniformFind<MutableIterator, Accessor, Pivot64>(begin_, end_, key, out); } // Do not call before FinishedInserting. @@ -129,7 +179,7 @@ template <class PackingT> class SortedUniformMap { assert(initialized_); assert(loaded_); #endif - return SortedUniformFind<ConstIterator, Key>(ConstIterator(begin_), ConstIterator(end_), key, out); + return SortedUniformFind<ConstIterator, Accessor, Pivot64>(Accessor(), ConstIterator(begin_), ConstIterator(end_), key, out); } ConstIterator begin() const { return begin_; } |