diff options
Diffstat (limited to 'gi/clda/src/ccrp.h')
| -rw-r--r-- | gi/clda/src/ccrp.h | 21 | 
1 files changed, 17 insertions, 4 deletions
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h index 47f364f2..9b1c5284 100644 --- a/gi/clda/src/ccrp.h +++ b/gi/clda/src/ccrp.h @@ -18,12 +18,19 @@ class CCRP {   public:    CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {} +  void clear() { +    num_tables_ = 0; +    num_customers_ = 0; +    dish_locs_.clear(); +  } +    unsigned num_tables(const Dish& dish) const {      const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);      if (it == dish_locs_.end()) return 0;      return it->second.table_counts_.size();    } +  // returns +1 or 0 indicating whether a new table was opened    int increment(const Dish& dish, const double& p0, MT19937* rng) {      DishLocations& loc = dish_locs_[dish];      bool share_table = false; @@ -56,6 +63,7 @@ class CCRP {      return (share_table ? 0 : 1);    } +  // returns -1 or 0, indicating whether a table was closed    int decrement(const Dish& dish, MT19937* rng) {      DishLocations& loc = dish_locs_[dish];      assert(loc.total_dish_count_); @@ -66,11 +74,13 @@ class CCRP {        return -1;      } else {        int delta = 0; -      double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); +      // sample customer to remove UNIFORMLY. that is, do NOT use the discount +      // here. if you do, it will introduce (unwanted) bias! +      double r = rng->next() * loc.total_dish_count_;        --loc.total_dish_count_;        for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();             ti != loc.table_counts_.end(); ++ti) { -        r -= (*ti - discount_); +        r -= *ti;          if (r <= 0.0) {            if ((--(*ti)) == 0) {              --num_tables_; @@ -107,7 +117,9 @@ class CCRP {      double lp = 0.0;      if (num_customers_) {        const double r = lgamma(1.0 - discount_); -      lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) - lgamma(concentration_ / discount_); +      lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) +           + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) +           - lgamma(concentration_ / discount_);        for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();             it != dish_locs_.end(); ++it) {          const DishLocations& cur = it->second; @@ -119,8 +131,9 @@ class CCRP {    struct DishLocations {      DishLocations() : total_dish_count_() {} -    unsigned total_dish_count_; +    unsigned total_dish_count_;        // customers at all tables with this dish      std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want +                                       // .size() is the number of tables for this dish    };    void Print(std::ostream* out) const {  | 
