diff options
Diffstat (limited to 'utils/ccrp.h')
-rw-r--r-- | utils/ccrp.h | 103 |
1 files changed, 28 insertions, 75 deletions
diff --git a/utils/ccrp.h b/utils/ccrp.h index 1d41a3ef..f5d3fc78 100644 --- a/utils/ccrp.h +++ b/utils/ccrp.h @@ -11,6 +11,7 @@ #include <boost/functional/hash.hpp> #include "sampler.h" #include "slice_sampler.h" +#include "crp_table_manager.h" #include "m.h" // Chinese restaurant process (Pitman-Yor parameters) with table tracking. @@ -81,9 +82,9 @@ class CCRP { } unsigned num_tables(const Dish& dish) const { - const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); + const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish); if (it == dish_locs_.end()) return 0; - return it->second.table_counts_.size(); + return it->second.num_tables(); } unsigned num_customers() const { @@ -91,9 +92,9 @@ class CCRP { } unsigned num_customers(const Dish& dish) const { - const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); + const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish); if (it == dish_locs_.end()) return 0; - return it->total_dish_count_; + return it->num_customers(); } // returns +1 or 0 indicating whether a new table was opened @@ -101,84 +102,48 @@ class CCRP { // excluding p0 template <typename T> int increment(const Dish& dish, const T& p0, MT19937* rng, T* p = NULL) { - DishLocations& loc = dish_locs_[dish]; + CRPTableManager& loc = dish_locs_[dish]; bool share_table = false; - if (loc.total_dish_count_) { + if (loc.num_customers()) { const T p_empty = T(strength_ + num_tables_ * discount_) * p0; - const T p_share = T(loc.total_dish_count_ - loc.table_counts_.size() * discount_); + const T p_share = T(loc.num_customers() - loc.num_tables() * discount_); share_table = rng->SelectSample(p_empty, p_share); } if (share_table) { - double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); - for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); - ti != loc.table_counts_.end(); ++ti) { - r -= (*ti - discount_); - if (r <= 0.0) { - if (p) { *p = T(*ti - discount_) / T(strength_ + num_customers_); } - ++(*ti); - break; - } - } - if (r > 0.0) { - std::cerr << "Serious error: r=" << r << std::endl; - Print(&std::cerr); - assert(r <= 0.0); - } + loc.share_table(discount_, rng); } else { - loc.table_counts_.push_back(1u); - if (p) { *p = T(strength_ + discount_ * num_tables_) / T(strength_ + num_customers_); } + loc.create_table(); ++num_tables_; } - ++loc.total_dish_count_; ++num_customers_; return (share_table ? 0 : 1); } // returns -1 or 0, indicating whether a table was closed int decrement(const Dish& dish, MT19937* rng) { - DishLocations& loc = dish_locs_[dish]; - assert(loc.total_dish_count_); - if (loc.total_dish_count_ == 1) { + CRPTableManager& loc = dish_locs_[dish]; + assert(loc.num_customers()); + if (loc.num_customers() == 1) { dish_locs_.erase(dish); --num_tables_; --num_customers_; return -1; } else { - int delta = 0; - // sample customer to remove UNIFORMLY. that is, do NOT use the discount - // here. if you do, it will introduce (unwanted) bias! - double r = rng->next() * loc.total_dish_count_; - --loc.total_dish_count_; - for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); - ti != loc.table_counts_.end(); ++ti) { - r -= *ti; - if (r <= 0.0) { - if ((--(*ti)) == 0) { - --num_tables_; - delta = -1; - loc.table_counts_.erase(ti); - } - break; - } - } - if (r > 0.0) { - std::cerr << "Serious error: r=" << r << std::endl; - Print(&std::cerr); - assert(r <= 0.0); - } + int delta = loc.remove_customer(rng); --num_customers_; + if (delta) --num_tables_; return delta; } } template <typename T> T prob(const Dish& dish, const T& p0) const { - const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); + const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish); const T r = T(num_tables_ * discount_ + strength_); if (it == dish_locs_.end()) { return r * p0 / T(num_customers_ + strength_); } else { - return (T(it->second.total_dish_count_ - discount_ * it->second.table_counts_.size()) + r * p0) / + return (T(it->second.num_customers() - discount_ * it->second.num_tables()) + r * p0) / T(num_customers_ + strength_); } } @@ -204,20 +169,20 @@ class CCRP { lp += - lgamma(strength + num_customers_) + num_tables_ * log(discount) + lgamma(strength / discount + num_tables_); assert(std::isfinite(lp)); - for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); + for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { - const DishLocations& cur = it->second; - for (std::list<unsigned>::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) { - lp += lgamma(*ti - discount) - r; + const CRPTableManager& cur = it->second; // TODO check + for (CRPTableManager::const_iterator ti = cur.begin(); ti != cur.end(); ++ti) { + lp += (lgamma(ti->first - discount) - r) * ti->second; } } } else if (!discount) { // discount == 0.0 lp += lgamma(strength) + num_tables_ * log(strength) - lgamma(strength + num_tables_); assert(std::isfinite(lp)); - for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); + for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { - const DishLocations& cur = it->second; - lp += lgamma(cur.table_counts_.size()); + const CRPTableManager& cur = it->second; + lp += lgamma(cur.num_tables()); } } else { assert(!"discount less than 0 detected!"); @@ -264,27 +229,15 @@ class CCRP { } }; - struct DishLocations { - DishLocations() : total_dish_count_() {} - unsigned total_dish_count_; // customers at all tables with this dish - std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want - // .size() is the number of tables for this dish - }; - void Print(std::ostream* out) const { std::cerr << "PYP(d=" << discount_ << ",c=" << strength_ << ") customers=" << num_customers_ << std::endl; - for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); + for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { - (*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; - for (typename std::list<unsigned>::const_iterator i = it->second.table_counts_.begin(); - i != it->second.table_counts_.end(); ++i) { - (*out) << " " << *i; - } - (*out) << std::endl; + (*out) << it->first << " : " << it->second << std::endl; } } - typedef typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator const_iterator; + typedef typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator const_iterator; const_iterator begin() const { return dish_locs_.begin(); } @@ -294,7 +247,7 @@ class CCRP { unsigned num_tables_; unsigned num_customers_; - std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_; + std::tr1::unordered_map<Dish, CRPTableManager, DishHash> dish_locs_; double discount_; double strength_; |