From 21373c28b7786b51d1e91529ebb189bfbc4c6dd8 Mon Sep 17 00:00:00 2001
From: Chris Dyer <cdyer@cs.cmu.edu>
Date: Sun, 12 Aug 2012 02:14:57 -0400
Subject: fix sparse vector api, and add crp helper class

---
 utils/crp_table_manager.h  | 110 +++++++++++++++++++++++++++++++++++++++++++++
 utils/fast_sparse_vector.h |   7 +--
 2 files changed, 114 insertions(+), 3 deletions(-)
 create mode 100644 utils/crp_table_manager.h

(limited to 'utils')

diff --git a/utils/crp_table_manager.h b/utils/crp_table_manager.h
new file mode 100644
index 00000000..32840ad8
--- /dev/null
+++ b/utils/crp_table_manager.h
@@ -0,0 +1,110 @@
+#ifndef _CRP_TABLE_MANAGER_H_
+#define _CRP_TABLE_MANAGER_H_
+
+#include <iostream>
+#include "sparse_vector.h"
+#include "sampler.h"
+
+// these are helper classes for implementing token-based CRP samplers
+// basically the data structures recommended by Blunsom et al. in the Note.
+
+struct CRPHistogram {
+  //typedef std::map<unsigned, unsigned> MAPTYPE;
+  typedef SparseVector<unsigned> MAPTYPE;
+  typedef MAPTYPE::const_iterator const_iterator;
+
+  inline void increment(unsigned bin, unsigned delta = 1u) {
+    data[bin] += delta;
+  }
+  inline void decrement(unsigned bin, unsigned delta = 1u) {
+    unsigned r = data[bin] -= delta;
+    if (!r) data.erase(bin);
+  }
+  inline void move(unsigned from_bin, unsigned to_bin, unsigned delta = 1u) {
+    decrement(from_bin, delta);
+    increment(to_bin, delta);
+  }
+  inline const_iterator begin() const { return data.begin(); }
+  inline const_iterator end() const { return data.end(); }
+
+ private:
+  MAPTYPE data;
+};
+
+// A CRPTableManager tracks statistics about all customers
+// and tables serving some dish in a CRP and can correctly sample what
+// table to remove a customer from and what table to join
+struct CRPTableManager {
+  CRPTableManager() : customers(), tables() {}
+
+  inline unsigned num_tables() const {
+    return tables;
+  }
+
+  inline unsigned num_customers() const {
+    return customers;
+  }
+
+  inline void create_table() {
+    h.increment(1);
+    ++tables;
+    ++customers;
+  }
+
+  // seat a customer at a table proportional to the number of customers seated at a table, less the discount
+  // *new tables are never created by this function!
+  inline void share_table(const double discount, MT19937* rng) {
+    const double z = customers - discount * num_tables();
+    double r = z * rng->next();
+    const CRPHistogram::const_iterator end = h.end();
+    CRPHistogram::const_iterator it = h.begin();
+    for (; it != end; ++it) {
+      // it->first = number of customers at table
+      // it->second = number of such tables
+      double thresh = (it->first - discount) * it->second;
+      if (thresh > r) break;
+      r -= thresh;
+    }
+    h.move(it->first, it->first + 1);
+    ++customers;
+  }
+
+  // randomly sample a customer
+  // *tables may be removed
+  // returns -1 if a table is removed, 0 otherwise
+  inline int remove_customer(MT19937* rng) {
+    int r = rng->next() * num_customers();
+    const CRPHistogram::const_iterator end = h.end();
+    CRPHistogram::const_iterator it = h.begin();
+    for (; it != end; ++it) {
+      int thresh = it->first * it->second;
+      if (thresh > r) break;
+      r -= thresh;
+    }
+    --customers;
+    const unsigned tc = it->first;
+    if (tc == 1) {
+      h.decrement(1);
+      --tables;
+      return -1;
+    } else {
+      h.move(tc, tc - 1);
+      return 0;
+    }
+  }
+
+  unsigned customers;
+  unsigned tables;
+  CRPHistogram h;
+};
+
+std::ostream& operator<<(std::ostream& os, const CRPTableManager& tm) {
+  os << '[' << tm.num_customers() << " total customers at " << tm.num_tables() << " tables ||| ";
+  for (CRPHistogram::const_iterator it = tm.h.begin(); it != tm.h.end(); ++it) {
+    if (it != tm.h.begin()) os << "  --  ";
+    os << '(' << it->first << ") x " << it->second;
+  }
+  return os << ']';
+}
+
+#endif
diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h
index 433a5cc5..5647a2a9 100644
--- a/utils/fast_sparse_vector.h
+++ b/utils/fast_sparse_vector.h
@@ -194,18 +194,19 @@ class FastSparseVector {
       data_.rbmap = new SPARSE_HASH_MAP<unsigned, T>(first, last);
     }
   }
-  void erase(int k) {
+  void erase(unsigned k) {
     if (is_remote_) {
       data_.rbmap->erase(k);
     } else {
-      for (int i = 0; i < local_size_; ++i) {
+      for (unsigned i = 0; i < local_size_; ++i) {
         if (data_.local[i].first() == k) {
-          for (int j = i+1; j < local_size_; ++j) {
+          for (unsigned j = i+1; j < local_size_; ++j) {
             data_.local[j-1].first() = data_.local[j].first();
             data_.local[j-1].second() = data_.local[j].second();
           }
         }
       }
+      --local_size_;
     }
   }
   const FastSparseVector<T>& operator=(const FastSparseVector<T>& other) {
-- 
cgit v1.2.3