diff options
Diffstat (limited to 'utils')
| -rw-r--r-- | utils/ccrp_nt.h | 17 | 
1 files changed, 6 insertions, 11 deletions
diff --git a/utils/ccrp_nt.h b/utils/ccrp_nt.h index 79321493..6efbfc78 100644 --- a/utils/ccrp_nt.h +++ b/utils/ccrp_nt.h @@ -11,6 +11,7 @@  #include <boost/functional/hash.hpp>  #include "sampler.h"  #include "slice_sampler.h" +#include "m.h"  // Chinese restaurant process (1 parameter)  template <typename Dish, typename DishHash = boost::hash<Dish> > @@ -29,6 +30,7 @@ class CCRP_NoTable {      alpha_prior_rate_(c_rate) {}    double alpha() const { return alpha_; } +  void set_alpha(const double& alpha) { alpha_ = alpha; assert(alpha_ > 0.0); }    bool has_alpha_prior() const {      return !std::isnan(alpha_prior_shape_); @@ -71,9 +73,10 @@ class CCRP_NoTable {      return table_diff;    } -  double prob(const Dish& dish, const double& p0) const { +  template <typename F> +  F prob(const Dish& dish, const F& p0) const {      const unsigned at_table = num_customers(dish); -    return (at_table + p0 * alpha_) / (num_customers_ + alpha_); +    return (F(at_table) + p0 * F(alpha_)) / F(num_customers_ + alpha_);    }    double logprob(const Dish& dish, const double& logp0) const { @@ -85,20 +88,12 @@ class CCRP_NoTable {      return log_crp_prob(alpha_);    } -  static double log_gamma_density(const double& x, const double& shape, const double& rate) { -    assert(x >= 0.0); -    assert(shape > 0.0); -    assert(rate > 0.0); -    const double lp = (shape-1)*log(x) - shape*log(rate) - x/rate - lgamma(shape); -    return lp; -  } -    // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process    // does not include P_0's    double log_crp_prob(const double& alpha) const {      double lp = 0.0;      if (has_alpha_prior()) -      lp += log_gamma_density(alpha, alpha_prior_shape_, alpha_prior_rate_); +      lp += Md::log_gamma_density(alpha, alpha_prior_shape_, alpha_prior_rate_);      assert(lp <= 0.0);      if (num_customers_) {        lp += lgamma(alpha) - lgamma(alpha + num_customers_) +  | 
