summaryrefslogtreecommitdiff
path: root/klm/util
diff options
context:
space:
mode:
Diffstat (limited to 'klm/util')
-rw-r--r--klm/util/bit_packing.cc4
-rw-r--r--klm/util/bit_packing.hh39
-rw-r--r--klm/util/bit_packing_test.cc25
-rw-r--r--klm/util/sorted_uniform.hh120
4 files changed, 133 insertions, 55 deletions
diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc
index 681da5f2..41999b72 100644
--- a/klm/util/bit_packing.cc
+++ b/klm/util/bit_packing.cc
@@ -28,10 +28,10 @@ void BitPackingSanity() {
memset(mem, 0, sizeof(mem));
const uint64_t test57 = 0x123456789abcdefULL;
for (uint64_t b = 0; b < 57 * 8; b += 57) {
- WriteInt57(mem + b / 8, b % 8, 57, test57);
+ WriteInt57(mem, b, 57, test57);
}
for (uint64_t b = 0; b < 57 * 8; b += 57) {
- if (test57 != ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1))
+ if (test57 != ReadInt57(mem, b, 57, (1ULL << 57) - 1))
UTIL_THROW(Exception, "The bit packing routines are failing for your architecture. Please send a bug report with your architecture, operating system, and compiler.");
}
// TODO: more checks.
diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh
index 5c71c792..b35d80c8 100644
--- a/klm/util/bit_packing.hh
+++ b/klm/util/bit_packing.hh
@@ -42,47 +42,62 @@ inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {
#error "Bit packing code isn't written for your byte order."
#endif
+inline uint64_t ReadOff(const void *base, uint64_t bit_off) {
+ return *reinterpret_cast<const uint64_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3));
+}
+
/* Pack integers up to 57 bits using their least significant digits.
* The length is specified using mask:
* Assumes mask == (1 << length) - 1 where length <= 57.
*/
-inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) {
- return (*reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, length)) & mask;
+inline uint64_t ReadInt57(const void *base, uint64_t bit_off, uint8_t length, uint64_t mask) {
+ return (ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, length)) & mask;
}
/* Assumes value < (1 << length) and length <= 57.
* Assumes the memory is zero initially.
*/
-inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) {
- *reinterpret_cast<uint64_t*>(base) |= (value << BitPackShift(bit, length));
+inline void WriteInt57(void *base, uint64_t bit_off, uint8_t length, uint64_t value) {
+ *reinterpret_cast<uint64_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |=
+ (value << BitPackShift(bit_off & 7, length));
+}
+
+/* Same caveats as above, but for a 25 bit limit. */
+inline uint32_t ReadInt25(const void *base, uint64_t bit_off, uint8_t length, uint32_t mask) {
+ return (*reinterpret_cast<const uint32_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)) >> BitPackShift(bit_off & 7, length)) & mask;
+}
+
+inline void WriteInt25(void *base, uint64_t bit_off, uint8_t length, uint32_t value) {
+ *reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |=
+ (value << BitPackShift(bit_off & 7, length));
}
typedef union { float f; uint32_t i; } FloatEnc;
-inline float ReadFloat32(const void *base, uint8_t bit) {
+inline float ReadFloat32(const void *base, uint64_t bit_off) {
FloatEnc encoded;
- encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 32);
+ encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32);
return encoded.f;
}
-inline void WriteFloat32(void *base, uint8_t bit, float value) {
+inline void WriteFloat32(void *base, uint64_t bit_off, float value) {
FloatEnc encoded;
encoded.f = value;
- WriteInt57(base, bit, 32, encoded.i);
+ WriteInt57(base, bit_off, 32, encoded.i);
}
const uint32_t kSignBit = 0x80000000;
-inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) {
+inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) {
FloatEnc encoded;
- encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 31);
+ encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31);
// Sign bit set means negative.
encoded.i |= kSignBit;
return encoded.f;
}
-inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) {
+inline void WriteNonPositiveFloat31(void *base, uint64_t bit_off, float value) {
FloatEnc encoded;
encoded.f = value;
encoded.i &= ~kSignBit;
- WriteInt57(base, bit, 31, encoded.i);
+ WriteInt57(base, bit_off, 31, encoded.i);
}
void BitPackingSanity();
diff --git a/klm/util/bit_packing_test.cc b/klm/util/bit_packing_test.cc
index c578ddd1..4edc2004 100644
--- a/klm/util/bit_packing_test.cc
+++ b/klm/util/bit_packing_test.cc
@@ -9,15 +9,16 @@ namespace util {
namespace {
const uint64_t test57 = 0x123456789abcdefULL;
+const uint32_t test25 = 0x1234567;
-BOOST_AUTO_TEST_CASE(ZeroBit) {
+BOOST_AUTO_TEST_CASE(ZeroBit57) {
char mem[16];
memset(mem, 0, sizeof(mem));
WriteInt57(mem, 0, 57, test57);
BOOST_CHECK_EQUAL(test57, ReadInt57(mem, 0, 57, (1ULL << 57) - 1));
}
-BOOST_AUTO_TEST_CASE(EachBit) {
+BOOST_AUTO_TEST_CASE(EachBit57) {
char mem[16];
for (uint8_t b = 0; b < 8; ++b) {
memset(mem, 0, sizeof(mem));
@@ -26,15 +27,27 @@ BOOST_AUTO_TEST_CASE(EachBit) {
}
}
-BOOST_AUTO_TEST_CASE(Consecutive) {
+BOOST_AUTO_TEST_CASE(Consecutive57) {
char mem[57+8];
memset(mem, 0, sizeof(mem));
for (uint64_t b = 0; b < 57 * 8; b += 57) {
- WriteInt57(mem + (b / 8), b % 8, 57, test57);
- BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1));
+ WriteInt57(mem, b, 57, test57);
+ BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1));
}
for (uint64_t b = 0; b < 57 * 8; b += 57) {
- BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1));
+ BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1));
+ }
+}
+
+BOOST_AUTO_TEST_CASE(Consecutive25) {
+ char mem[25+8];
+ memset(mem, 0, sizeof(mem));
+ for (uint64_t b = 0; b < 25 * 8; b += 25) {
+ WriteInt25(mem, b, 25, test25);
+ BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1));
+ }
+ for (uint64_t b = 0; b < 25 * 8; b += 25) {
+ BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1));
}
}
diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh
index 05826b51..84d7aa02 100644
--- a/klm/util/sorted_uniform.hh
+++ b/klm/util/sorted_uniform.hh
@@ -9,52 +9,96 @@
namespace util {
-inline std::size_t Pivot(uint64_t off, uint64_t range, std::size_t width) {
- std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width));
- // Cap for floating point rounding
- return (ret < width) ? ret : width - 1;
-}
-/*inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) {
- return static_cast<std::size_t>(static_cast<uint64_t>(off) * static_cast<uint64_t>(width) / static_cast<uint64_t>(range));
+template <class T> class IdentityAccessor {
+ public:
+ typedef T Key;
+ T operator()(const uint64_t *in) const { return *in; }
+};
+
+struct Pivot64 {
+ static inline std::size_t Calc(uint64_t off, uint64_t range, std::size_t width) {
+ std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width));
+ // Cap for floating point rounding
+ return (ret < width) ? ret : width - 1;
+ }
+};
+
+// Use when off * width is <2^64. This is guaranteed when each of them is actually a 32-bit value.
+struct Pivot32 {
+ static inline std::size_t Calc(uint64_t off, uint64_t range, uint64_t width) {
+ return static_cast<std::size_t>((off * width) / (range + 1));
+ }
+};
+
+// Usage: PivotSelect<sizeof(DataType)>::T
+template <unsigned> struct PivotSelect;
+template <> struct PivotSelect<8> { typedef Pivot64 T; };
+template <> struct PivotSelect<4> { typedef Pivot32 T; };
+template <> struct PivotSelect<2> { typedef Pivot32 T; };
+
+/* Binary search. */
+template <class Iterator, class Accessor> bool BinaryFind(
+ const Accessor &accessor,
+ Iterator begin,
+ Iterator end,
+ const typename Accessor::Key key, Iterator &out) {
+ while (end > begin) {
+ Iterator pivot(begin + (end - begin) / 2);
+ typename Accessor::Key mid(accessor(pivot));
+ if (mid < key) {
+ begin = pivot + 1;
+ } else if (mid > key) {
+ end = pivot;
+ } else {
+ out = pivot;
+ return true;
+ }
+ }
+ return false;
}
-inline std::size_t Pivot(uint16_t off, uint16_t range, std::size_t width) {
- return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range));
+
+// Search the range [before_it + 1, after_it - 1] for key.
+// Preconditions:
+// before_v <= key <= after_v
+// before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v
+// range is sorted.
+template <class Iterator, class Accessor, class Pivot> bool BoundedSortedUniformFind(
+ const Accessor &accessor,
+ Iterator before_it, typename Accessor::Key before_v,
+ Iterator after_it, typename Accessor::Key after_v,
+ const typename Accessor::Key key, Iterator &out) {
+ while (after_it - before_it > 1) {
+ Iterator pivot(before_it + (1 + Pivot::Calc(key - before_v, after_v - before_v, after_it - before_it - 1)));
+ typename Accessor::Key mid(accessor(pivot));
+ if (mid < key) {
+ before_it = pivot;
+ before_v = mid;
+ } else if (mid > key) {
+ after_it = pivot;
+ after_v = mid;
+ } else {
+ out = pivot;
+ return true;
+ }
+ }
+ return false;
}
-inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t width) {
- return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range));
-}*/
-template <class Iterator, class Key> bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) {
+template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(const Accessor &accessor, Iterator begin, Iterator end, const typename Accessor::Key key, Iterator &out) {
if (begin == end) return false;
- Key below(begin->GetKey());
+ typename Accessor::Key below(accessor(begin));
if (key <= below) {
if (key == below) { out = begin; return true; }
return false;
}
// Make the range [begin, end].
--end;
- Key above(end->GetKey());
+ typename Accessor::Key above(accessor(end));
if (key >= above) {
if (key == above) { out = end; return true; }
return false;
}
-
- // Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above.
- while (end - begin > 1) {
- Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast<std::size_t>(end - begin - 1))));
- Key mid(pivot->GetKey());
- if (mid < key) {
- begin = pivot;
- below = mid;
- } else if (mid > key) {
- end = pivot;
- above = mid;
- } else {
- out = pivot;
- return true;
- }
- }
- return false;
+ return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out);
}
// To use this template, you need to define a Pivot function to match Key.
@@ -64,7 +108,13 @@ template <class PackingT> class SortedUniformMap {
typedef typename Packing::ConstIterator ConstIterator;
typedef typename Packing::MutableIterator MutableIterator;
- public:
+ struct Accessor {
+ public:
+ typedef typename Packing::Key Key;
+ const Key &operator()(const ConstIterator &i) const { return i->GetKey(); }
+ Key &operator()(const MutableIterator &i) const { return i->GetKey(); }
+ };
+
// Offer consistent API with probing hash.
static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) {
return sizeof(uint64_t) + entries * Packing::kBytes;
@@ -120,7 +170,7 @@ template <class PackingT> class SortedUniformMap {
assert(initialized_);
assert(loaded_);
#endif
- return SortedUniformFind<MutableIterator, Key>(begin_, end_, key, out);
+ return SortedUniformFind<MutableIterator, Accessor, Pivot64>(begin_, end_, key, out);
}
// Do not call before FinishedInserting.
@@ -129,7 +179,7 @@ template <class PackingT> class SortedUniformMap {
assert(initialized_);
assert(loaded_);
#endif
- return SortedUniformFind<ConstIterator, Key>(ConstIterator(begin_), ConstIterator(end_), key, out);
+ return SortedUniformFind<ConstIterator, Accessor, Pivot64>(Accessor(), ConstIterator(begin_), ConstIterator(end_), key, out);
}
ConstIterator begin() const { return begin_; }