summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
Diffstat (limited to 'utils')
-rw-r--r--utils/crp_table_manager.h110
-rw-r--r--utils/fast_sparse_vector.h7
2 files changed, 114 insertions, 3 deletions
diff --git a/utils/crp_table_manager.h b/utils/crp_table_manager.h
new file mode 100644
index 00000000..32840ad8
--- /dev/null
+++ b/utils/crp_table_manager.h
@@ -0,0 +1,110 @@
+#ifndef _CRP_TABLE_MANAGER_H_
+#define _CRP_TABLE_MANAGER_H_
+
+#include <iostream>
+#include "sparse_vector.h"
+#include "sampler.h"
+
+// these are helper classes for implementing token-based CRP samplers
+// basically the data structures recommended by Blunsom et al. in the Note.
+
+struct CRPHistogram {
+ //typedef std::map<unsigned, unsigned> MAPTYPE;
+ typedef SparseVector<unsigned> MAPTYPE;
+ typedef MAPTYPE::const_iterator const_iterator;
+
+ inline void increment(unsigned bin, unsigned delta = 1u) {
+ data[bin] += delta;
+ }
+ inline void decrement(unsigned bin, unsigned delta = 1u) {
+ unsigned r = data[bin] -= delta;
+ if (!r) data.erase(bin);
+ }
+ inline void move(unsigned from_bin, unsigned to_bin, unsigned delta = 1u) {
+ decrement(from_bin, delta);
+ increment(to_bin, delta);
+ }
+ inline const_iterator begin() const { return data.begin(); }
+ inline const_iterator end() const { return data.end(); }
+
+ private:
+ MAPTYPE data;
+};
+
+// A CRPTableManager tracks statistics about all customers
+// and tables serving some dish in a CRP and can correctly sample what
+// table to remove a customer from and what table to join
+struct CRPTableManager {
+ CRPTableManager() : customers(), tables() {}
+
+ inline unsigned num_tables() const {
+ return tables;
+ }
+
+ inline unsigned num_customers() const {
+ return customers;
+ }
+
+ inline void create_table() {
+ h.increment(1);
+ ++tables;
+ ++customers;
+ }
+
+ // seat a customer at a table proportional to the number of customers seated at a table, less the discount
+ // *new tables are never created by this function!
+ inline void share_table(const double discount, MT19937* rng) {
+ const double z = customers - discount * num_tables();
+ double r = z * rng->next();
+ const CRPHistogram::const_iterator end = h.end();
+ CRPHistogram::const_iterator it = h.begin();
+ for (; it != end; ++it) {
+ // it->first = number of customers at table
+ // it->second = number of such tables
+ double thresh = (it->first - discount) * it->second;
+ if (thresh > r) break;
+ r -= thresh;
+ }
+ h.move(it->first, it->first + 1);
+ ++customers;
+ }
+
+ // randomly sample a customer
+ // *tables may be removed
+ // returns -1 if a table is removed, 0 otherwise
+ inline int remove_customer(MT19937* rng) {
+ int r = rng->next() * num_customers();
+ const CRPHistogram::const_iterator end = h.end();
+ CRPHistogram::const_iterator it = h.begin();
+ for (; it != end; ++it) {
+ int thresh = it->first * it->second;
+ if (thresh > r) break;
+ r -= thresh;
+ }
+ --customers;
+ const unsigned tc = it->first;
+ if (tc == 1) {
+ h.decrement(1);
+ --tables;
+ return -1;
+ } else {
+ h.move(tc, tc - 1);
+ return 0;
+ }
+ }
+
+ unsigned customers;
+ unsigned tables;
+ CRPHistogram h;
+};
+
+std::ostream& operator<<(std::ostream& os, const CRPTableManager& tm) {
+ os << '[' << tm.num_customers() << " total customers at " << tm.num_tables() << " tables ||| ";
+ for (CRPHistogram::const_iterator it = tm.h.begin(); it != tm.h.end(); ++it) {
+ if (it != tm.h.begin()) os << " -- ";
+ os << '(' << it->first << ") x " << it->second;
+ }
+ return os << ']';
+}
+
+#endif
diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h
index 433a5cc5..5647a2a9 100644
--- a/utils/fast_sparse_vector.h
+++ b/utils/fast_sparse_vector.h
@@ -194,18 +194,19 @@ class FastSparseVector {
data_.rbmap = new SPARSE_HASH_MAP<unsigned, T>(first, last);
}
}
- void erase(int k) {
+ void erase(unsigned k) {
if (is_remote_) {
data_.rbmap->erase(k);
} else {
- for (int i = 0; i < local_size_; ++i) {
+ for (unsigned i = 0; i < local_size_; ++i) {
if (data_.local[i].first() == k) {
- for (int j = i+1; j < local_size_; ++j) {
+ for (unsigned j = i+1; j < local_size_; ++j) {
data_.local[j-1].first() = data_.local[j].first();
data_.local[j-1].second() = data_.local[j].second();
}
}
}
+ --local_size_;
}
}
const FastSparseVector<T>& operator=(const FastSparseVector<T>& other) {