diff options
-rw-r--r-- | utils/ccrp.h | 103 | ||||
-rw-r--r-- | utils/crp_table_manager.h | 6 |
2 files changed, 33 insertions, 76 deletions
diff --git a/utils/ccrp.h b/utils/ccrp.h index 1d41a3ef..f5d3fc78 100644 --- a/utils/ccrp.h +++ b/utils/ccrp.h @@ -11,6 +11,7 @@ #include <boost/functional/hash.hpp> #include "sampler.h" #include "slice_sampler.h" +#include "crp_table_manager.h" #include "m.h" // Chinese restaurant process (Pitman-Yor parameters) with table tracking. @@ -81,9 +82,9 @@ class CCRP { } unsigned num_tables(const Dish& dish) const { - const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); + const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish); if (it == dish_locs_.end()) return 0; - return it->second.table_counts_.size(); + return it->second.num_tables(); } unsigned num_customers() const { @@ -91,9 +92,9 @@ class CCRP { } unsigned num_customers(const Dish& dish) const { - const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); + const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish); if (it == dish_locs_.end()) return 0; - return it->total_dish_count_; + return it->num_customers(); } // returns +1 or 0 indicating whether a new table was opened @@ -101,84 +102,48 @@ class CCRP { // excluding p0 template <typename T> int increment(const Dish& dish, const T& p0, MT19937* rng, T* p = NULL) { - DishLocations& loc = dish_locs_[dish]; + CRPTableManager& loc = dish_locs_[dish]; bool share_table = false; - if (loc.total_dish_count_) { + if (loc.num_customers()) { const T p_empty = T(strength_ + num_tables_ * discount_) * p0; - const T p_share = T(loc.total_dish_count_ - loc.table_counts_.size() * discount_); + const T p_share = T(loc.num_customers() - loc.num_tables() * discount_); share_table = rng->SelectSample(p_empty, p_share); } if (share_table) { - double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); - for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); - ti != loc.table_counts_.end(); ++ti) { - r -= (*ti - discount_); - if (r <= 0.0) { - if (p) { *p = T(*ti - discount_) / T(strength_ + num_customers_); } - ++(*ti); - break; - } - } - if (r > 0.0) { - std::cerr << "Serious error: r=" << r << std::endl; - Print(&std::cerr); - assert(r <= 0.0); - } + loc.share_table(discount_, rng); } else { - loc.table_counts_.push_back(1u); - if (p) { *p = T(strength_ + discount_ * num_tables_) / T(strength_ + num_customers_); } + loc.create_table(); ++num_tables_; } - ++loc.total_dish_count_; ++num_customers_; return (share_table ? 0 : 1); } // returns -1 or 0, indicating whether a table was closed int decrement(const Dish& dish, MT19937* rng) { - DishLocations& loc = dish_locs_[dish]; - assert(loc.total_dish_count_); - if (loc.total_dish_count_ == 1) { + CRPTableManager& loc = dish_locs_[dish]; + assert(loc.num_customers()); + if (loc.num_customers() == 1) { dish_locs_.erase(dish); --num_tables_; --num_customers_; return -1; } else { - int delta = 0; - // sample customer to remove UNIFORMLY. that is, do NOT use the discount - // here. if you do, it will introduce (unwanted) bias! - double r = rng->next() * loc.total_dish_count_; - --loc.total_dish_count_; - for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); - ti != loc.table_counts_.end(); ++ti) { - r -= *ti; - if (r <= 0.0) { - if ((--(*ti)) == 0) { - --num_tables_; - delta = -1; - loc.table_counts_.erase(ti); - } - break; - } - } - if (r > 0.0) { - std::cerr << "Serious error: r=" << r << std::endl; - Print(&std::cerr); - assert(r <= 0.0); - } + int delta = loc.remove_customer(rng); --num_customers_; + if (delta) --num_tables_; return delta; } } template <typename T> T prob(const Dish& dish, const T& p0) const { - const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); + const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish); const T r = T(num_tables_ * discount_ + strength_); if (it == dish_locs_.end()) { return r * p0 / T(num_customers_ + strength_); } else { - return (T(it->second.total_dish_count_ - discount_ * it->second.table_counts_.size()) + r * p0) / + return (T(it->second.num_customers() - discount_ * it->second.num_tables()) + r * p0) / T(num_customers_ + strength_); } } @@ -204,20 +169,20 @@ class CCRP { lp += - lgamma(strength + num_customers_) + num_tables_ * log(discount) + lgamma(strength / discount + num_tables_); assert(std::isfinite(lp)); - for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); + for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { - const DishLocations& cur = it->second; - for (std::list<unsigned>::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) { - lp += lgamma(*ti - discount) - r; + const CRPTableManager& cur = it->second; // TODO check + for (CRPTableManager::const_iterator ti = cur.begin(); ti != cur.end(); ++ti) { + lp += (lgamma(ti->first - discount) - r) * ti->second; } } } else if (!discount) { // discount == 0.0 lp += lgamma(strength) + num_tables_ * log(strength) - lgamma(strength + num_tables_); assert(std::isfinite(lp)); - for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); + for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { - const DishLocations& cur = it->second; - lp += lgamma(cur.table_counts_.size()); + const CRPTableManager& cur = it->second; + lp += lgamma(cur.num_tables()); } } else { assert(!"discount less than 0 detected!"); @@ -264,27 +229,15 @@ class CCRP { } }; - struct DishLocations { - DishLocations() : total_dish_count_() {} - unsigned total_dish_count_; // customers at all tables with this dish - std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want - // .size() is the number of tables for this dish - }; - void Print(std::ostream* out) const { std::cerr << "PYP(d=" << discount_ << ",c=" << strength_ << ") customers=" << num_customers_ << std::endl; - for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); + for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { - (*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; - for (typename std::list<unsigned>::const_iterator i = it->second.table_counts_.begin(); - i != it->second.table_counts_.end(); ++i) { - (*out) << " " << *i; - } - (*out) << std::endl; + (*out) << it->first << " : " << it->second << std::endl; } } - typedef typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator const_iterator; + typedef typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator const_iterator; const_iterator begin() const { return dish_locs_.begin(); } @@ -294,7 +247,7 @@ class CCRP { unsigned num_tables_; unsigned num_customers_; - std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_; + std::tr1::unordered_map<Dish, CRPTableManager, DishHash> dish_locs_; double discount_; double strength_; diff --git a/utils/crp_table_manager.h b/utils/crp_table_manager.h index 32840ad8..753e721f 100644 --- a/utils/crp_table_manager.h +++ b/utils/crp_table_manager.h @@ -93,6 +93,10 @@ struct CRPTableManager { } } + typedef CRPHistogram::const_iterator const_iterator; + const_iterator begin() const { return h.begin(); } + const_iterator end() const { return h.end(); } + unsigned customers; unsigned tables; CRPHistogram h; @@ -100,7 +104,7 @@ struct CRPTableManager { std::ostream& operator<<(std::ostream& os, const CRPTableManager& tm) { os << '[' << tm.num_customers() << " total customers at " << tm.num_tables() << " tables ||| "; - for (CRPHistogram::const_iterator it = tm.h.begin(); it != tm.h.end(); ++it) { + for (CRPHistogram::const_iterator it = tm.begin(); it != tm.end(); ++it) { if (it != tm.h.begin()) os << " -- "; os << '(' << it->first << ") x " << it->second; } |