diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/fast_sparse_vector.h | 64 |
1 files changed, 55 insertions, 9 deletions
diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h index 4aae2039..1301581a 100644 --- a/utils/fast_sparse_vector.h +++ b/utils/fast_sparse_vector.h @@ -7,6 +7,8 @@ // important: indexes are integers // important: iterators may return elements in any order +#include "config.h" + #include <cmath> #include <cstring> #include <climits> @@ -16,6 +18,12 @@ #include <boost/static_assert.hpp> +#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP +#include <boost/serialization/map.hpp> +#endif + +#include "fdict.h" + // this is architecture dependent, it should be // detected in some way but it's probably easiest (for me) // to just set it @@ -235,6 +243,13 @@ class FastSparseVector { } return *this; } + FastSparseVector<T> erase_zeros(const T& EPSILON = 1e-4) const { + FastSparseVector<T> o; + for (const_iterator it = begin(); it != end(); ++it) { + if (fabs(it->second) > EPSILON) o.set_value(it->first, it->second); + } + return o; + } const_iterator begin() const { return const_iterator(*this, false); } @@ -327,8 +342,45 @@ class FastSparseVector { } data_; unsigned char local_size_; bool is_remote_; + +#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP + private: + friend class boost::serialization::access; + template<class Archive> + void save(Archive & ar, const unsigned int version) const { + (void) version; + int eff_size = size(); + const_iterator it = this->begin(); + if (eff_size > 0) { + // 0 index is reserved as empty + if (it->first == 0) { ++it; --eff_size; } + } + ar & eff_size; + while (it != this->end()) { + const std::pair<const std::string&, const T&> wire_pair(FD::Convert(it->first), it->second); + ar & wire_pair; + ++it; + } + } + template<class Archive> + void load(Archive & ar, const unsigned int version) { + (void) version; + this->clear(); + int sz; ar & sz; + for (int i = 0; i < sz; ++i) { + std::pair<std::string, T> wire_pair; + ar & wire_pair; + this->set_value(FD::Convert(wire_pair.first), wire_pair.second); + } + } + BOOST_SERIALIZATION_SPLIT_MEMBER() +#endif }; +#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP +BOOST_CLASS_TRACKING(FastSparseVector<double>,track_never) +#endif + template <typename T> const FastSparseVector<T> operator+(const FastSparseVector<T>& x, const FastSparseVector<T>& y) { if (x.size() > y.size()) { @@ -344,15 +396,9 @@ const FastSparseVector<T> operator+(const FastSparseVector<T>& x, const FastSpar template <typename T> const FastSparseVector<T> operator-(const FastSparseVector<T>& x, const FastSparseVector<T>& y) { - if (x.size() > y.size()) { - FastSparseVector<T> res(x); - res -= y; - return res; - } else { - FastSparseVector<T> res(y); - res -= x; - return res; - } + FastSparseVector<T> res(x); + res -= y; + return res; } template <class T> |