summaryrefslogtreecommitdiff
path: root/klm/util/probing_hash_table.hh
blob: f466cebc920dad83db44eea70f6adf44f61419bc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#ifndef UTIL_PROBING_HASH_TABLE__
#define UTIL_PROBING_HASH_TABLE__

#include "util/exception.hh"

#include <algorithm>
#include <cstddef>
#include <functional>

#include <assert.h>

namespace util {

/* Thrown when table grows too large */
class ProbingSizeException : public Exception {
  public:
    ProbingSizeException() throw() {}
    ~ProbingSizeException() throw() {}
};

// std::identity is an SGI extension :-(
struct IdentityHash {
  template <class T> T operator()(T arg) const { return arg; }
};

/* Non-standard hash table
 * Buckets must be set at the beginning and must be greater than maximum number
 * of elements, else it throws ProbingSizeException.
 * Memory management and initialization is externalized to make it easier to
 * serialize these to disk and load them quickly.
 * Uses linear probing to find value.
 * Only insert and lookup operations.  
 */

template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class ProbingHashTable {
  public:
    typedef EntryT Entry;
    typedef typename Entry::Key Key;
    typedef const Entry *ConstIterator;
    typedef Entry *MutableIterator;
    typedef HashT Hash;
    typedef EqualT Equal;

  public:
    static std::size_t Size(std::size_t entries, float multiplier) {
      std::size_t buckets = std::max(entries + 1, static_cast<std::size_t>(multiplier * static_cast<float>(entries)));
      return buckets * sizeof(Entry);
    }

    // Must be assigned to later.  
    ProbingHashTable() : entries_(0)
#ifdef DEBUG
      , initialized_(false)
#endif
    {}

    ProbingHashTable(void *start, std::size_t allocated, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal())
      : begin_(reinterpret_cast<MutableIterator>(start)),
        buckets_(allocated / sizeof(Entry)),
        end_(begin_ + buckets_),
        invalid_(invalid),
        hash_(hash_func),
        equal_(equal_func),
        entries_(0)
#ifdef DEBUG
        , initialized_(true)
#endif
    {}

    template <class T> MutableIterator Insert(const T &t) {
#ifdef DEBUG
      assert(initialized_);
#endif
      UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full.");
      for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) {
        if (equal_(i->GetKey(), invalid_)) { *i = t; return i; }
        if (++i == end_) { i = begin_; }
      }
    }

    void FinishedInserting() {}

    void LoadedBinary() {}

    // Don't change anything related to GetKey,  
    template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
      for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) {
        Key got(i->GetKey());
        if (equal_(got, key)) { out = i; return true; }
        if (equal_(got, invalid_)) return false;
        if (++i == end_) i = begin_;
      }   
    }

    template <class Key> bool Find(const Key key, ConstIterator &out) const {
#ifdef DEBUG
      assert(initialized_);
#endif
      for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) {
        Key got(i->GetKey());
        if (equal_(got, key)) { out = i; return true; }
        if (equal_(got, invalid_)) return false;
        if (++i == end_) i = begin_;
      }    
    }

  private:
    MutableIterator begin_;
    std::size_t buckets_;
    MutableIterator end_;
    Key invalid_;
    Hash hash_;
    Equal equal_;
    std::size_t entries_;
#ifdef DEBUG
    bool initialized_;
#endif
};

} // namespace util

#endif // UTIL_PROBING_HASH_TABLE__