diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-03-09 22:23:50 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-03-09 22:23:50 -0500 |
commit | 113317266853abff2e1c0c3e889017d0eee55c93 (patch) | |
tree | 3fb77e29acaf45e1a9a006f8f11fb2b021b5987b /utils | |
parent | 78bf1457f606dd3880c2bc912201c4945d3f85b4 (diff) |
moar
Diffstat (limited to 'utils')
-rw-r--r-- | utils/ccrp_nt.h | 17 |
1 files changed, 6 insertions, 11 deletions
diff --git a/utils/ccrp_nt.h b/utils/ccrp_nt.h index 79321493..6efbfc78 100644 --- a/utils/ccrp_nt.h +++ b/utils/ccrp_nt.h @@ -11,6 +11,7 @@ #include <boost/functional/hash.hpp> #include "sampler.h" #include "slice_sampler.h" +#include "m.h" // Chinese restaurant process (1 parameter) template <typename Dish, typename DishHash = boost::hash<Dish> > @@ -29,6 +30,7 @@ class CCRP_NoTable { alpha_prior_rate_(c_rate) {} double alpha() const { return alpha_; } + void set_alpha(const double& alpha) { alpha_ = alpha; assert(alpha_ > 0.0); } bool has_alpha_prior() const { return !std::isnan(alpha_prior_shape_); @@ -71,9 +73,10 @@ class CCRP_NoTable { return table_diff; } - double prob(const Dish& dish, const double& p0) const { + template <typename F> + F prob(const Dish& dish, const F& p0) const { const unsigned at_table = num_customers(dish); - return (at_table + p0 * alpha_) / (num_customers_ + alpha_); + return (F(at_table) + p0 * F(alpha_)) / F(num_customers_ + alpha_); } double logprob(const Dish& dish, const double& logp0) const { @@ -85,20 +88,12 @@ class CCRP_NoTable { return log_crp_prob(alpha_); } - static double log_gamma_density(const double& x, const double& shape, const double& rate) { - assert(x >= 0.0); - assert(shape > 0.0); - assert(rate > 0.0); - const double lp = (shape-1)*log(x) - shape*log(rate) - x/rate - lgamma(shape); - return lp; - } - // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process // does not include P_0's double log_crp_prob(const double& alpha) const { double lp = 0.0; if (has_alpha_prior()) - lp += log_gamma_density(alpha, alpha_prior_shape_, alpha_prior_rate_); + lp += Md::log_gamma_density(alpha, alpha_prior_shape_, alpha_prior_rate_); assert(lp <= 0.0); if (num_customers_) { lp += lgamma(alpha) - lgamma(alpha + num_customers_) + |