From 671c21451542e2dd20e45b4033d44d8e8735f87b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 3 Dec 2009 16:33:55 -0500 Subject: initial check in --- src/logval.h | 136 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 src/logval.h (limited to 'src/logval.h') diff --git a/src/logval.h b/src/logval.h new file mode 100644 index 00000000..a8ca620c --- /dev/null +++ b/src/logval.h @@ -0,0 +1,136 @@ +#ifndef LOGVAL_H_ +#define LOGVAL_H_ + +#include +#include + +template +class LogVal { + public: + LogVal() : v_(-std::numeric_limits::infinity()) {} + explicit LogVal(double x) : v_(std::log(x)) {} + LogVal(const LogVal& o) : v_(o.v_) {} + static LogVal One() { return LogVal(1); } + static LogVal Zero() { return LogVal(); } + + void logeq(const T& v) { v_ = v; } + + LogVal& operator+=(const LogVal& a) { + if (a.v_ == -std::numeric_limits::infinity()) return *this; + if (a.v_ < v_) { + v_ = v_ + log1p(std::exp(a.v_ - v_)); + } else { + v_ = a.v_ + log1p(std::exp(v_ - a.v_)); + } + return *this; + } + + LogVal& operator*=(const LogVal& a) { + v_ += a.v_; + return *this; + } + + LogVal& operator*=(const T& a) { + v_ += log(a); + return *this; + } + + LogVal& operator/=(const LogVal& a) { + v_ -= a.v_; + return *this; + } + + LogVal& poweq(const T& power) { + if (power == 0) v_ = 0; else v_ *= power; + return *this; + } + + LogVal pow(const T& power) const { + LogVal res = *this; + res.poweq(power); + return res; + } + + operator T() const { + return std::exp(v_); + } + + T v_; +}; + +template +LogVal operator+(const LogVal& o1, const LogVal& o2) { + LogVal res(o1); + res += o2; + return res; +} + +template +LogVal operator*(const LogVal& o1, const LogVal& o2) { + LogVal res(o1); + res *= o2; + return res; +} + +template +LogVal operator*(const LogVal& o1, const T& o2) { + LogVal res(o1); + res *= o2; + return res; +} + +template +LogVal operator*(const T& o1, const LogVal& o2) { + LogVal res(o2); + res *= o1; + return res; +} + +template +LogVal operator/(const LogVal& o1, const LogVal& o2) { + LogVal res(o1); + res /= o2; + return res; +} + +template +T log(const LogVal& o) { + return o.v_; +} + +template +LogVal pow(const LogVal& b, const T& e) { + return b.pow(e); +} + +template +bool operator<(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ < rhs.v_); +} + +template +bool operator<=(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ <= rhs.v_); +} + +template +bool operator>(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ > rhs.v_); +} + +template +bool operator>=(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ >= rhs.v_); +} + +template +bool operator==(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ == rhs.v_); +} + +template +bool operator!=(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ != rhs.v_); +} + +#endif -- cgit v1.2.3