From 836470428c6398ddd5ca86023ba9b48517110c58 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 4 Mar 2012 23:15:51 -0500 Subject: support full range of hyperparameter values for PYP (including strength <= 0) --- utils/ccrp.h | 68 ++++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/utils/ccrp.h b/utils/ccrp.h index 68769635..c883c027 100644 --- a/utils/ccrp.h +++ b/utils/ccrp.h @@ -19,29 +19,44 @@ template > class CCRP { public: CCRP(double disc, double alpha) : - num_tables_(), - num_customers_(), - discount_(disc), - alpha_(alpha), - discount_prior_alpha_(std::numeric_limits::quiet_NaN()), - discount_prior_beta_(std::numeric_limits::quiet_NaN()), - alpha_prior_shape_(std::numeric_limits::quiet_NaN()), - alpha_prior_rate_(std::numeric_limits::quiet_NaN()) {} + num_tables_(), + num_customers_(), + discount_(disc), + alpha_(alpha), + discount_prior_alpha_(std::numeric_limits::quiet_NaN()), + discount_prior_beta_(std::numeric_limits::quiet_NaN()), + alpha_prior_shape_(std::numeric_limits::quiet_NaN()), + alpha_prior_rate_(std::numeric_limits::quiet_NaN()) { + check_hyperparameters(); + } CCRP(double d_alpha, double d_beta, double c_shape, double c_rate, double d = 0.9, double c = 1.0) : - num_tables_(), - num_customers_(), - discount_(d), - alpha_(c), - discount_prior_alpha_(d_alpha), - discount_prior_beta_(d_beta), - alpha_prior_shape_(c_shape), - alpha_prior_rate_(c_rate) {} + num_tables_(), + num_customers_(), + discount_(d), + alpha_(c), + discount_prior_alpha_(d_alpha), + discount_prior_beta_(d_beta), + alpha_prior_shape_(c_shape), + alpha_prior_rate_(c_rate) { + check_hyperparameters(); + } + + void check_hyperparameters() { + if (discount_ < 0.0 || discount_ >= 1.0) { + std::cerr << "Bad discount: " << discount_ << std::endl; + abort(); + } + if (alpha_ <= -discount_) { + std::cerr << "Bad strength: " << alpha_ << " (discount=" << discount_ << ")" << std::endl; + abort(); + } + } double discount() const { return discount_; } double alpha() const { return alpha_; } - void set_discount(double d) { discount_ = d; } - void set_alpha(double a) { alpha_ = a; } + void set_discount(double d) { discount_ = d; check_hyperparameters(); } + void set_alpha(double a) { alpha_ = a; check_hyperparameters(); } bool has_discount_prior() const { return !std::isnan(discount_prior_alpha_); @@ -215,14 +230,15 @@ class CCRP { if (has_discount_prior()) lp = Md::log_beta_density(discount, discount_prior_alpha_, discount_prior_beta_); if (has_alpha_prior()) - lp += Md::log_gamma_density(alpha, alpha_prior_shape_, alpha_prior_rate_); + lp += Md::log_gamma_density(alpha + discount, alpha_prior_shape_, alpha_prior_rate_); assert(lp <= 0.0); if (num_customers_) { if (discount > 0.0) { const double r = lgamma(1.0 - discount); - lp += lgamma(alpha) - lgamma(alpha + num_customers_) - + num_tables_ * log(discount) + lgamma(alpha / discount + num_tables_) - - lgamma(alpha / discount); + if (alpha) + lp += lgamma(alpha) - lgamma(alpha / discount); + lp += - lgamma(alpha + num_customers_) + + num_tables_ * log(discount) + lgamma(alpha / discount + num_tables_); assert(std::isfinite(lp)); for (typename std::tr1::unordered_map::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { @@ -245,15 +261,17 @@ class CCRP { StrengthResampler sr(*this); for (int iter = 0; iter < nloop; ++iter) { if (has_alpha_prior()) { - alpha_ = slice_sampler1d(sr, alpha_, *rng, 0.0, + alpha_ = slice_sampler1d(sr, alpha_, *rng, -discount_, std::numeric_limits::infinity(), 0.0, niterations, 100*niterations); } if (has_discount_prior()) { - discount_ = slice_sampler1d(dr, discount_, *rng, std::numeric_limits::min(), + double min_discount = std::numeric_limits::min(); + if (alpha_ < 0.0) min_discount = -alpha_; + discount_ = slice_sampler1d(dr, discount_, *rng, min_discount, 1.0, 0.0, niterations, 100*niterations); } } - alpha_ = slice_sampler1d(sr, alpha_, *rng, 0.0, + alpha_ = slice_sampler1d(sr, alpha_, *rng, -discount_, std::numeric_limits::infinity(), 0.0, niterations, 100*niterations); } -- cgit v1.2.3