diff options
Diffstat (limited to 'utils/logval.h')
-rw-r--r-- | utils/logval.h | 87 |
1 files changed, 68 insertions, 19 deletions
diff --git a/utils/logval.h b/utils/logval.h index b337cf0e..868146de 100644 --- a/utils/logval.h +++ b/utils/logval.h @@ -8,20 +8,32 @@ #include <cstdlib> #include <cmath> #include <limits> +#include "semiring.h" +//TODO: template for supporting negation or not - most uses are for nonnegative "probs" only; probably some 10-20% speedup available template <class T> class LogVal { public: - LogVal() : s_(), v_(-std::numeric_limits<T>::infinity()) {} + typedef LogVal<T> Self; + + LogVal() : s_(), v_(LOGVAL_LOG0) {} explicit LogVal(double x) : s_(std::signbit(x)), v_(s_ ? std::log(-x) : std::log(x)) {} + LogVal(init_minus_1) : s_(true),v_(0) { } + LogVal(init_1) : s_(),v_(0) { } + LogVal(init_0) : s_(),v_(LOGVAL_LOG0) { } LogVal(int x) : s_(x<0), v_(s_ ? std::log(-x) : std::log(x)) {} LogVal(unsigned x) : s_(0), v_(std::log(x)) { } LogVal(double lnx,bool sign) : s_(sign),v_(lnx) {} - static LogVal<T> exp(T lnx) { return LogVal(lnx,false); } + LogVal(double lnx,init_lnx) : s_(),v_(lnx) {} + static Self exp(T lnx) { return Self(lnx,false); } + + // maybe the below are faster than == 1 and == 0. i don't know. + bool is_1() const { return v_==0&&s_==0; } + bool is_0() const { return v_==LOGVAL_LOG0; } - static LogVal<T> One() { return LogVal(1); } - static LogVal<T> Zero() { return LogVal(); } - static LogVal<T> e() { return LogVal(1,false); } + static Self One() { return Self(1); } + static Self Zero() { return Self(); } + static Self e() { return Self(1,false); } void logeq(const T& v) { s_ = false; v_ = v; } std::size_t hash_impl() const { @@ -29,8 +41,21 @@ class LogVal { return hash_value(v_)+s_; } - LogVal& operator+=(const LogVal& a) { - if (a.v_ == -std::numeric_limits<T>::infinity()) return *this; + // just like std::signbit, negative means true. weird, i know + bool signbit() const { + return s_; + } + friend inline bool signbit(Self const& x) { return x.signbit(); } + + Self& besteq(const Self& a) { + assert(!a.s_ && !s_); + if (a.v_ < v_) + v_=a.v_; + return *this; + } + + Self& operator+=(const Self& a) { + if (a.is_0()) return *this; if (a.s_ == s_) { if (a.v_ < v_) { v_ = v_ + log1p(std::exp(a.v_ - v_)); @@ -48,31 +73,31 @@ class LogVal { return *this; } - LogVal& operator*=(const LogVal& a) { + Self& operator*=(const Self& a) { s_ = (s_ != a.s_); v_ += a.v_; return *this; } - LogVal& operator/=(const LogVal& a) { + Self& operator/=(const Self& a) { s_ = (s_ != a.s_); v_ -= a.v_; return *this; } - LogVal& operator-=(const LogVal& a) { - LogVal b = a; - b.invert(); + Self& operator-=(const Self& a) { + Self b = a; + b.negate(); return *this += b; } - // LogVal(fabs(log(x)),x.s_) - friend LogVal abslog(LogVal x) { + // Self(fabs(log(x)),x.s_) + friend Self abslog(Self x) { if (x.v_<0) x.v_=-x.v_; return x; } - LogVal& poweq(const T& power) { + Self& poweq(const T& power) { #if LOGVAL_CHECK_NEG if (s_) { std::cerr << "poweq(T) not implemented when s_ is true\n"; @@ -83,26 +108,50 @@ class LogVal { return *this; } - void invert() { s_ = !s_; } + //remember, s_ means negative. + inline bool lt(Self const& o) const { + return s_ ? (!o.s_ || o.v_<v_) : (o.s_ || v_<o.v_); + } + inline bool gt(Self const& o) const { + return s_ ? (o.s_ && v_<o.v_) : (!o.s_ && o.v_<v_); + } - LogVal pow(const T& power) const { - LogVal res = *this; + Self operator-() const { + return Self(v_,-s_); + } + void negate() { s_ = !s_; } + + Self inverse() const { return Self(-v_,s_); } + + Self pow(const T& power) const { + Self res = *this; res.poweq(power); return res; } - LogVal root(const T& root) const { + Self root(const T& root) const { return pow(1/root); } operator T() const { if (s_) return -std::exp(v_); else return std::exp(v_); } + T as_float() const { + if (s_) return -std::exp(v_); else return std::exp(v_); + } bool s_; T v_; }; +template <class T> +struct semiring_traits<LogVal<T> > : default_semiring_traits<LogVal<T> > { + static const bool has_logplus=true; + static const bool has_besteq=true; + static const bool has_subtract=true; + static const bool has_negative=true; +}; + // copy elision - as opposed to explicit copy of LogVal<T> const& o1, we should be able to construct Logval r=a+(b+c) as a single result in place in r. todo: return std::move(o1) - C++0x template<class T> LogVal<T> operator+(LogVal<T> o1, const LogVal<T>& o2) { |