diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/ccrp.h | 4 | ||||
-rw-r--r-- | utils/mfcr.h | 19 |
2 files changed, 18 insertions, 5 deletions
diff --git a/utils/ccrp.h b/utils/ccrp.h index e24130ac..439d7e1e 100644 --- a/utils/ccrp.h +++ b/utils/ccrp.h @@ -225,12 +225,12 @@ class CCRP { StrengthResampler sr(*this); for (int iter = 0; iter < nloop; ++iter) { if (has_strength_prior()) { - strength_ = slice_sampler1d(sr, strength_, *rng, -discount_, + strength_ = slice_sampler1d(sr, strength_, *rng, -discount_ + std::numeric_limits<double>::min(), std::numeric_limits<double>::infinity(), 0.0, niterations, 100*niterations); } if (has_discount_prior()) { double min_discount = std::numeric_limits<double>::min(); - if (strength_ < 0.0) min_discount = -strength_; + if (strength_ < 0.0) min_discount -= strength_; discount_ = slice_sampler1d(dr, discount_, *rng, min_discount, 1.0, 0.0, niterations, 100*niterations); } diff --git a/utils/mfcr.h b/utils/mfcr.h index 6cc0ebf1..886f01ef 100644 --- a/utils/mfcr.h +++ b/utils/mfcr.h @@ -48,7 +48,7 @@ class MFCR { discount_prior_strength_(std::numeric_limits<double>::quiet_NaN()), discount_prior_beta_(std::numeric_limits<double>::quiet_NaN()), strength_prior_shape_(std::numeric_limits<double>::quiet_NaN()), - strength_prior_rate_(std::numeric_limits<double>::quiet_NaN()) {} + strength_prior_rate_(std::numeric_limits<double>::quiet_NaN()) { check_hyperparameters(); } MFCR(double discount_strength, double discount_beta, double strength_shape, double strength_rate, double d = 0.9, double strength = 10.0) : num_tables_(), @@ -58,10 +58,23 @@ class MFCR { discount_prior_strength_(discount_strength), discount_prior_beta_(discount_beta), strength_prior_shape_(strength_shape), - strength_prior_rate_(strength_rate) {} + strength_prior_rate_(strength_rate) { check_hyperparameters(); } + + void check_hyperparameters() { + if (discount_ < 0.0 || discount_ >= 1.0) { + std::cerr << "Bad discount: " << discount_ << std::endl; + abort(); + } + if (strength_ <= -discount_) { + std::cerr << "Bad strength: " << strength_ << " (discount=" << discount_ << ")" << std::endl; + abort(); + } + } double discount() const { return discount_; } double strength() const { return strength_; } + void set_discount(double d) { discount_ = d; check_hyperparameters(); } + void set_strength(double a) { strength_ = a; check_hyperparameters(); } bool has_discount_prior() const { return !std::isnan(discount_prior_strength_); @@ -275,7 +288,7 @@ class MFCR { } if (has_discount_prior()) { double min_discount = std::numeric_limits<double>::min(); - if (strength_ < 0.0) min_discount = -strength_; + if (strength_ < 0.0) min_discount -= strength_; discount_ = slice_sampler1d(dr, discount_, *rng, min_discount, 1.0, 0.0, niterations, 100*niterations); } |