diff options
Diffstat (limited to 'gi')
| -rw-r--r-- | gi/clda/src/ccrp.h | 33 | 
1 files changed, 21 insertions, 12 deletions
| diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h index 74d5be29..a7c2825c 100644 --- a/gi/clda/src/ccrp.h +++ b/gi/clda/src/ccrp.h @@ -12,8 +12,7 @@  #include "sampler.h"  #include "slice_sampler.h" -// Chinese restaurant process (Pitman-Yor parameters) with explicit table -// tracking. +// Chinese restaurant process (Pitman-Yor parameters) with table tracking.  template <typename Dish, typename DishHash = boost::hash<Dish> >  class CCRP { @@ -65,6 +64,12 @@ class CCRP {      return num_customers_;    } +  unsigned num_customers(const Dish& dish) const { +    const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); +    if (it == dish_locs_.end()) return 0; +    return it->total_dish_count_; +  } +    // returns +1 or 0 indicating whether a new table was opened    int increment(const Dish& dish, const double& p0, MT19937* rng) {      DishLocations& loc = dish_locs_[dish]; @@ -177,17 +182,21 @@ class CCRP {        lp += log_gamma_density(concentration, concentration_prior_shape_, concentration_prior_rate_);      assert(lp <= 0.0);      if (num_customers_) { -      const double r = lgamma(1.0 - discount); -      lp += lgamma(concentration) - lgamma(concentration + num_customers_) -           + num_tables_ * discount + lgamma(concentration / discount + num_tables_) -           - lgamma(concentration / discount); -      assert(std::isfinite(lp)); -      for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); -           it != dish_locs_.end(); ++it) { -        const DishLocations& cur = it->second; -        for (std::list<unsigned>::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) { -          lp += lgamma(*ti - discount) - r; +      if (discount > 0.0) { +        const double r = lgamma(1.0 - discount); +        lp += lgamma(concentration) - lgamma(concentration + num_customers_) +             + num_tables_ * log(discount) + lgamma(concentration / discount + num_tables_) +             - lgamma(concentration / discount); +        assert(std::isfinite(lp)); +        for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); +             it != dish_locs_.end(); ++it) { +          const DishLocations& cur = it->second; +          for (std::list<unsigned>::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) { +            lp += lgamma(*ti - discount) - r; +          }          } +      } else { +        assert(!"not implemented yet");        }      }      assert(std::isfinite(lp)); | 
