summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
Diffstat (limited to 'utils')
-rw-r--r--utils/ccrp.h68
1 files changed, 43 insertions, 25 deletions
diff --git a/utils/ccrp.h b/utils/ccrp.h
index 68769635..c883c027 100644
--- a/utils/ccrp.h
+++ b/utils/ccrp.h
@@ -19,29 +19,44 @@ template <typename Dish, typename DishHash = boost::hash<Dish> >
class CCRP {
public:
CCRP(double disc, double alpha) :
- num_tables_(),
- num_customers_(),
- discount_(disc),
- alpha_(alpha),
- discount_prior_alpha_(std::numeric_limits<double>::quiet_NaN()),
- discount_prior_beta_(std::numeric_limits<double>::quiet_NaN()),
- alpha_prior_shape_(std::numeric_limits<double>::quiet_NaN()),
- alpha_prior_rate_(std::numeric_limits<double>::quiet_NaN()) {}
+ num_tables_(),
+ num_customers_(),
+ discount_(disc),
+ alpha_(alpha),
+ discount_prior_alpha_(std::numeric_limits<double>::quiet_NaN()),
+ discount_prior_beta_(std::numeric_limits<double>::quiet_NaN()),
+ alpha_prior_shape_(std::numeric_limits<double>::quiet_NaN()),
+ alpha_prior_rate_(std::numeric_limits<double>::quiet_NaN()) {
+ check_hyperparameters();
+ }
CCRP(double d_alpha, double d_beta, double c_shape, double c_rate, double d = 0.9, double c = 1.0) :
- num_tables_(),
- num_customers_(),
- discount_(d),
- alpha_(c),
- discount_prior_alpha_(d_alpha),
- discount_prior_beta_(d_beta),
- alpha_prior_shape_(c_shape),
- alpha_prior_rate_(c_rate) {}
+ num_tables_(),
+ num_customers_(),
+ discount_(d),
+ alpha_(c),
+ discount_prior_alpha_(d_alpha),
+ discount_prior_beta_(d_beta),
+ alpha_prior_shape_(c_shape),
+ alpha_prior_rate_(c_rate) {
+ check_hyperparameters();
+ }
+
+ void check_hyperparameters() {
+ if (discount_ < 0.0 || discount_ >= 1.0) {
+ std::cerr << "Bad discount: " << discount_ << std::endl;
+ abort();
+ }
+ if (alpha_ <= -discount_) {
+ std::cerr << "Bad strength: " << alpha_ << " (discount=" << discount_ << ")" << std::endl;
+ abort();
+ }
+ }
double discount() const { return discount_; }
double alpha() const { return alpha_; }
- void set_discount(double d) { discount_ = d; }
- void set_alpha(double a) { alpha_ = a; }
+ void set_discount(double d) { discount_ = d; check_hyperparameters(); }
+ void set_alpha(double a) { alpha_ = a; check_hyperparameters(); }
bool has_discount_prior() const {
return !std::isnan(discount_prior_alpha_);
@@ -215,14 +230,15 @@ class CCRP {
if (has_discount_prior())
lp = Md::log_beta_density(discount, discount_prior_alpha_, discount_prior_beta_);
if (has_alpha_prior())
- lp += Md::log_gamma_density(alpha, alpha_prior_shape_, alpha_prior_rate_);
+ lp += Md::log_gamma_density(alpha + discount, alpha_prior_shape_, alpha_prior_rate_);
assert(lp <= 0.0);
if (num_customers_) {
if (discount > 0.0) {
const double r = lgamma(1.0 - discount);
- lp += lgamma(alpha) - lgamma(alpha + num_customers_)
- + num_tables_ * log(discount) + lgamma(alpha / discount + num_tables_)
- - lgamma(alpha / discount);
+ if (alpha)
+ lp += lgamma(alpha) - lgamma(alpha / discount);
+ lp += - lgamma(alpha + num_customers_)
+ + num_tables_ * log(discount) + lgamma(alpha / discount + num_tables_);
assert(std::isfinite(lp));
for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
it != dish_locs_.end(); ++it) {
@@ -245,15 +261,17 @@ class CCRP {
StrengthResampler sr(*this);
for (int iter = 0; iter < nloop; ++iter) {
if (has_alpha_prior()) {
- alpha_ = slice_sampler1d(sr, alpha_, *rng, 0.0,
+ alpha_ = slice_sampler1d(sr, alpha_, *rng, -discount_,
std::numeric_limits<double>::infinity(), 0.0, niterations, 100*niterations);
}
if (has_discount_prior()) {
- discount_ = slice_sampler1d(dr, discount_, *rng, std::numeric_limits<double>::min(),
+ double min_discount = std::numeric_limits<double>::min();
+ if (alpha_ < 0.0) min_discount = -alpha_;
+ discount_ = slice_sampler1d(dr, discount_, *rng, min_discount,
1.0, 0.0, niterations, 100*niterations);
}
}
- alpha_ = slice_sampler1d(sr, alpha_, *rng, 0.0,
+ alpha_ = slice_sampler1d(sr, alpha_, *rng, -discount_,
std::numeric_limits<double>::infinity(), 0.0, niterations, 100*niterations);
}