summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-03-09 22:23:50 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-03-09 22:23:50 -0500
commit113317266853abff2e1c0c3e889017d0eee55c93 (patch)
tree3fb77e29acaf45e1a9a006f8f11fb2b021b5987b /utils
parent78bf1457f606dd3880c2bc912201c4945d3f85b4 (diff)
moar
Diffstat (limited to 'utils')
-rw-r--r--utils/ccrp_nt.h17
1 files changed, 6 insertions, 11 deletions
diff --git a/utils/ccrp_nt.h b/utils/ccrp_nt.h
index 79321493..6efbfc78 100644
--- a/utils/ccrp_nt.h
+++ b/utils/ccrp_nt.h
@@ -11,6 +11,7 @@
#include <boost/functional/hash.hpp>
#include "sampler.h"
#include "slice_sampler.h"
+#include "m.h"
// Chinese restaurant process (1 parameter)
template <typename Dish, typename DishHash = boost::hash<Dish> >
@@ -29,6 +30,7 @@ class CCRP_NoTable {
alpha_prior_rate_(c_rate) {}
double alpha() const { return alpha_; }
+ void set_alpha(const double& alpha) { alpha_ = alpha; assert(alpha_ > 0.0); }
bool has_alpha_prior() const {
return !std::isnan(alpha_prior_shape_);
@@ -71,9 +73,10 @@ class CCRP_NoTable {
return table_diff;
}
- double prob(const Dish& dish, const double& p0) const {
+ template <typename F>
+ F prob(const Dish& dish, const F& p0) const {
const unsigned at_table = num_customers(dish);
- return (at_table + p0 * alpha_) / (num_customers_ + alpha_);
+ return (F(at_table) + p0 * F(alpha_)) / F(num_customers_ + alpha_);
}
double logprob(const Dish& dish, const double& logp0) const {
@@ -85,20 +88,12 @@ class CCRP_NoTable {
return log_crp_prob(alpha_);
}
- static double log_gamma_density(const double& x, const double& shape, const double& rate) {
- assert(x >= 0.0);
- assert(shape > 0.0);
- assert(rate > 0.0);
- const double lp = (shape-1)*log(x) - shape*log(rate) - x/rate - lgamma(shape);
- return lp;
- }
-
// taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process
// does not include P_0's
double log_crp_prob(const double& alpha) const {
double lp = 0.0;
if (has_alpha_prior())
- lp += log_gamma_density(alpha, alpha_prior_shape_, alpha_prior_rate_);
+ lp += Md::log_gamma_density(alpha, alpha_prior_shape_, alpha_prior_rate_);
assert(lp <= 0.0);
if (num_customers_) {
lp += lgamma(alpha) - lgamma(alpha + num_customers_) +