diff options
Diffstat (limited to 'utils/fast_sparse_vector.h')
-rw-r--r-- | utils/fast_sparse_vector.h | 108 |
1 files changed, 94 insertions, 14 deletions
diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h index e86cbdc1..433a5cc5 100644 --- a/utils/fast_sparse_vector.h +++ b/utils/fast_sparse_vector.h @@ -66,6 +66,60 @@ BOOST_STATIC_ASSERT(sizeof(PairIntT<float>) == sizeof(std::pair<unsigned,float>) template <typename T, unsigned LOCAL_MAX = (sizeof(T) == sizeof(float) ? 15u : 7u)> class FastSparseVector { public: + struct iterator { + iterator(FastSparseVector<T>& v, const bool is_end) : local_(!v.is_remote_) { + if (local_) { + local_it_ = &v.data_.local[is_end ? v.local_size_ : 0]; + } else { + if (is_end) + remote_it_ = v.data_.rbmap->end(); + else + remote_it_ = v.data_.rbmap->begin(); + } + } + iterator(FastSparseVector<T>& v, const bool, const unsigned k) : local_(!v.is_remote_) { + if (local_) { + unsigned i = 0; + while(i < v.local_size_ && v.data_.local[i].first() != k) { ++i; } + local_it_ = &v.data_.local[i]; + } else { + remote_it_ = v.data_.rbmap->find(k); + } + } + const bool local_; + PairIntT<T>* local_it_; + typename SPARSE_HASH_MAP<unsigned, T>::iterator remote_it_; + std::pair<const unsigned, T>& operator*() const { + if (local_) + return *reinterpret_cast<std::pair<const unsigned, T>*>(local_it_); + else + return *remote_it_; + } + + std::pair<const unsigned, T>* operator->() const { + if (local_) + return reinterpret_cast<std::pair<const unsigned, T>*>(local_it_); + else + return &*remote_it_; + } + + iterator& operator++() { + if (local_) ++local_it_; else ++remote_it_; + return *this; + } + + inline bool operator==(const iterator& o) const { + if (o.local_ != local_) return false; + if (local_) { + return local_it_ == o.local_it_; + } else { + return remote_it_ == o.remote_it_; + } + } + inline bool operator!=(const iterator& o) const { + return !(o == *this); + } + }; struct const_iterator { const_iterator(const FastSparseVector<T>& v, const bool is_end) : local_(!v.is_remote_) { if (local_) { @@ -77,12 +131,21 @@ class FastSparseVector { remote_it_ = v.data_.rbmap->begin(); } } + const_iterator(const FastSparseVector<T>& v, const bool, const unsigned k) : local_(!v.is_remote_) { + if (local_) { + unsigned i = 0; + while(i < v.local_size_ && v.data_.local[i].first() != k) { ++i; } + local_it_ = &v.data_.local[i]; + } else { + remote_it_ = v.data_.rbmap->find(k); + } + } const bool local_; const PairIntT<T>* local_it_; - typename std::map<unsigned, T>::const_iterator remote_it_; + typename SPARSE_HASH_MAP<unsigned, T>::const_iterator remote_it_; const std::pair<const unsigned, T>& operator*() const { if (local_) - return *reinterpret_cast<const std::pair<const unsigned, float>*>(local_it_); + return *reinterpret_cast<const std::pair<const unsigned, T>*>(local_it_); else return *remote_it_; } @@ -118,7 +181,7 @@ class FastSparseVector { } FastSparseVector(const FastSparseVector& other) { std::memcpy(this, &other, sizeof(FastSparseVector)); - if (is_remote_) data_.rbmap = new std::map<unsigned, T>(*data_.rbmap); + if (is_remote_) data_.rbmap = new SPARSE_HASH_MAP<unsigned, T>(*data_.rbmap); } FastSparseVector(std::pair<unsigned, T>* first, std::pair<unsigned, T>* last) { const ptrdiff_t n = last - first; @@ -128,7 +191,7 @@ class FastSparseVector { std::memcpy(data_.local, first, sizeof(std::pair<unsigned, T>) * n); } else { is_remote_ = true; - data_.rbmap = new std::map<unsigned, T>(first, last); + data_.rbmap = new SPARSE_HASH_MAP<unsigned, T>(first, last); } } void erase(int k) { @@ -150,7 +213,7 @@ class FastSparseVector { clear(); std::memcpy(this, &other, sizeof(FastSparseVector)); if (is_remote_) - data_.rbmap = new std::map<unsigned, T>(*data_.rbmap); + data_.rbmap = new SPARSE_HASH_MAP<unsigned, T>(*data_.rbmap); return *this; } T const& get_singleton() const { @@ -160,6 +223,9 @@ class FastSparseVector { bool nonzero(unsigned k) const { return static_cast<bool>(value(k)); } + inline T& operator[](unsigned k) { + return get_or_create_bin(k); + } inline void set_value(unsigned k, const T& v) { get_or_create_bin(k) = v; } @@ -171,7 +237,7 @@ class FastSparseVector { } inline T value(unsigned k) const { if (is_remote_) { - typename std::map<unsigned, T>::const_iterator it = data_.rbmap->find(k); + typename SPARSE_HASH_MAP<unsigned, T>::const_iterator it = data_.rbmap->find(k); if (it != data_.rbmap->end()) return it->second; } else { for (unsigned i = 0; i < local_size_; ++i) { @@ -256,8 +322,8 @@ class FastSparseVector { } inline FastSparseVector& operator*=(const T& scalar) { if (is_remote_) { - const typename std::map<unsigned, T>::iterator end = data_.rbmap->end(); - for (typename std::map<unsigned, T>::iterator it = data_.rbmap->begin(); it != end; ++it) + const typename SPARSE_HASH_MAP<unsigned, T>::iterator end = data_.rbmap->end(); + for (typename SPARSE_HASH_MAP<unsigned, T>::iterator it = data_.rbmap->begin(); it != end; ++it) it->second *= scalar; } else { for (int i = 0; i < local_size_; ++i) @@ -267,8 +333,8 @@ class FastSparseVector { } inline FastSparseVector& operator/=(const T& scalar) { if (is_remote_) { - const typename std::map<unsigned, T>::iterator end = data_.rbmap->end(); - for (typename std::map<unsigned, T>::iterator it = data_.rbmap->begin(); it != end; ++it) + const typename SPARSE_HASH_MAP<unsigned, T>::iterator end = data_.rbmap->end(); + for (typename SPARSE_HASH_MAP<unsigned, T>::iterator it = data_.rbmap->begin(); it != end; ++it) it->second /= scalar; } else { for (int i = 0; i < local_size_; ++i) @@ -283,6 +349,18 @@ class FastSparseVector { } return o; } + iterator find(unsigned k) { + return iterator(*this, false, k); + } + iterator begin() { + return iterator(*this, false); + } + iterator end() { + return iterator(*this, true); + } + const_iterator find(unsigned k) const { + return const_iterator(*this, false, k); + } const_iterator begin() const { return const_iterator(*this, false); } @@ -353,17 +431,19 @@ class FastSparseVector { void swap_local_rbmap() { if (is_remote_) { // data is in rbmap, move to local assert(data_.rbmap->size() < LOCAL_MAX); - const std::map<unsigned, T>* m = data_.rbmap; + const SPARSE_HASH_MAP<unsigned, T>* m = data_.rbmap; local_size_ = m->size(); int i = 0; - for (typename std::map<unsigned, T>::const_iterator it = m->begin(); + for (typename SPARSE_HASH_MAP<unsigned, T>::const_iterator it = m->begin(); it != m->end(); ++it) { data_.local[i] = *it; ++i; } is_remote_ = false; } else { // data is local, move to rbmap - std::map<unsigned, T>* m = new std::map<unsigned, T>(&data_.local[0], &data_.local[local_size_]); + SPARSE_HASH_MAP<unsigned, T>* m = new SPARSE_HASH_MAP<unsigned, T>( + reinterpret_cast<std::pair<unsigned, T>*>(&data_.local[0]), + reinterpret_cast<std::pair<unsigned, T>*>(&data_.local[local_size_]), local_size_ * 1.5 + 1); data_.rbmap = m; is_remote_ = true; } @@ -371,7 +451,7 @@ class FastSparseVector { union { PairIntT<T> local[LOCAL_MAX]; - std::map<unsigned, T>* rbmap; + SPARSE_HASH_MAP<unsigned, T>* rbmap; } data_; unsigned char local_size_; bool is_remote_; |