summaryrefslogtreecommitdiff
path: root/utils/ccrp.h
diff options
context:
space:
mode:
authorPatrick Simianer <simianer@cl.uni-heidelberg.de>2012-04-07 16:58:55 +0200
committerPatrick Simianer <simianer@cl.uni-heidelberg.de>2012-04-07 16:58:55 +0200
commite91553ae70907e243a554e4a549c53df57b78478 (patch)
treea4d044093f5937d0152b573c99914746b5a2b8ef /utils/ccrp.h
parentfb714888562845a8ae10fd4411cf199961193833 (diff)
parent2fe4323cbfc34de906a2869f98c017b41e4ccae7 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'utils/ccrp.h')
-rw-r--r--utils/ccrp.h10
1 files changed, 9 insertions, 1 deletions
diff --git a/utils/ccrp.h b/utils/ccrp.h
index 4a8b80e7..8635b422 100644
--- a/utils/ccrp.h
+++ b/utils/ccrp.h
@@ -55,6 +55,10 @@ class CCRP {
double discount() const { return discount_; }
double strength() const { return strength_; }
+ void set_hyperparameters(double d, double s) {
+ discount_ = d; strength_ = s;
+ check_hyperparameters();
+ }
void set_discount(double d) { discount_ = d; check_hyperparameters(); }
void set_strength(double a) { strength_ = a; check_hyperparameters(); }
@@ -93,8 +97,10 @@ class CCRP {
}
// returns +1 or 0 indicating whether a new table was opened
+ // p = probability with which the particular table was selected
+ // excluding p0
template <typename T>
- int increment(const Dish& dish, const T& p0, MT19937* rng) {
+ int increment(const Dish& dish, const T& p0, MT19937* rng, T* p = NULL) {
DishLocations& loc = dish_locs_[dish];
bool share_table = false;
if (loc.total_dish_count_) {
@@ -108,6 +114,7 @@ class CCRP {
ti != loc.table_counts_.end(); ++ti) {
r -= (*ti - discount_);
if (r <= 0.0) {
+ if (p) { *p = T(*ti - discount_) / T(strength_ + num_customers_); }
++(*ti);
break;
}
@@ -119,6 +126,7 @@ class CCRP {
}
} else {
loc.table_counts_.push_back(1u);
+ if (p) { *p = T(strength_ + discount_ * num_tables_) / T(strength_ + num_customers_); }
++num_tables_;
}
++loc.total_dish_count_;