summaryrefslogtreecommitdiff
path: root/gi
diff options
context:
space:
mode:
Diffstat (limited to 'gi')
-rw-r--r--gi/clda/src/ccrp.h33
1 files changed, 21 insertions, 12 deletions
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h
index 74d5be29..a7c2825c 100644
--- a/gi/clda/src/ccrp.h
+++ b/gi/clda/src/ccrp.h
@@ -12,8 +12,7 @@
#include "sampler.h"
#include "slice_sampler.h"
-// Chinese restaurant process (Pitman-Yor parameters) with explicit table
-// tracking.
+// Chinese restaurant process (Pitman-Yor parameters) with table tracking.
template <typename Dish, typename DishHash = boost::hash<Dish> >
class CCRP {
@@ -65,6 +64,12 @@ class CCRP {
return num_customers_;
}
+ unsigned num_customers(const Dish& dish) const {
+ const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
+ if (it == dish_locs_.end()) return 0;
+ return it->total_dish_count_;
+ }
+
// returns +1 or 0 indicating whether a new table was opened
int increment(const Dish& dish, const double& p0, MT19937* rng) {
DishLocations& loc = dish_locs_[dish];
@@ -177,17 +182,21 @@ class CCRP {
lp += log_gamma_density(concentration, concentration_prior_shape_, concentration_prior_rate_);
assert(lp <= 0.0);
if (num_customers_) {
- const double r = lgamma(1.0 - discount);
- lp += lgamma(concentration) - lgamma(concentration + num_customers_)
- + num_tables_ * discount + lgamma(concentration / discount + num_tables_)
- - lgamma(concentration / discount);
- assert(std::isfinite(lp));
- for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
- it != dish_locs_.end(); ++it) {
- const DishLocations& cur = it->second;
- for (std::list<unsigned>::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) {
- lp += lgamma(*ti - discount) - r;
+ if (discount > 0.0) {
+ const double r = lgamma(1.0 - discount);
+ lp += lgamma(concentration) - lgamma(concentration + num_customers_)
+ + num_tables_ * log(discount) + lgamma(concentration / discount + num_tables_)
+ - lgamma(concentration / discount);
+ assert(std::isfinite(lp));
+ for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
+ it != dish_locs_.end(); ++it) {
+ const DishLocations& cur = it->second;
+ for (std::list<unsigned>::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) {
+ lp += lgamma(*ti - discount) - r;
+ }
}
+ } else {
+ assert(!"not implemented yet");
}
}
assert(std::isfinite(lp));