diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-25 16:16:27 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-25 16:16:27 +0000 |
commit | 9981b620b08704211d498c01a8808a3cc523fe2b (patch) | |
tree | 5e3612e1ba586d66f1918fdf49f5496c33871155 /gi/clda/src/ccrp.h | |
parent | ea718017bab2e9ec0b3e131b0fe0e6f5a112cd16 (diff) |
unit tests for PYP sampler
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@620 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/clda/src/ccrp.h')
-rw-r--r-- | gi/clda/src/ccrp.h | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h index 47f364f2..9b1c5284 100644 --- a/gi/clda/src/ccrp.h +++ b/gi/clda/src/ccrp.h @@ -18,12 +18,19 @@ class CCRP { public: CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {} + void clear() { + num_tables_ = 0; + num_customers_ = 0; + dish_locs_.clear(); + } + unsigned num_tables(const Dish& dish) const { const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); if (it == dish_locs_.end()) return 0; return it->second.table_counts_.size(); } + // returns +1 or 0 indicating whether a new table was opened int increment(const Dish& dish, const double& p0, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; bool share_table = false; @@ -56,6 +63,7 @@ class CCRP { return (share_table ? 0 : 1); } + // returns -1 or 0, indicating whether a table was closed int decrement(const Dish& dish, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; assert(loc.total_dish_count_); @@ -66,11 +74,13 @@ class CCRP { return -1; } else { int delta = 0; - double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); + // sample customer to remove UNIFORMLY. that is, do NOT use the discount + // here. if you do, it will introduce (unwanted) bias! + double r = rng->next() * loc.total_dish_count_; --loc.total_dish_count_; for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); ti != loc.table_counts_.end(); ++ti) { - r -= (*ti - discount_); + r -= *ti; if (r <= 0.0) { if ((--(*ti)) == 0) { --num_tables_; @@ -107,7 +117,9 @@ class CCRP { double lp = 0.0; if (num_customers_) { const double r = lgamma(1.0 - discount_); - lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) - lgamma(concentration_ / discount_); + lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) + + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) + - lgamma(concentration_ / discount_); for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { const DishLocations& cur = it->second; @@ -119,8 +131,9 @@ class CCRP { struct DishLocations { DishLocations() : total_dish_count_() {} - unsigned total_dish_count_; + unsigned total_dish_count_; // customers at all tables with this dish std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want + // .size() is the number of tables for this dish }; void Print(std::ostream* out) const { |