summaryrefslogtreecommitdiff
path: root/gi/clda/src/ccrp.h
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-25 16:16:27 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-25 16:16:27 +0000
commitaa2a66378777d3825cc31640dff2e206491ccee3 (patch)
treedeed44f7528e70ed7710627e31f58ce50adb978a /gi/clda/src/ccrp.h
parenteb7dd39a10474a58fda0ab237e92b30f9b352d55 (diff)
unit tests for PYP sampler
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@620 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/clda/src/ccrp.h')
-rw-r--r--gi/clda/src/ccrp.h21
1 files changed, 17 insertions, 4 deletions
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h
index 47f364f2..9b1c5284 100644
--- a/gi/clda/src/ccrp.h
+++ b/gi/clda/src/ccrp.h
@@ -18,12 +18,19 @@ class CCRP {
public:
CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {}
+ void clear() {
+ num_tables_ = 0;
+ num_customers_ = 0;
+ dish_locs_.clear();
+ }
+
unsigned num_tables(const Dish& dish) const {
const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
if (it == dish_locs_.end()) return 0;
return it->second.table_counts_.size();
}
+ // returns +1 or 0 indicating whether a new table was opened
int increment(const Dish& dish, const double& p0, MT19937* rng) {
DishLocations& loc = dish_locs_[dish];
bool share_table = false;
@@ -56,6 +63,7 @@ class CCRP {
return (share_table ? 0 : 1);
}
+ // returns -1 or 0, indicating whether a table was closed
int decrement(const Dish& dish, MT19937* rng) {
DishLocations& loc = dish_locs_[dish];
assert(loc.total_dish_count_);
@@ -66,11 +74,13 @@ class CCRP {
return -1;
} else {
int delta = 0;
- double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_);
+ // sample customer to remove UNIFORMLY. that is, do NOT use the discount
+ // here. if you do, it will introduce (unwanted) bias!
+ double r = rng->next() * loc.total_dish_count_;
--loc.total_dish_count_;
for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
ti != loc.table_counts_.end(); ++ti) {
- r -= (*ti - discount_);
+ r -= *ti;
if (r <= 0.0) {
if ((--(*ti)) == 0) {
--num_tables_;
@@ -107,7 +117,9 @@ class CCRP {
double lp = 0.0;
if (num_customers_) {
const double r = lgamma(1.0 - discount_);
- lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) - lgamma(concentration_ / discount_);
+ lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_)
+ + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_)
+ - lgamma(concentration_ / discount_);
for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
it != dish_locs_.end(); ++it) {
const DishLocations& cur = it->second;
@@ -119,8 +131,9 @@ class CCRP {
struct DishLocations {
DishLocations() : total_dish_count_() {}
- unsigned total_dish_count_;
+ unsigned total_dish_count_; // customers at all tables with this dish
std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want
+ // .size() is the number of tables for this dish
};
void Print(std::ostream* out) const {