summaryrefslogtreecommitdiff
path: root/utils/fast_sparse_vector.h
diff options
context:
space:
mode:
Diffstat (limited to 'utils/fast_sparse_vector.h')
-rw-r--r--utils/fast_sparse_vector.h108
1 files changed, 94 insertions, 14 deletions
diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h
index e86cbdc1..433a5cc5 100644
--- a/utils/fast_sparse_vector.h
+++ b/utils/fast_sparse_vector.h
@@ -66,6 +66,60 @@ BOOST_STATIC_ASSERT(sizeof(PairIntT<float>) == sizeof(std::pair<unsigned,float>)
template <typename T, unsigned LOCAL_MAX = (sizeof(T) == sizeof(float) ? 15u : 7u)>
class FastSparseVector {
public:
+ struct iterator {
+ iterator(FastSparseVector<T>& v, const bool is_end) : local_(!v.is_remote_) {
+ if (local_) {
+ local_it_ = &v.data_.local[is_end ? v.local_size_ : 0];
+ } else {
+ if (is_end)
+ remote_it_ = v.data_.rbmap->end();
+ else
+ remote_it_ = v.data_.rbmap->begin();
+ }
+ }
+ iterator(FastSparseVector<T>& v, const bool, const unsigned k) : local_(!v.is_remote_) {
+ if (local_) {
+ unsigned i = 0;
+ while(i < v.local_size_ && v.data_.local[i].first() != k) { ++i; }
+ local_it_ = &v.data_.local[i];
+ } else {
+ remote_it_ = v.data_.rbmap->find(k);
+ }
+ }
+ const bool local_;
+ PairIntT<T>* local_it_;
+ typename SPARSE_HASH_MAP<unsigned, T>::iterator remote_it_;
+ std::pair<const unsigned, T>& operator*() const {
+ if (local_)
+ return *reinterpret_cast<std::pair<const unsigned, T>*>(local_it_);
+ else
+ return *remote_it_;
+ }
+
+ std::pair<const unsigned, T>* operator->() const {
+ if (local_)
+ return reinterpret_cast<std::pair<const unsigned, T>*>(local_it_);
+ else
+ return &*remote_it_;
+ }
+
+ iterator& operator++() {
+ if (local_) ++local_it_; else ++remote_it_;
+ return *this;
+ }
+
+ inline bool operator==(const iterator& o) const {
+ if (o.local_ != local_) return false;
+ if (local_) {
+ return local_it_ == o.local_it_;
+ } else {
+ return remote_it_ == o.remote_it_;
+ }
+ }
+ inline bool operator!=(const iterator& o) const {
+ return !(o == *this);
+ }
+ };
struct const_iterator {
const_iterator(const FastSparseVector<T>& v, const bool is_end) : local_(!v.is_remote_) {
if (local_) {
@@ -77,12 +131,21 @@ class FastSparseVector {
remote_it_ = v.data_.rbmap->begin();
}
}
+ const_iterator(const FastSparseVector<T>& v, const bool, const unsigned k) : local_(!v.is_remote_) {
+ if (local_) {
+ unsigned i = 0;
+ while(i < v.local_size_ && v.data_.local[i].first() != k) { ++i; }
+ local_it_ = &v.data_.local[i];
+ } else {
+ remote_it_ = v.data_.rbmap->find(k);
+ }
+ }
const bool local_;
const PairIntT<T>* local_it_;
- typename std::map<unsigned, T>::const_iterator remote_it_;
+ typename SPARSE_HASH_MAP<unsigned, T>::const_iterator remote_it_;
const std::pair<const unsigned, T>& operator*() const {
if (local_)
- return *reinterpret_cast<const std::pair<const unsigned, float>*>(local_it_);
+ return *reinterpret_cast<const std::pair<const unsigned, T>*>(local_it_);
else
return *remote_it_;
}
@@ -118,7 +181,7 @@ class FastSparseVector {
}
FastSparseVector(const FastSparseVector& other) {
std::memcpy(this, &other, sizeof(FastSparseVector));
- if (is_remote_) data_.rbmap = new std::map<unsigned, T>(*data_.rbmap);
+ if (is_remote_) data_.rbmap = new SPARSE_HASH_MAP<unsigned, T>(*data_.rbmap);
}
FastSparseVector(std::pair<unsigned, T>* first, std::pair<unsigned, T>* last) {
const ptrdiff_t n = last - first;
@@ -128,7 +191,7 @@ class FastSparseVector {
std::memcpy(data_.local, first, sizeof(std::pair<unsigned, T>) * n);
} else {
is_remote_ = true;
- data_.rbmap = new std::map<unsigned, T>(first, last);
+ data_.rbmap = new SPARSE_HASH_MAP<unsigned, T>(first, last);
}
}
void erase(int k) {
@@ -150,7 +213,7 @@ class FastSparseVector {
clear();
std::memcpy(this, &other, sizeof(FastSparseVector));
if (is_remote_)
- data_.rbmap = new std::map<unsigned, T>(*data_.rbmap);
+ data_.rbmap = new SPARSE_HASH_MAP<unsigned, T>(*data_.rbmap);
return *this;
}
T const& get_singleton() const {
@@ -160,6 +223,9 @@ class FastSparseVector {
bool nonzero(unsigned k) const {
return static_cast<bool>(value(k));
}
+ inline T& operator[](unsigned k) {
+ return get_or_create_bin(k);
+ }
inline void set_value(unsigned k, const T& v) {
get_or_create_bin(k) = v;
}
@@ -171,7 +237,7 @@ class FastSparseVector {
}
inline T value(unsigned k) const {
if (is_remote_) {
- typename std::map<unsigned, T>::const_iterator it = data_.rbmap->find(k);
+ typename SPARSE_HASH_MAP<unsigned, T>::const_iterator it = data_.rbmap->find(k);
if (it != data_.rbmap->end()) return it->second;
} else {
for (unsigned i = 0; i < local_size_; ++i) {
@@ -256,8 +322,8 @@ class FastSparseVector {
}
inline FastSparseVector& operator*=(const T& scalar) {
if (is_remote_) {
- const typename std::map<unsigned, T>::iterator end = data_.rbmap->end();
- for (typename std::map<unsigned, T>::iterator it = data_.rbmap->begin(); it != end; ++it)
+ const typename SPARSE_HASH_MAP<unsigned, T>::iterator end = data_.rbmap->end();
+ for (typename SPARSE_HASH_MAP<unsigned, T>::iterator it = data_.rbmap->begin(); it != end; ++it)
it->second *= scalar;
} else {
for (int i = 0; i < local_size_; ++i)
@@ -267,8 +333,8 @@ class FastSparseVector {
}
inline FastSparseVector& operator/=(const T& scalar) {
if (is_remote_) {
- const typename std::map<unsigned, T>::iterator end = data_.rbmap->end();
- for (typename std::map<unsigned, T>::iterator it = data_.rbmap->begin(); it != end; ++it)
+ const typename SPARSE_HASH_MAP<unsigned, T>::iterator end = data_.rbmap->end();
+ for (typename SPARSE_HASH_MAP<unsigned, T>::iterator it = data_.rbmap->begin(); it != end; ++it)
it->second /= scalar;
} else {
for (int i = 0; i < local_size_; ++i)
@@ -283,6 +349,18 @@ class FastSparseVector {
}
return o;
}
+ iterator find(unsigned k) {
+ return iterator(*this, false, k);
+ }
+ iterator begin() {
+ return iterator(*this, false);
+ }
+ iterator end() {
+ return iterator(*this, true);
+ }
+ const_iterator find(unsigned k) const {
+ return const_iterator(*this, false, k);
+ }
const_iterator begin() const {
return const_iterator(*this, false);
}
@@ -353,17 +431,19 @@ class FastSparseVector {
void swap_local_rbmap() {
if (is_remote_) { // data is in rbmap, move to local
assert(data_.rbmap->size() < LOCAL_MAX);
- const std::map<unsigned, T>* m = data_.rbmap;
+ const SPARSE_HASH_MAP<unsigned, T>* m = data_.rbmap;
local_size_ = m->size();
int i = 0;
- for (typename std::map<unsigned, T>::const_iterator it = m->begin();
+ for (typename SPARSE_HASH_MAP<unsigned, T>::const_iterator it = m->begin();
it != m->end(); ++it) {
data_.local[i] = *it;
++i;
}
is_remote_ = false;
} else { // data is local, move to rbmap
- std::map<unsigned, T>* m = new std::map<unsigned, T>(&data_.local[0], &data_.local[local_size_]);
+ SPARSE_HASH_MAP<unsigned, T>* m = new SPARSE_HASH_MAP<unsigned, T>(
+ reinterpret_cast<std::pair<unsigned, T>*>(&data_.local[0]),
+ reinterpret_cast<std::pair<unsigned, T>*>(&data_.local[local_size_]), local_size_ * 1.5 + 1);
data_.rbmap = m;
is_remote_ = true;
}
@@ -371,7 +451,7 @@ class FastSparseVector {
union {
PairIntT<T> local[LOCAL_MAX];
- std::map<unsigned, T>* rbmap;
+ SPARSE_HASH_MAP<unsigned, T>* rbmap;
} data_;
unsigned char local_size_;
bool is_remote_;