summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--utils/ccrp.h103
-rw-r--r--utils/crp_table_manager.h6
2 files changed, 33 insertions, 76 deletions
diff --git a/utils/ccrp.h b/utils/ccrp.h
index 1d41a3ef..f5d3fc78 100644
--- a/utils/ccrp.h
+++ b/utils/ccrp.h
@@ -11,6 +11,7 @@
#include <boost/functional/hash.hpp>
#include "sampler.h"
#include "slice_sampler.h"
+#include "crp_table_manager.h"
#include "m.h"
// Chinese restaurant process (Pitman-Yor parameters) with table tracking.
@@ -81,9 +82,9 @@ class CCRP {
}
unsigned num_tables(const Dish& dish) const {
- const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
+ const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish);
if (it == dish_locs_.end()) return 0;
- return it->second.table_counts_.size();
+ return it->second.num_tables();
}
unsigned num_customers() const {
@@ -91,9 +92,9 @@ class CCRP {
}
unsigned num_customers(const Dish& dish) const {
- const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
+ const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish);
if (it == dish_locs_.end()) return 0;
- return it->total_dish_count_;
+ return it->num_customers();
}
// returns +1 or 0 indicating whether a new table was opened
@@ -101,84 +102,48 @@ class CCRP {
// excluding p0
template <typename T>
int increment(const Dish& dish, const T& p0, MT19937* rng, T* p = NULL) {
- DishLocations& loc = dish_locs_[dish];
+ CRPTableManager& loc = dish_locs_[dish];
bool share_table = false;
- if (loc.total_dish_count_) {
+ if (loc.num_customers()) {
const T p_empty = T(strength_ + num_tables_ * discount_) * p0;
- const T p_share = T(loc.total_dish_count_ - loc.table_counts_.size() * discount_);
+ const T p_share = T(loc.num_customers() - loc.num_tables() * discount_);
share_table = rng->SelectSample(p_empty, p_share);
}
if (share_table) {
- double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_);
- for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
- ti != loc.table_counts_.end(); ++ti) {
- r -= (*ti - discount_);
- if (r <= 0.0) {
- if (p) { *p = T(*ti - discount_) / T(strength_ + num_customers_); }
- ++(*ti);
- break;
- }
- }
- if (r > 0.0) {
- std::cerr << "Serious error: r=" << r << std::endl;
- Print(&std::cerr);
- assert(r <= 0.0);
- }
+ loc.share_table(discount_, rng);
} else {
- loc.table_counts_.push_back(1u);
- if (p) { *p = T(strength_ + discount_ * num_tables_) / T(strength_ + num_customers_); }
+ loc.create_table();
++num_tables_;
}
- ++loc.total_dish_count_;
++num_customers_;
return (share_table ? 0 : 1);
}
// returns -1 or 0, indicating whether a table was closed
int decrement(const Dish& dish, MT19937* rng) {
- DishLocations& loc = dish_locs_[dish];
- assert(loc.total_dish_count_);
- if (loc.total_dish_count_ == 1) {
+ CRPTableManager& loc = dish_locs_[dish];
+ assert(loc.num_customers());
+ if (loc.num_customers() == 1) {
dish_locs_.erase(dish);
--num_tables_;
--num_customers_;
return -1;
} else {
- int delta = 0;
- // sample customer to remove UNIFORMLY. that is, do NOT use the discount
- // here. if you do, it will introduce (unwanted) bias!
- double r = rng->next() * loc.total_dish_count_;
- --loc.total_dish_count_;
- for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
- ti != loc.table_counts_.end(); ++ti) {
- r -= *ti;
- if (r <= 0.0) {
- if ((--(*ti)) == 0) {
- --num_tables_;
- delta = -1;
- loc.table_counts_.erase(ti);
- }
- break;
- }
- }
- if (r > 0.0) {
- std::cerr << "Serious error: r=" << r << std::endl;
- Print(&std::cerr);
- assert(r <= 0.0);
- }
+ int delta = loc.remove_customer(rng);
--num_customers_;
+ if (delta) --num_tables_;
return delta;
}
}
template <typename T>
T prob(const Dish& dish, const T& p0) const {
- const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
+ const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish);
const T r = T(num_tables_ * discount_ + strength_);
if (it == dish_locs_.end()) {
return r * p0 / T(num_customers_ + strength_);
} else {
- return (T(it->second.total_dish_count_ - discount_ * it->second.table_counts_.size()) + r * p0) /
+ return (T(it->second.num_customers() - discount_ * it->second.num_tables()) + r * p0) /
T(num_customers_ + strength_);
}
}
@@ -204,20 +169,20 @@ class CCRP {
lp += - lgamma(strength + num_customers_)
+ num_tables_ * log(discount) + lgamma(strength / discount + num_tables_);
assert(std::isfinite(lp));
- for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
+ for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin();
it != dish_locs_.end(); ++it) {
- const DishLocations& cur = it->second;
- for (std::list<unsigned>::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) {
- lp += lgamma(*ti - discount) - r;
+ const CRPTableManager& cur = it->second; // TODO check
+ for (CRPTableManager::const_iterator ti = cur.begin(); ti != cur.end(); ++ti) {
+ lp += (lgamma(ti->first - discount) - r) * ti->second;
}
}
} else if (!discount) { // discount == 0.0
lp += lgamma(strength) + num_tables_ * log(strength) - lgamma(strength + num_tables_);
assert(std::isfinite(lp));
- for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
+ for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin();
it != dish_locs_.end(); ++it) {
- const DishLocations& cur = it->second;
- lp += lgamma(cur.table_counts_.size());
+ const CRPTableManager& cur = it->second;
+ lp += lgamma(cur.num_tables());
}
} else {
assert(!"discount less than 0 detected!");
@@ -264,27 +229,15 @@ class CCRP {
}
};
- struct DishLocations {
- DishLocations() : total_dish_count_() {}
- unsigned total_dish_count_; // customers at all tables with this dish
- std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want
- // .size() is the number of tables for this dish
- };
-
void Print(std::ostream* out) const {
std::cerr << "PYP(d=" << discount_ << ",c=" << strength_ << ") customers=" << num_customers_ << std::endl;
- for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
+ for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin();
it != dish_locs_.end(); ++it) {
- (*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): ";
- for (typename std::list<unsigned>::const_iterator i = it->second.table_counts_.begin();
- i != it->second.table_counts_.end(); ++i) {
- (*out) << " " << *i;
- }
- (*out) << std::endl;
+ (*out) << it->first << " : " << it->second << std::endl;
}
}
- typedef typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator const_iterator;
+ typedef typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator const_iterator;
const_iterator begin() const {
return dish_locs_.begin();
}
@@ -294,7 +247,7 @@ class CCRP {
unsigned num_tables_;
unsigned num_customers_;
- std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_;
+ std::tr1::unordered_map<Dish, CRPTableManager, DishHash> dish_locs_;
double discount_;
double strength_;
diff --git a/utils/crp_table_manager.h b/utils/crp_table_manager.h
index 32840ad8..753e721f 100644
--- a/utils/crp_table_manager.h
+++ b/utils/crp_table_manager.h
@@ -93,6 +93,10 @@ struct CRPTableManager {
}
}
+ typedef CRPHistogram::const_iterator const_iterator;
+ const_iterator begin() const { return h.begin(); }
+ const_iterator end() const { return h.end(); }
+
unsigned customers;
unsigned tables;
CRPHistogram h;
@@ -100,7 +104,7 @@ struct CRPTableManager {
std::ostream& operator<<(std::ostream& os, const CRPTableManager& tm) {
os << '[' << tm.num_customers() << " total customers at " << tm.num_tables() << " tables ||| ";
- for (CRPHistogram::const_iterator it = tm.h.begin(); it != tm.h.end(); ++it) {
+ for (CRPHistogram::const_iterator it = tm.begin(); it != tm.end(); ++it) {
if (it != tm.h.begin()) os << " -- ";
os << '(' << it->first << ") x " << it->second;
}