summaryrefslogtreecommitdiff
path: root/utils/fast_sparse_vector.h
diff options
context:
space:
mode:
Diffstat (limited to 'utils/fast_sparse_vector.h')
-rw-r--r--utils/fast_sparse_vector.h64
1 files changed, 55 insertions, 9 deletions
diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h
index 4aae2039..1301581a 100644
--- a/utils/fast_sparse_vector.h
+++ b/utils/fast_sparse_vector.h
@@ -7,6 +7,8 @@
// important: indexes are integers
// important: iterators may return elements in any order
+#include "config.h"
+
#include <cmath>
#include <cstring>
#include <climits>
@@ -16,6 +18,12 @@
#include <boost/static_assert.hpp>
+#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP
+#include <boost/serialization/map.hpp>
+#endif
+
+#include "fdict.h"
+
// this is architecture dependent, it should be
// detected in some way but it's probably easiest (for me)
// to just set it
@@ -235,6 +243,13 @@ class FastSparseVector {
}
return *this;
}
+ FastSparseVector<T> erase_zeros(const T& EPSILON = 1e-4) const {
+ FastSparseVector<T> o;
+ for (const_iterator it = begin(); it != end(); ++it) {
+ if (fabs(it->second) > EPSILON) o.set_value(it->first, it->second);
+ }
+ return o;
+ }
const_iterator begin() const {
return const_iterator(*this, false);
}
@@ -327,8 +342,45 @@ class FastSparseVector {
} data_;
unsigned char local_size_;
bool is_remote_;
+
+#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP
+ private:
+ friend class boost::serialization::access;
+ template<class Archive>
+ void save(Archive & ar, const unsigned int version) const {
+ (void) version;
+ int eff_size = size();
+ const_iterator it = this->begin();
+ if (eff_size > 0) {
+ // 0 index is reserved as empty
+ if (it->first == 0) { ++it; --eff_size; }
+ }
+ ar & eff_size;
+ while (it != this->end()) {
+ const std::pair<const std::string&, const T&> wire_pair(FD::Convert(it->first), it->second);
+ ar & wire_pair;
+ ++it;
+ }
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int version) {
+ (void) version;
+ this->clear();
+ int sz; ar & sz;
+ for (int i = 0; i < sz; ++i) {
+ std::pair<std::string, T> wire_pair;
+ ar & wire_pair;
+ this->set_value(FD::Convert(wire_pair.first), wire_pair.second);
+ }
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
+#endif
};
+#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP
+BOOST_CLASS_TRACKING(FastSparseVector<double>,track_never)
+#endif
+
template <typename T>
const FastSparseVector<T> operator+(const FastSparseVector<T>& x, const FastSparseVector<T>& y) {
if (x.size() > y.size()) {
@@ -344,15 +396,9 @@ const FastSparseVector<T> operator+(const FastSparseVector<T>& x, const FastSpar
template <typename T>
const FastSparseVector<T> operator-(const FastSparseVector<T>& x, const FastSparseVector<T>& y) {
- if (x.size() > y.size()) {
- FastSparseVector<T> res(x);
- res -= y;
- return res;
- } else {
- FastSparseVector<T> res(y);
- res -= x;
- return res;
- }
+ FastSparseVector<T> res(x);
+ res -= y;
+ return res;
}
template <class T>