diff options
Diffstat (limited to 'utils/sparse_vector.h')
-rw-r--r-- | utils/sparse_vector.h | 136 |
1 files changed, 119 insertions, 17 deletions
diff --git a/utils/sparse_vector.h b/utils/sparse_vector.h index e8e9c2f7..7ac85d1d 100644 --- a/utils/sparse_vector.h +++ b/utils/sparse_vector.h @@ -1,16 +1,13 @@ #ifndef _SPARSE_VECTOR_H_ #define _SPARSE_VECTOR_H_ +/* +TODO: specialize for int value types, where it probably makes sense to check if adding/subtracting brings a value to 0, and remove it from the map (e.g. in a gibbs sampler). or add a separate policy argument for that. + */ + //#define SPARSE_VECTOR_HASH +// if defined, use hash_map rather than map. map is probably faster/smaller for small vectors -#ifdef SPARSE_VECTOR_HASH -#include "hash.h" -# define SPARSE_VECTOR_MAP HASH_MAP -# define SPARSE_VECTOR_MAP_RESERVED(h,empty,deleted) HASH_MAP_RESERVED(h,empty,deleted) -#else -# define SPARSE_VECTOR_MAP std::map -# define SPARSE_VECTOR_MAP_RESERVED(h,empty,deleted) -#endif /* use SparseVectorList (pair smallvector) for feat funcs / hypergraphs (you rarely need random access; just append a feature to the list) */ @@ -38,6 +35,17 @@ // this is a modified version of code originally written // by Phil Blunsom +#include <boost/functional/hash.hpp> +#include <stdexcept> +#ifdef SPARSE_VECTOR_HASH +#include "hash.h" +# define SPARSE_VECTOR_MAP HASH_MAP +# define SPARSE_VECTOR_MAP_RESERVED(h,empty,deleted) HASH_MAP_RESERVED(h,empty,deleted) +#else +# define SPARSE_VECTOR_MAP std::map +# define SPARSE_VECTOR_MAP_RESERVED(h,empty,deleted) +#endif + #include <iostream> #include <map> #include <tr1/unordered_map> @@ -46,6 +54,7 @@ #include "fdict.h" #include "small_vector.h" +#include "string_to.h" template <class T> inline T & extend_vector(std::vector<T> &v,int i) { @@ -54,7 +63,7 @@ inline T & extend_vector(std::vector<T> &v,int i) { return v[i]; } -template <typename T> +template <class T> class SparseVector { void init_reserved() { SPARSE_VECTOR_MAP_RESERVED(values_,-1,-2); @@ -71,17 +80,97 @@ public: SparseVector() { init_reserved(); } + typedef typename MapType::value_type value_type; + typedef typename MapType::iterator iterator; explicit SparseVector(std::vector<T> const& v) { init_reserved(); - typename MapType::iterator p=values_.begin(); + iterator p=values_.begin(); const T z=0; for (unsigned i=0;i<v.size();++i) { T const& t=v[i]; if (t!=z) - p=values_.insert(p,typename MapType::value_type(i,t)); //hint makes insertion faster + p=values_.insert(p,value_type(i,t)); //hint makes insertion faster + } + } + + typedef char const* Str; + template <class O> + void print(O &o,Str pre="",Str post="",Str kvsep="=",Str pairsep=" ") const { + o << pre; + bool first=true; + for (const_iterator i=values_.begin(),e=values_.end();i!=e;++i) { + if (first) + first=false; + else + o<<pairsep; + o<<FD::Convert(i->first)<<kvsep<<i->second; } + o << post; } + static void error(std::string const& msg) { + throw std::runtime_error("SparseVector: "+msg); + } + + enum DupPolicy { + NO_DUPS, + KEEP_FIRST, + KEEP_LAST, + SUM + }; + + // either key val alternating whitespace sep, or key=val (kvsep char is '='). end at eof or terminator (non-ws) char + template <class S> + void read(S &s,DupPolicy dp=NO_DUPS,bool use_kvsep=true,char kvsep='=',bool stop_at_terminator=false,char terminator=')') { + values_.clear(); + std::string id; + WordID k; + T v; +#undef SPARSE_MUST_READ +#define SPARSE_MUST_READ(x) if (!(x)) error(#x); + int ki; + while (s) { + if (stop_at_terminator) { + char c; + if (!(s>>c)) goto eof; + s.unget(); + if (c==terminator) return; + } + if (!(s>>id)) goto eof; + if (use_kvsep && (ki=id.find(kvsep))!=std::string::npos) { + k=FD::Convert(std::string(id,0,ki)); + string_into(id.c_str()+ki+1,v); + } else { + k=FD::Convert(id); + if (!(s>>v)) error("reading value failed"); + } + std::pair<iterator,bool> vi=values_.insert(value_type(k,v)); + if (vi.second) { + T &oldv=vi.first->second; + switch(dp) { + case NO_DUPS: error("read duplicate key with NO_DUPS. key=" + +FD::Convert(k)+" val="+to_string(v)+" old-val="+to_string(oldv)); + break; + case KEEP_FIRST: break; + case KEEP_LAST: oldv=v; break; + case SUM: oldv+=v; break; + } + } + } + return; + eof: + if (!s.eof()) error("reading key failed (before EOF)"); + } + + friend inline std::ostream & operator<<(std::ostream &o,Self const& s) { + s.print(o); + return o; + } + + friend inline std::istream & operator>>(std::istream &o,Self & s) { + s.read(o); + return o; + } void init_vector(std::vector<T> *vp) const { init_vector(*vp); @@ -118,6 +207,10 @@ public: return values_[index]; } + inline void maybe_set_value(int index, const T &value) { + if (value) values_[index] = value; + } + inline void set_value(int index, const T &value) { values_[index] = value; } @@ -352,6 +445,10 @@ public: return size()==other.size() && contains_keys_of(other) && other.contains_i(*this); } + std::size_t hash_impl() const { + return boost::hash_range(begin(),end()); + } + bool contains(Self const &o) const { return size()>o.size() && contains(o); } @@ -371,7 +468,7 @@ public: bool contains_keys_of(Self const& o) const { for (typename MapType::const_iterator i=o.begin(),e=o.end();i!=e;++i) - if (values_.find(i)==values_.end()) + if (values_.find(i->first)==values_.end()) return false; return true; } @@ -478,31 +575,36 @@ private: List p; }; -template <typename T> +template <class T> +std::size_t hash_value(SparseVector<T> const& x) { + return x.hash_impl(); +} + +template <class T> SparseVector<T> operator+(const SparseVector<T>& a, const SparseVector<T>& b) { SparseVector<T> result = a; return result += b; } -template <typename T> +template <class T> SparseVector<T> operator*(const SparseVector<T>& a, const double& b) { SparseVector<T> result = a; return result *= b; } -template <typename T> +template <class T> SparseVector<T> operator*(const SparseVector<T>& a, const T& b) { SparseVector<T> result = a; return result *= b; } -template <typename T> +template <class T> SparseVector<T> operator*(const double& a, const SparseVector<T>& b) { SparseVector<T> result = b; return result *= a; } -template <typename T> +template <class T> std::ostream &operator<<(std::ostream &out, const SparseVector<T> &vec) { return vec.operator<<(out); |