summaryrefslogtreecommitdiff
path: root/klm/util/probing_hash_table.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/util/probing_hash_table.hh')
-rw-r--r--klm/util/probing_hash_table.hh84
1 files changed, 79 insertions, 5 deletions
diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh
index 51a2944d..38524806 100644
--- a/klm/util/probing_hash_table.hh
+++ b/klm/util/probing_hash_table.hh
@@ -2,6 +2,7 @@
#define UTIL_PROBING_HASH_TABLE__
#include "util/exception.hh"
+#include "util/scoped.hh"
#include <algorithm>
#include <cstddef>
@@ -25,6 +26,8 @@ struct IdentityHash {
template <class T> T operator()(T arg) const { return arg; }
};
+template <class EntryT, class HashT, class EqualT> class AutoProbing;
+
/* Non-standard hash table
* Buckets must be set at the beginning and must be greater than maximum number
* of elements, else it throws ProbingSizeException.
@@ -33,7 +36,6 @@ struct IdentityHash {
* Uses linear probing to find value.
* Only insert and lookup operations.
*/
-
template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class ProbingHashTable {
public:
typedef EntryT Entry;
@@ -43,7 +45,6 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
typedef HashT Hash;
typedef EqualT Equal;
- public:
static uint64_t Size(uint64_t entries, float multiplier) {
uint64_t buckets = std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries)));
return buckets * sizeof(Entry);
@@ -69,6 +70,11 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#endif
{}
+ void Relocate(void *new_base) {
+ begin_ = reinterpret_cast<MutableIterator>(new_base);
+ end_ = begin_ + buckets_;
+ }
+
template <class T> MutableIterator Insert(const T &t) {
#ifdef DEBUG
assert(initialized_);
@@ -82,7 +88,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#ifdef DEBUG
assert(initialized_);
#endif
- for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) {
+ for (MutableIterator i = Ideal(t);;) {
Key got(i->GetKey());
if (equal_(got, t.GetKey())) { out = i; return true; }
if (equal_(got, invalid_)) {
@@ -97,8 +103,6 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
void FinishedInserting() {}
- void LoadedBinary() {}
-
// Don't change anything related to GetKey,
template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
#ifdef DEBUG
@@ -224,6 +228,8 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
}
private:
+ friend class AutoProbing<Entry, Hash, Equal>;
+
template <class T> MutableIterator Ideal(const T &t) {
return begin_ + (hash_(t.GetKey()) % buckets_);
}
@@ -247,6 +253,74 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#endif
};
+// Resizable linear probing hash table. This owns the memory.
+template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class AutoProbing {
+ private:
+ typedef ProbingHashTable<EntryT, HashT, EqualT> Backend;
+ public:
+ typedef EntryT Entry;
+ typedef typename Entry::Key Key;
+ typedef const Entry *ConstIterator;
+ typedef Entry *MutableIterator;
+ typedef HashT Hash;
+ typedef EqualT Equal;
+
+ AutoProbing(std::size_t initial_size = 10, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) :
+ allocated_(Backend::Size(initial_size, 1.5)), mem_(util::MallocOrThrow(allocated_)), backend_(mem_.get(), allocated_, invalid, hash_func, equal_func) {
+ threshold_ = initial_size * 1.2;
+ }
+
+ // Assumes that the key is unique. Multiple insertions won't cause a failure, just inconsistent lookup.
+ template <class T> MutableIterator Insert(const T &t) {
+ DoubleIfNeeded();
+ return backend_.UncheckedInsert(t);
+ }
+
+ template <class T> bool FindOrInsert(const T &t, MutableIterator &out) {
+ DoubleIfNeeded();
+ return backend_.FindOrInsert(t, out);
+ }
+
+ template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
+ return backend_.UnsafeMutableFind(key, out);
+ }
+
+ template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) {
+ return backend_.UnsafeMutableMustFind(key);
+ }
+
+ template <class Key> bool Find(const Key key, ConstIterator &out) const {
+ return backend_.Find(key, out);
+ }
+
+ template <class Key> ConstIterator MustFind(const Key key) const {
+ return backend_.MustFind(key);
+ }
+
+ std::size_t Size() const {
+ return backend_.SizeNoSerialization();
+ }
+
+ void Clear() {
+ backend_.Clear();
+ }
+
+ private:
+ void DoubleIfNeeded() {
+ if (Size() < threshold_)
+ return;
+ mem_.call_realloc(backend_.DoubleTo());
+ allocated_ = backend_.DoubleTo();
+ backend_.Double(mem_.get());
+ threshold_ *= 2;
+ }
+
+ std::size_t allocated_;
+ util::scoped_malloc mem_;
+ Backend backend_;
+ std::size_t threshold_;
+};
+
} // namespace util
#endif // UTIL_PROBING_HASH_TABLE__