From 465abfdb07652238291e807f709292a9ff066366 Mon Sep 17 00:00:00 2001
From: Chris Dyer <cdyer@cs.cmu.edu>
Date: Sun, 12 Aug 2012 02:53:28 -0400
Subject: clean up CRP code to use faster data structure

---
 utils/ccrp.h              | 103 +++++++++++++---------------------------------
 utils/crp_table_manager.h |   6 ++-
 2 files changed, 33 insertions(+), 76 deletions(-)

(limited to 'utils')

diff --git a/utils/ccrp.h b/utils/ccrp.h
index 1d41a3ef..f5d3fc78 100644
--- a/utils/ccrp.h
+++ b/utils/ccrp.h
@@ -11,6 +11,7 @@
 #include <boost/functional/hash.hpp>
 #include "sampler.h"
 #include "slice_sampler.h"
+#include "crp_table_manager.h"
 #include "m.h"
 
 // Chinese restaurant process (Pitman-Yor parameters) with table tracking.
@@ -81,9 +82,9 @@ class CCRP {
   }
 
   unsigned num_tables(const Dish& dish) const {
-    const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
+    const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish);
     if (it == dish_locs_.end()) return 0;
-    return it->second.table_counts_.size();
+    return it->second.num_tables();
   }
 
   unsigned num_customers() const {
@@ -91,9 +92,9 @@ class CCRP {
   }
 
   unsigned num_customers(const Dish& dish) const {
-    const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
+    const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish);
     if (it == dish_locs_.end()) return 0;
-    return it->total_dish_count_;
+    return it->num_customers();
   }
 
   // returns +1 or 0 indicating whether a new table was opened
@@ -101,84 +102,48 @@ class CCRP {
   //       excluding p0
   template <typename T>
   int increment(const Dish& dish, const T& p0, MT19937* rng, T* p = NULL) {
-    DishLocations& loc = dish_locs_[dish];
+    CRPTableManager& loc = dish_locs_[dish];
     bool share_table = false;
-    if (loc.total_dish_count_) {
+    if (loc.num_customers()) {
       const T p_empty = T(strength_ + num_tables_ * discount_) * p0;
-      const T p_share = T(loc.total_dish_count_ - loc.table_counts_.size() * discount_);
+      const T p_share = T(loc.num_customers() - loc.num_tables() * discount_);
       share_table = rng->SelectSample(p_empty, p_share);
     }
     if (share_table) {
-      double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_);
-      for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
-           ti != loc.table_counts_.end(); ++ti) {
-        r -= (*ti - discount_);
-        if (r <= 0.0) {
-          if (p) { *p = T(*ti - discount_) / T(strength_ + num_customers_); }
-          ++(*ti);
-          break;
-        }
-      }
-      if (r > 0.0) {
-        std::cerr << "Serious error: r=" << r << std::endl;
-        Print(&std::cerr);
-        assert(r <= 0.0);
-      }
+      loc.share_table(discount_, rng);
     } else {
-      loc.table_counts_.push_back(1u);
-      if (p) { *p = T(strength_ + discount_ * num_tables_) / T(strength_ + num_customers_); }
+      loc.create_table();
       ++num_tables_;
     }
-    ++loc.total_dish_count_;
     ++num_customers_;
     return (share_table ? 0 : 1);
   }
 
   // returns -1 or 0, indicating whether a table was closed
   int decrement(const Dish& dish, MT19937* rng) {
-    DishLocations& loc = dish_locs_[dish];
-    assert(loc.total_dish_count_);
-    if (loc.total_dish_count_ == 1) {
+    CRPTableManager& loc = dish_locs_[dish];
+    assert(loc.num_customers());
+    if (loc.num_customers() == 1) {
       dish_locs_.erase(dish);
       --num_tables_;
       --num_customers_;
       return -1;
     } else {
-      int delta = 0;
-      // sample customer to remove UNIFORMLY. that is, do NOT use the discount
-      // here. if you do, it will introduce (unwanted) bias!
-      double r = rng->next() * loc.total_dish_count_;
-      --loc.total_dish_count_;
-      for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
-           ti != loc.table_counts_.end(); ++ti) {
-        r -= *ti;
-        if (r <= 0.0) {
-          if ((--(*ti)) == 0) {
-            --num_tables_;
-            delta = -1;
-            loc.table_counts_.erase(ti);
-          }
-          break;
-        }
-      }
-      if (r > 0.0) {
-        std::cerr << "Serious error: r=" << r << std::endl;
-        Print(&std::cerr);
-        assert(r <= 0.0);
-      }
+      int delta = loc.remove_customer(rng);
       --num_customers_;
+      if (delta) --num_tables_;
       return delta;
     }
   }
 
   template <typename T>
   T prob(const Dish& dish, const T& p0) const {
-    const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
+    const typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.find(dish);
     const T r = T(num_tables_ * discount_ + strength_);
     if (it == dish_locs_.end()) {
       return r * p0 / T(num_customers_ + strength_);
     } else {
-      return (T(it->second.total_dish_count_ - discount_ * it->second.table_counts_.size()) + r * p0) /
+      return (T(it->second.num_customers() - discount_ * it->second.num_tables()) + r * p0) /
                T(num_customers_ + strength_);
     }
   }
