diff options
-rw-r--r-- | gi/clda/src/ccrp.h | 33 |
1 files changed, 21 insertions, 12 deletions
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h index 74d5be29..a7c2825c 100644 --- a/gi/clda/src/ccrp.h +++ b/gi/clda/src/ccrp.h @@ -12,8 +12,7 @@ #include "sampler.h" #include "slice_sampler.h" -// Chinese restaurant process (Pitman-Yor parameters) with explicit table -// tracking. +// Chinese restaurant process (Pitman-Yor parameters) with table tracking. template <typename Dish, typename DishHash = boost::hash<Dish> > class CCRP { @@ -65,6 +64,12 @@ class CCRP { return num_customers_; } + unsigned num_customers(const Dish& dish) const { + const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); + if (it == dish_locs_.end()) return 0; + return it->total_dish_count_; + } + // returns +1 or 0 indicating whether a new table was opened int increment(const Dish& dish, const double& p0, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; @@ -177,17 +182,21 @@ class CCRP { lp += log_gamma_density(concentration, concentration_prior_shape_, concentration_prior_rate_); assert(lp <= 0.0); if (num_customers_) { - const double r = lgamma(1.0 - discount); - lp += lgamma(concentration) - lgamma(concentration + num_customers_) - + num_tables_ * discount + lgamma(concentration / discount + num_tables_) - - lgamma(concentration / discount); - assert(std::isfinite(lp)); - for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); - it != dish_locs_.end(); ++it) { - const DishLocations& cur = it->second; - for (std::list<unsigned>::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) { - lp += lgamma(*ti - discount) - r; + if (discount > 0.0) { + const double r = lgamma(1.0 - discount); + lp += lgamma(concentration) - lgamma(concentration + num_customers_) + + num_tables_ * log(discount) + lgamma(concentration / discount + num_tables_) + - lgamma(concentration / discount); + assert(std::isfinite(lp)); + for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); + it != dish_locs_.end(); ++it) { + const DishLocations& cur = it->second; + for (std::list<unsigned>::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) { + lp += lgamma(*ti - discount) - r; + } } + } else { + assert(!"not implemented yet"); } } assert(std::isfinite(lp)); |