summaryrefslogtreecommitdiff
path: root/klm/util/sorted_uniform.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/util/sorted_uniform.hh')
-rw-r--r--klm/util/sorted_uniform.hh120
1 files changed, 85 insertions, 35 deletions
diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh
index 05826b51..84d7aa02 100644
--- a/klm/util/sorted_uniform.hh
+++ b/klm/util/sorted_uniform.hh
@@ -9,52 +9,96 @@
namespace util {
-inline std::size_t Pivot(uint64_t off, uint64_t range, std::size_t width) {
- std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width));
- // Cap for floating point rounding
- return (ret < width) ? ret : width - 1;
-}
-/*inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) {
- return static_cast<std::size_t>(static_cast<uint64_t>(off) * static_cast<uint64_t>(width) / static_cast<uint64_t>(range));
+template <class T> class IdentityAccessor {
+ public:
+ typedef T Key;
+ T operator()(const uint64_t *in) const { return *in; }
+};
+
+struct Pivot64 {
+ static inline std::size_t Calc(uint64_t off, uint64_t range, std::size_t width) {
+ std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width));
+ // Cap for floating point rounding
+ return (ret < width) ? ret : width - 1;
+ }
+};
+
+// Use when off * width is <2^64. This is guaranteed when each of them is actually a 32-bit value.
+struct Pivot32 {
+ static inline std::size_t Calc(uint64_t off, uint64_t range, uint64_t width) {
+ return static_cast<std::size_t>((off * width) / (range + 1));
+ }
+};
+
+// Usage: PivotSelect<sizeof(DataType)>::T
+template <unsigned> struct PivotSelect;
+template <> struct PivotSelect<8> { typedef Pivot64 T; };
+template <> struct PivotSelect<4> { typedef Pivot32 T; };
+template <> struct PivotSelect<2> { typedef Pivot32 T; };
+
+/* Binary search. */
+template <class Iterator, class Accessor> bool BinaryFind(
+ const Accessor &accessor,
+ Iterator begin,
+ Iterator end,
+ const typename Accessor::Key key, Iterator &out) {
+ while (end > begin) {
+ Iterator pivot(begin + (end - begin) / 2);
+ typename Accessor::Key mid(accessor(pivot));
+ if (mid < key) {
+ begin = pivot + 1;
+ } else if (mid > key) {
+ end = pivot;
+ } else {
+ out = pivot;
+ return true;
+ }
+ }
+ return false;
}
-inline std::size_t Pivot(uint16_t off, uint16_t range, std::size_t width) {
- return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range));
+
+// Search the range [before_it + 1, after_it - 1] for key.
+// Preconditions:
+// before_v <= key <= after_v
+// before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v
+// range is sorted.
+template <class Iterator, class Accessor, class Pivot> bool BoundedSortedUniformFind(
+ const Accessor &accessor,
+ Iterator before_it, typename Accessor::Key before_v,
+ Iterator after_it, typename Accessor::Key after_v,
+ const typename Accessor::Key key, Iterator &out) {
+ while (after_it - before_it > 1) {
+ Iterator pivot(before_it + (1 + Pivot::Calc(key - before_v, after_v - before_v, after_it - before_it - 1)));
+ typename Accessor::Key mid(accessor(pivot));
+ if (mid < key) {
+ before_it = pivot;
+ before_v = mid;
+ } else if (mid > key) {
+ after_it = pivot;
+ after_v = mid;
+ } else {
+ out = pivot;
+ return true;
+ }
+ }
+ return false;
}
-inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t width) {
- return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range));
-}*/
-template <class Iterator, class Key> bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) {
+template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(const Accessor &accessor, Iterator begin, Iterator end, const typename Accessor::Key key, Iterator &out) {
if (begin == end) return false;
- Key below(begin->GetKey());
+ typename Accessor::Key below(accessor(begin));
if (key <= below) {
if (key == below) { out = begin; return true; }
return false;
}
// Make the range [begin, end].
--end;
- Key above(end->GetKey());
+ typename Accessor::Key above(accessor(end));
if (key >= above) {
if (key == above) { out = end; return true; }
return false;
}
-
- // Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above.
- while (end - begin > 1) {
- Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast<std::size_t>(end - begin - 1))));
- Key mid(pivot->GetKey());
- if (mid < key) {
- begin = pivot;
- below = mid;
- } else if (mid > key) {
- end = pivot;
- above = mid;
- } else {
- out = pivot;
- return true;
- }
- }
- return false;
+ return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out);
}
// To use this template, you need to define a Pivot function to match Key.
@@ -64,7 +108,13 @@ template <class PackingT> class SortedUniformMap {
typedef typename Packing::ConstIterator ConstIterator;
typedef typename Packing::MutableIterator MutableIterator;
- public:
+ struct Accessor {
+ public:
+ typedef typename Packing::Key Key;
+ const Key &operator()(const ConstIterator &i) const { return i->GetKey(); }
+ Key &operator()(const MutableIterator &i) const { return i->GetKey(); }
+ };
+
// Offer consistent API with probing hash.
static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) {
return sizeof(uint64_t) + entries * Packing::kBytes;
@@ -120,7 +170,7 @@ template <class PackingT> class SortedUniformMap {
assert(initialized_);
assert(loaded_);
#endif
- return SortedUniformFind<MutableIterator, Key>(begin_, end_, key, out);
+ return SortedUniformFind<MutableIterator, Accessor, Pivot64>(begin_, end_, key, out);
}
// Do not call before FinishedInserting.
@@ -129,7 +179,7 @@ template <class PackingT> class SortedUniformMap {
assert(initialized_);
assert(loaded_);
#endif
- return SortedUniformFind<ConstIterator, Key>(ConstIterator(begin_), ConstIterator(end_), key, out);
+ return SortedUniformFind<ConstIterator, Accessor, Pivot64>(Accessor(), ConstIterator(begin_), ConstIterator(end_), key, out);
}
ConstIterator begin() const { return begin_; }