From 926cedc80db13840263d373f17bdec762b9fd551 Mon Sep 17 00:00:00 2001 From: redpony Date: Mon, 30 Aug 2010 22:52:15 +0000 Subject: fix llh git-svn-id: https://ws10smt.googlecode.com/svn/trunk@633 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/clda/src/ccrp.h | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) (limited to 'gi') diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h index 74d5be29..a7c2825c 100644 --- a/gi/clda/src/ccrp.h +++ b/gi/clda/src/ccrp.h @@ -12,8 +12,7 @@ #include "sampler.h" #include "slice_sampler.h" -// Chinese restaurant process (Pitman-Yor parameters) with explicit table -// tracking. +// Chinese restaurant process (Pitman-Yor parameters) with table tracking. template > class CCRP { @@ -65,6 +64,12 @@ class CCRP { return num_customers_; } + unsigned num_customers(const Dish& dish) const { + const typename std::tr1::unordered_map::const_iterator it = dish_locs_.find(dish); + if (it == dish_locs_.end()) return 0; + return it->total_dish_count_; + } + // returns +1 or 0 indicating whether a new table was opened int increment(const Dish& dish, const double& p0, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; @@ -177,17 +182,21 @@ class CCRP { lp += log_gamma_density(concentration, concentration_prior_shape_, concentration_prior_rate_); assert(lp <= 0.0); if (num_customers_) { - const double r = lgamma(1.0 - discount); - lp += lgamma(concentration) - lgamma(concentration + num_customers_) - + num_tables_ * discount + lgamma(concentration / discount + num_tables_) - - lgamma(concentration / discount); - assert(std::isfinite(lp)); - for (typename std::tr1::unordered_map::const_iterator it = dish_locs_.begin(); - it != dish_locs_.end(); ++it) { - const DishLocations& cur = it->second; - for (std::list::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) { - lp += lgamma(*ti - discount) - r; + if (discount > 0.0) { + const double r = lgamma(1.0 - discount); + lp += lgamma(concentration) - lgamma(concentration + num_customers_) + + num_tables_ * log(discount) + lgamma(concentration / discount + num_tables_) + - lgamma(concentration / discount); + assert(std::isfinite(lp)); + for (typename std::tr1::unordered_map::const_iterator it = dish_locs_.begin(); + it != dish_locs_.end(); ++it) { + const DishLocations& cur = it->second; + for (std::list::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) { + lp += lgamma(*ti - discount) - r; + } } + } else { + assert(!"not implemented yet"); } } assert(std::isfinite(lp)); -- cgit v1.2.3