summaryrefslogtreecommitdiff
path: root/utils/mfcr.h
diff options
context:
space:
mode:
Diffstat (limited to 'utils/mfcr.h')
-rw-r--r--utils/mfcr.h19
1 files changed, 16 insertions, 3 deletions
diff --git a/utils/mfcr.h b/utils/mfcr.h
index 6cc0ebf1..886f01ef 100644
--- a/utils/mfcr.h
+++ b/utils/mfcr.h
@@ -48,7 +48,7 @@ class MFCR {
discount_prior_strength_(std::numeric_limits<double>::quiet_NaN()),
discount_prior_beta_(std::numeric_limits<double>::quiet_NaN()),
strength_prior_shape_(std::numeric_limits<double>::quiet_NaN()),
- strength_prior_rate_(std::numeric_limits<double>::quiet_NaN()) {}
+ strength_prior_rate_(std::numeric_limits<double>::quiet_NaN()) { check_hyperparameters(); }
MFCR(double discount_strength, double discount_beta, double strength_shape, double strength_rate, double d = 0.9, double strength = 10.0) :
num_tables_(),
@@ -58,10 +58,23 @@ class MFCR {
discount_prior_strength_(discount_strength),
discount_prior_beta_(discount_beta),
strength_prior_shape_(strength_shape),
- strength_prior_rate_(strength_rate) {}
+ strength_prior_rate_(strength_rate) { check_hyperparameters(); }
+
+ void check_hyperparameters() {
+ if (discount_ < 0.0 || discount_ >= 1.0) {
+ std::cerr << "Bad discount: " << discount_ << std::endl;
+ abort();
+ }
+ if (strength_ <= -discount_) {
+ std::cerr << "Bad strength: " << strength_ << " (discount=" << discount_ << ")" << std::endl;
+ abort();
+ }
+ }
double discount() const { return discount_; }
double strength() const { return strength_; }
+ void set_discount(double d) { discount_ = d; check_hyperparameters(); }
+ void set_strength(double a) { strength_ = a; check_hyperparameters(); }
bool has_discount_prior() const {
return !std::isnan(discount_prior_strength_);
@@ -275,7 +288,7 @@ class MFCR {
}
if (has_discount_prior()) {
double min_discount = std::numeric_limits<double>::min();
- if (strength_ < 0.0) min_discount = -strength_;
+ if (strength_ < 0.0) min_discount -= strength_;
discount_ = slice_sampler1d(dr, discount_, *rng, min_discount,
1.0, 0.0, niterations, 100*niterations);
}