diff options
Diffstat (limited to 'gi/clda/src/ccrp.h')
-rw-r--r-- | gi/clda/src/ccrp.h | 55 |
1 files changed, 38 insertions, 17 deletions
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h index e978225d..47f364f2 100644 --- a/gi/clda/src/ccrp.h +++ b/gi/clda/src/ccrp.h @@ -18,16 +18,22 @@ class CCRP { public: CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {} + unsigned num_tables(const Dish& dish) const { + const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); + if (it == dish_locs_.end()) return 0; + return it->second.table_counts_.size(); + } + int increment(const Dish& dish, const double& p0, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; bool share_table = false; - if (loc.dish_count_) { + if (loc.total_dish_count_) { const double p_empty = (concentration_ + num_tables_ * discount_) * p0; - const double p_share = (loc.dish_count_ - loc.table_counts_.size() * discount_); + const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_); share_table = rng->SelectSample(p_empty, p_share); } if (share_table) { - double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_); + double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); ti != loc.table_counts_.end(); ++ti) { r -= (*ti - discount_); @@ -45,23 +51,23 @@ class CCRP { loc.table_counts_.push_back(1u); ++num_tables_; } - ++loc.dish_count_; + ++loc.total_dish_count_; ++num_customers_; return (share_table ? 0 : 1); } int decrement(const Dish& dish, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; - assert(loc.dish_count_); - if (loc.dish_count_ == 1) { + assert(loc.total_dish_count_); + if (loc.total_dish_count_ == 1) { dish_locs_.erase(dish); --num_tables_; --num_customers_; return -1; } else { int delta = 0; - double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_); - --loc.dish_count_; + double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); + --loc.total_dish_count_; for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); ti != loc.table_counts_.end(); ++ti) { r -= (*ti - discount_); @@ -90,30 +96,37 @@ class CCRP { if (it == dish_locs_.end()) { return r * p0 / (num_customers_ + concentration_); } else { - return (it->second.dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) / + return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) / (num_customers_ + concentration_); } } - double llh() const { + // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process + // does not include P_0's + double log_crp_prob() const { + double lp = 0.0; if (num_customers_) { - std::cerr << "not implemented\n"; - return 0.0; - } else { - return 0.0; + const double r = lgamma(1.0 - discount_); + lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) - lgamma(concentration_ / discount_); + for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); + it != dish_locs_.end(); ++it) { + const DishLocations& cur = it->second; + lp += lgamma(cur.total_dish_count_ - discount_) - r; + } } + return lp; } struct DishLocations { - DishLocations() : dish_count_() {} - unsigned dish_count_; + DishLocations() : total_dish_count_() {} + unsigned total_dish_count_; std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want }; void Print(std::ostream* out) const { for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { - (*out) << it->first << " (" << it->second.dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; + (*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; for (typename std::list<unsigned>::const_iterator i = it->second.table_counts_.begin(); i != it->second.table_counts_.end(); ++i) { (*out) << " " << *i; @@ -122,6 +135,14 @@ class CCRP { } } + typedef typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator const_iterator; + const_iterator begin() const { + return dish_locs_.begin(); + } + const_iterator end() const { + return dish_locs_.end(); + } + unsigned num_tables_; unsigned num_customers_; std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_; |