summaryrefslogtreecommitdiff
path: root/utils/ccrp_nt.h
diff options
context:
space:
mode:
Diffstat (limited to 'utils/ccrp_nt.h')
-rw-r--r--utils/ccrp_nt.h65
1 files changed, 30 insertions, 35 deletions
diff --git a/utils/ccrp_nt.h b/utils/ccrp_nt.h
index 63b6f4c2..6efbfc78 100644
--- a/utils/ccrp_nt.h
+++ b/utils/ccrp_nt.h
@@ -11,6 +11,7 @@
#include <boost/functional/hash.hpp>
#include "sampler.h"
#include "slice_sampler.h"
+#include "m.h"
// Chinese restaurant process (1 parameter)
template <typename Dish, typename DishHash = boost::hash<Dish> >
@@ -18,20 +19,21 @@ class CCRP_NoTable {
public:
explicit CCRP_NoTable(double conc) :
num_customers_(),
- concentration_(conc),
- concentration_prior_shape_(std::numeric_limits<double>::quiet_NaN()),
- concentration_prior_rate_(std::numeric_limits<double>::quiet_NaN()) {}
+ alpha_(conc),
+ alpha_prior_shape_(std::numeric_limits<double>::quiet_NaN()),
+ alpha_prior_rate_(std::numeric_limits<double>::quiet_NaN()) {}
CCRP_NoTable(double c_shape, double c_rate, double c = 10.0) :
num_customers_(),
- concentration_(c),
- concentration_prior_shape_(c_shape),
- concentration_prior_rate_(c_rate) {}
+ alpha_(c),
+ alpha_prior_shape_(c_shape),
+ alpha_prior_rate_(c_rate) {}
- double concentration() const { return concentration_; }
+ double alpha() const { return alpha_; }
+ void set_alpha(const double& alpha) { alpha_ = alpha; assert(alpha_ > 0.0); }
- bool has_concentration_prior() const {
- return !std::isnan(concentration_prior_shape_);
+ bool has_alpha_prior() const {
+ return !std::isnan(alpha_prior_shape_);
}
void clear() {
@@ -71,38 +73,31 @@ class CCRP_NoTable {
return table_diff;
}
- double prob(const Dish& dish, const double& p0) const {
+ template <typename F>
+ F prob(const Dish& dish, const F& p0) const {
const unsigned at_table = num_customers(dish);
- return (at_table + p0 * concentration_) / (num_customers_ + concentration_);
+ return (F(at_table) + p0 * F(alpha_)) / F(num_customers_ + alpha_);
}
double logprob(const Dish& dish, const double& logp0) const {
const unsigned at_table = num_customers(dish);
- return log(at_table + exp(logp0 + log(concentration_))) - log(num_customers_ + concentration_);
+ return log(at_table + exp(logp0 + log(alpha_))) - log(num_customers_ + alpha_);
}
double log_crp_prob() const {
- return log_crp_prob(concentration_);
- }
-
- static double log_gamma_density(const double& x, const double& shape, const double& rate) {
- assert(x >= 0.0);
- assert(shape > 0.0);
- assert(rate > 0.0);
- const double lp = (shape-1)*log(x) - shape*log(rate) - x/rate - lgamma(shape);
- return lp;
+ return log_crp_prob(alpha_);
}
// taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process
// does not include P_0's
- double log_crp_prob(const double& concentration) const {
+ double log_crp_prob(const double& alpha) const {
double lp = 0.0;
- if (has_concentration_prior())
- lp += log_gamma_density(concentration, concentration_prior_shape_, concentration_prior_rate_);
+ if (has_alpha_prior())
+ lp += Md::log_gamma_density(alpha, alpha_prior_shape_, alpha_prior_rate_);
assert(lp <= 0.0);
if (num_customers_) {
- lp += lgamma(concentration) - lgamma(concentration + num_customers_) +
- custs_.size() * log(concentration);
+ lp += lgamma(alpha) - lgamma(alpha + num_customers_) +
+ custs_.size() * log(alpha);
assert(std::isfinite(lp));
for (typename std::tr1::unordered_map<Dish, unsigned, DishHash>::const_iterator it = custs_.begin();
it != custs_.end(); ++it) {
@@ -114,10 +109,10 @@ class CCRP_NoTable {
}
void resample_hyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) {
- assert(has_concentration_prior());
+ assert(has_alpha_prior());
ConcentrationResampler cr(*this);
for (int iter = 0; iter < nloop; ++iter) {
- concentration_ = slice_sampler1d(cr, concentration_, *rng, 0.0,
+ alpha_ = slice_sampler1d(cr, alpha_, *rng, 0.0,
std::numeric_limits<double>::infinity(), 0.0, niterations, 100*niterations);
}
}
@@ -125,13 +120,13 @@ class CCRP_NoTable {
struct ConcentrationResampler {
ConcentrationResampler(const CCRP_NoTable& crp) : crp_(crp) {}
const CCRP_NoTable& crp_;
- double operator()(const double& proposed_concentration) const {
- return crp_.log_crp_prob(proposed_concentration);
+ double operator()(const double& proposed_alpha) const {
+ return crp_.log_crp_prob(proposed_alpha);
}
};
void Print(std::ostream* out) const {
- (*out) << "DP(alpha=" << concentration_ << ") customers=" << num_customers_ << std::endl;
+ (*out) << "DP(alpha=" << alpha_ << ") customers=" << num_customers_ << std::endl;
int cc = 0;
for (typename std::tr1::unordered_map<Dish, unsigned, DishHash>::const_iterator it = custs_.begin();
it != custs_.end(); ++it) {
@@ -153,11 +148,11 @@ class CCRP_NoTable {
return custs_.end();
}
- double concentration_;
+ double alpha_;
- // optional gamma prior on concentration_ (NaN if no prior)
- double concentration_prior_shape_;
- double concentration_prior_rate_;
+ // optional gamma prior on alpha_ (NaN if no prior)
+ double alpha_prior_shape_;
+ double alpha_prior_rate_;
};
template <typename T,typename H>