summaryrefslogtreecommitdiff
path: root/phrasinator/ccrp_nt.h
diff options
context:
space:
mode:
Diffstat (limited to 'phrasinator/ccrp_nt.h')
-rw-r--r--phrasinator/ccrp_nt.h24
1 files changed, 20 insertions, 4 deletions
diff --git a/phrasinator/ccrp_nt.h b/phrasinator/ccrp_nt.h
index 163b643a..811bce73 100644
--- a/phrasinator/ccrp_nt.h
+++ b/phrasinator/ccrp_nt.h
@@ -50,15 +50,26 @@ class CCRP_NoTable {
return it->second;
}
- void increment(const Dish& dish) {
- ++custs_[dish];
+ int increment(const Dish& dish) {
+ int table_diff = 0;
+ if (++custs_[dish] == 1)
+ table_diff = 1;
++num_customers_;
+ return table_diff;
}
- void decrement(const Dish& dish) {
- if ((--custs_[dish]) == 0)
+ int decrement(const Dish& dish) {
+ int table_diff = 0;
+ int nc = --custs_[dish];
+ if (nc == 0) {
custs_.erase(dish);
+ table_diff = -1;
+ } else if (nc < 0) {
+ std::cerr << "Dish counts dropped below zero for: " << dish << std::endl;
+ abort();
+ }
--num_customers_;
+ return table_diff;
}
double prob(const Dish& dish, const double& p0) const {
@@ -66,6 +77,11 @@ class CCRP_NoTable {
return (at_table + p0 * concentration_) / (num_customers_ + concentration_);
}
+ double logprob(const Dish& dish, const double& logp0) const {
+ const unsigned at_table = num_customers(dish);
+ return log(at_table + exp(logp0 + log(concentration_))) - log(num_customers_ + concentration_);
+ }
+
double log_crp_prob() const {
return log_crp_prob(concentration_);
}