diff options
Diffstat (limited to 'klm/util/probing_hash_table.hh')
-rw-r--r-- | klm/util/probing_hash_table.hh | 84 |
1 files changed, 79 insertions, 5 deletions
diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 51a2944d..38524806 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -2,6 +2,7 @@ #define UTIL_PROBING_HASH_TABLE__ #include "util/exception.hh" +#include "util/scoped.hh" #include <algorithm> #include <cstddef> @@ -25,6 +26,8 @@ struct IdentityHash { template <class T> T operator()(T arg) const { return arg; } }; +template <class EntryT, class HashT, class EqualT> class AutoProbing; + /* Non-standard hash table * Buckets must be set at the beginning and must be greater than maximum number * of elements, else it throws ProbingSizeException. @@ -33,7 +36,6 @@ struct IdentityHash { * Uses linear probing to find value. * Only insert and lookup operations. */ - template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class ProbingHashTable { public: typedef EntryT Entry; @@ -43,7 +45,6 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry typedef HashT Hash; typedef EqualT Equal; - public: static uint64_t Size(uint64_t entries, float multiplier) { uint64_t buckets = std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries))); return buckets * sizeof(Entry); @@ -69,6 +70,11 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry #endif {} + void Relocate(void *new_base) { + begin_ = reinterpret_cast<MutableIterator>(new_base); + end_ = begin_ + buckets_; + } + template <class T> MutableIterator Insert(const T &t) { #ifdef DEBUG assert(initialized_); @@ -82,7 +88,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry #ifdef DEBUG assert(initialized_); #endif - for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) { + for (MutableIterator i = Ideal(t);;) { Key got(i->GetKey()); if (equal_(got, t.GetKey())) { out = i; return true; } if (equal_(got, invalid_)) { @@ -97,8 +103,6 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry void FinishedInserting() {} - void LoadedBinary() {} - // Don't change anything related to GetKey, template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) { #ifdef DEBUG @@ -224,6 +228,8 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry } private: + friend class AutoProbing<Entry, Hash, Equal>; + template <class T> MutableIterator Ideal(const T &t) { return begin_ + (hash_(t.GetKey()) % buckets_); } @@ -247,6 +253,74 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry #endif }; +// Resizable linear probing hash table. This owns the memory. +template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class AutoProbing { + private: + typedef ProbingHashTable<EntryT, HashT, EqualT> Backend; + public: + typedef EntryT Entry; + typedef typename Entry::Key Key; + typedef const Entry *ConstIterator; + typedef Entry *MutableIterator; + typedef HashT Hash; + typedef EqualT Equal; + + AutoProbing(std::size_t initial_size = 10, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) : + allocated_(Backend::Size(initial_size, 1.5)), mem_(util::MallocOrThrow(allocated_)), backend_(mem_.get(), allocated_, invalid, hash_func, equal_func) { + threshold_ = initial_size * 1.2; + } + + // Assumes that the key is unique. Multiple insertions won't cause a failure, just inconsistent lookup. + template <class T> MutableIterator Insert(const T &t) { + DoubleIfNeeded(); + return backend_.UncheckedInsert(t); + } + + template <class T> bool FindOrInsert(const T &t, MutableIterator &out) { + DoubleIfNeeded(); + return backend_.FindOrInsert(t, out); + } + + template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) { + return backend_.UnsafeMutableFind(key, out); + } + + template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) { + return backend_.UnsafeMutableMustFind(key); + } + + template <class Key> bool Find(const Key key, ConstIterator &out) const { + return backend_.Find(key, out); + } + + template <class Key> ConstIterator MustFind(const Key key) const { + return backend_.MustFind(key); + } + + std::size_t Size() const { + return backend_.SizeNoSerialization(); + } + + void Clear() { + backend_.Clear(); + } + + private: + void DoubleIfNeeded() { + if (Size() < threshold_) + return; + mem_.call_realloc(backend_.DoubleTo()); + allocated_ = backend_.DoubleTo(); + backend_.Double(mem_.get()); + threshold_ *= 2; + } + + std::size_t allocated_; + util::scoped_malloc mem_; + Backend backend_; + std::size_t threshold_; +}; + } // namespace util #endif // UTIL_PROBING_HASH_TABLE__ |