@@ -204,20 +169,20 @@ class CCRP {
         lp += - lgamma(strength + num_customers_)
              + num_tables_ * log(discount) + lgamma(strength / discount + num_tables_);
         assert(std::isfinite(lp));
-        for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
+        for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin();
              it != dish_locs_.end(); ++it) {
-          const DishLocations& cur = it->second;
-          for (std::list<unsigned>::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) {
-            lp += lgamma(*ti - discount) - r;
+          const CRPTableManager& cur = it->second;  // TODO check
+          for (CRPTableManager::const_iterator ti = cur.begin(); ti != cur.end(); ++ti) {
+            lp += (lgamma(ti->first - discount) - r) * ti->second;
           }
         }
       } else if (!discount) { // discount == 0.0
         lp += lgamma(strength) + num_tables_ * log(strength) - lgamma(strength + num_tables_);
         assert(std::isfinite(lp));
-        for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
+        for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin();
              it != dish_locs_.end(); ++it) {
-          const DishLocations& cur = it->second;
-          lp += lgamma(cur.table_counts_.size());
+          const CRPTableManager& cur = it->second;
+          lp += lgamma(cur.num_tables());
         }
       } else {
         assert(!"discount less than 0 detected!");
@@ -264,27 +229,15 @@ class CCRP {
     }
   };
 
-  struct DishLocations {
-    DishLocations() : total_dish_count_() {}
-    unsigned total_dish_count_;        // customers at all tables with this dish
-    std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want
-                                       // .size() is the number of tables for this dish
-  };
-
   void Print(std::ostream* out) const {
     std::cerr << "PYP(d=" << discount_ << ",c=" << strength_ << ") customers=" << num_customers_ << std::endl;
-    for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
+    for (typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator it = dish_locs_.begin();
          it != dish_locs_.end(); ++it) {
-      (*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): ";
-      for (typename std::list<unsigned>::const_iterator i = it->second.table_counts_.begin();
-           i != it->second.table_counts_.end(); ++i) {
-        (*out) << " " << *i;
-      }
-      (*out) << std::endl;
+      (*out) << it->first << " : " << it->second << std::endl;
     }
   }
 
-  typedef typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator const_iterator;
+  typedef typename std::tr1::unordered_map<Dish, CRPTableManager, DishHash>::const_iterator const_iterator;
   const_iterator begin() const {
     return dish_locs_.begin();
   }
@@ -294,7 +247,7 @@ class CCRP {
 
   unsigned num_tables_;
   unsigned num_customers_;
-  std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_;
+  std::tr1::unordered_map<Dish, CRPTableManager, DishHash> dish_locs_;
 
   double discount_;
   double strength_;
diff --git a/utils/crp_table_manager.h b/utils/crp_table_manager.h
index 32840ad8..753e721f 100644
--- a/utils/crp_table_manager.h
+++ b/utils/crp_table_manager.h
@@ -93,6 +93,10 @@ struct CRPTableManager {
     }
   }
 
+  typedef CRPHistogram::const_iterator const_iterator;
+  const_iterator begin() const { return h.begin(); }
+  const_iterator end() const { return h.end(); }
+
   unsigned customers;
   unsigned tables;
   CRPHistogram h;
@@ -100,7 +104,7 @@ struct CRPTableManager {
 
 std::ostream& operator<<(std::ostream& os, const CRPTableManager& tm) {
   os << '[' << tm.num_customers() << " total customers at " << tm.num_tables() << " tables ||| ";
-  for (CRPHistogram::const_iterator it = tm.h.begin(); it != tm.h.end(); ++it) {
+  for (CRPHistogram::const_iterator it = tm.begin(); it != tm.end(); ++it) {
     if (it != tm.h.begin()) os << "  --  ";
     os << '(' << it->first << ") x " << it->second;
   }
-- 
cgit v1.2.3