diff options
Diffstat (limited to 'utils/sparse_vector.h')
| -rw-r--r-- | utils/sparse_vector.h | 136 | 
1 files changed, 119 insertions, 17 deletions
| diff --git a/utils/sparse_vector.h b/utils/sparse_vector.h index e8e9c2f7..7ac85d1d 100644 --- a/utils/sparse_vector.h +++ b/utils/sparse_vector.h @@ -1,16 +1,13 @@  #ifndef _SPARSE_VECTOR_H_  #define _SPARSE_VECTOR_H_ +/* +TODO: specialize for int value types, where it probably makes sense to check if adding/subtracting brings a value to 0, and remove it from the map (e.g. in a gibbs sampler).  or add a separate policy argument for that. + */ +  //#define SPARSE_VECTOR_HASH +// if defined, use hash_map rather than map.  map is probably faster/smaller for small vectors -#ifdef SPARSE_VECTOR_HASH -#include "hash.h" -# define SPARSE_VECTOR_MAP HASH_MAP -# define SPARSE_VECTOR_MAP_RESERVED(h,empty,deleted) HASH_MAP_RESERVED(h,empty,deleted) -#else -# define SPARSE_VECTOR_MAP std::map -# define SPARSE_VECTOR_MAP_RESERVED(h,empty,deleted) -#endif  /*     use SparseVectorList (pair smallvector) for feat funcs / hypergraphs (you rarely need random access; just append a feature to the list)  */ @@ -38,6 +35,17 @@  // this is a modified version of code originally written  // by Phil Blunsom +#include <boost/functional/hash.hpp> +#include <stdexcept> +#ifdef SPARSE_VECTOR_HASH +#include "hash.h" +# define SPARSE_VECTOR_MAP HASH_MAP +# define SPARSE_VECTOR_MAP_RESERVED(h,empty,deleted) HASH_MAP_RESERVED(h,empty,deleted) +#else +# define SPARSE_VECTOR_MAP std::map +# define SPARSE_VECTOR_MAP_RESERVED(h,empty,deleted) +#endif +  #include <iostream>  #include <map>  #include <tr1/unordered_map> @@ -46,6 +54,7 @@  #include "fdict.h"  #include "small_vector.h" +#include "string_to.h"  template <class T>  inline T & extend_vector(std::vector<T> &v,int i) { @@ -54,7 +63,7 @@ inline T & extend_vector(std::vector<T> &v,int i) {    return v[i];  } -template <typename T> +template <class T>  class SparseVector {    void init_reserved() {      SPARSE_VECTOR_MAP_RESERVED(values_,-1,-2); @@ -71,17 +80,97 @@ public:    SparseVector() {      init_reserved();    } +  typedef typename MapType::value_type value_type; +  typedef typename MapType::iterator iterator;    explicit SparseVector(std::vector<T> const& v) {      init_reserved(); -    typename MapType::iterator p=values_.begin(); +    iterator p=values_.begin();      const T z=0;      for (unsigned i=0;i<v.size();++i) {        T const& t=v[i];        if (t!=z) -        p=values_.insert(p,typename MapType::value_type(i,t)); //hint makes insertion faster +        p=values_.insert(p,value_type(i,t)); //hint makes insertion faster +    } +  } + +  typedef char const* Str; +  template <class O> +  void print(O &o,Str pre="",Str post="",Str kvsep="=",Str pairsep=" ") const { +    o << pre; +    bool first=true; +    for (const_iterator i=values_.begin(),e=values_.end();i!=e;++i) { +      if (first) +        first=false; +      else +        o<<pairsep; +      o<<FD::Convert(i->first)<<kvsep<<i->second;      } +    o << post;    } +  static void error(std::string const& msg) { +    throw std::runtime_error("SparseVector: "+msg); +  } + +  enum DupPolicy { +    NO_DUPS, +    KEEP_FIRST, +    KEEP_LAST, +    SUM +  }; + +  // either key val alternating whitespace sep, or key=val (kvsep char is '=').  end at eof or terminator (non-ws) char +  template <class S> +  void read(S &s,DupPolicy dp=NO_DUPS,bool use_kvsep=true,char kvsep='=',bool stop_at_terminator=false,char terminator=')') { +    values_.clear(); +    std::string id; +    WordID k; +    T v; +#undef SPARSE_MUST_READ +#define SPARSE_MUST_READ(x) if (!(x)) error(#x); +    int ki; +    while (s) { +      if (stop_at_terminator) { +        char c; +        if (!(s>>c)) goto eof; +        s.unget(); +        if (c==terminator) return; +      } +      if (!(s>>id)) goto eof; +      if (use_kvsep && (ki=id.find(kvsep))!=std::string::npos) { +        k=FD::Convert(std::string(id,0,ki)); +        string_into(id.c_str()+ki+1,v); +      } else { +        k=FD::Convert(id); +        if (!(s>>v)) error("reading value failed"); +      } +      std::pair<iterator,bool> vi=values_.insert(value_type(k,v)); +      if (vi.second) { +        T &oldv=vi.first->second; +        switch(dp) { +        case NO_DUPS: error("read duplicate key with NO_DUPS.  key=" +                            +FD::Convert(k)+" val="+to_string(v)+" old-val="+to_string(oldv)); +          break; +        case KEEP_FIRST: break; +        case KEEP_LAST: oldv=v; break; +        case SUM: oldv+=v; break; +        } +      } +    } +    return; +  eof: +    if (!s.eof()) error("reading key failed (before EOF)"); +  } + +  friend inline std::ostream & operator<<(std::ostream &o,Self const& s) { +    s.print(o); +    return o; +  } + +  friend inline std::istream & operator>>(std::istream &o,Self & s) { +    s.read(o); +    return o; +  }    void init_vector(std::vector<T> *vp) const {      init_vector(*vp); @@ -118,6 +207,10 @@ public:      return values_[index];    } +  inline void maybe_set_value(int index, const T &value) { +    if (value) values_[index] = value; +  } +    inline void set_value(int index, const T &value) {      values_[index] = value;    } @@ -352,6 +445,10 @@ public:      return size()==other.size() && contains_keys_of(other) && other.contains_i(*this);    } +  std::size_t hash_impl() const { +    return boost::hash_range(begin(),end()); +  } +    bool contains(Self const &o) const {      return size()>o.size() && contains(o);    } @@ -371,7 +468,7 @@ public:    bool contains_keys_of(Self const& o) const {      for (typename MapType::const_iterator i=o.begin(),e=o.end();i!=e;++i) -      if (values_.find(i)==values_.end()) +      if (values_.find(i->first)==values_.end())          return false;      return true;    } @@ -478,31 +575,36 @@ private:    List p;  }; -template <typename T> +template <class T> +std::size_t hash_value(SparseVector<T> const& x) { +  return x.hash_impl(); +} + +template <class T>  SparseVector<T> operator+(const SparseVector<T>& a, const SparseVector<T>& b) {    SparseVector<T> result = a;    return result += b;  } -template <typename T> +template <class T>  SparseVector<T> operator*(const SparseVector<T>& a, const double& b) {    SparseVector<T> result = a;    return result *= b;  } -template <typename T> +template <class T>  SparseVector<T> operator*(const SparseVector<T>& a, const T& b) {    SparseVector<T> result = a;    return result *= b;  } -template <typename T> +template <class T>  SparseVector<T> operator*(const double& a, const SparseVector<T>& b) {    SparseVector<T> result = b;    return result *= a;  } -template <typename T> +template <class T>  std::ostream &operator<<(std::ostream &out, const SparseVector<T> &vec)  {      return vec.operator<<(out); | 
