From ffd4f84d33109e4146ad0f3fd5c6100e02150888 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 12 Aug 2012 02:14:57 -0400 Subject: fix sparse vector api, and add crp helper class --- utils/crp_table_manager.h | 110 +++++++++++++++++++++++++++++++++++++++++++++ utils/fast_sparse_vector.h | 7 +-- 2 files changed, 114 insertions(+), 3 deletions(-) create mode 100644 utils/crp_table_manager.h (limited to 'utils') diff --git a/utils/crp_table_manager.h b/utils/crp_table_manager.h new file mode 100644 index 00000000..32840ad8 --- /dev/null +++ b/utils/crp_table_manager.h @@ -0,0 +1,110 @@ +#ifndef _CRP_TABLE_MANAGER_H_ +#define _CRP_TABLE_MANAGER_H_ + +#include +#include "sparse_vector.h" +#include "sampler.h" + +// these are helper classes for implementing token-based CRP samplers +// basically the data structures recommended by Blunsom et al. in the Note. + +struct CRPHistogram { + //typedef std::map MAPTYPE; + typedef SparseVector MAPTYPE; + typedef MAPTYPE::const_iterator const_iterator; + + inline void increment(unsigned bin, unsigned delta = 1u) { + data[bin] += delta; + } + inline void decrement(unsigned bin, unsigned delta = 1u) { + unsigned r = data[bin] -= delta; + if (!r) data.erase(bin); + } + inline void move(unsigned from_bin, unsigned to_bin, unsigned delta = 1u) { + decrement(from_bin, delta); + increment(to_bin, delta); + } + inline const_iterator begin() const { return data.begin(); } + inline const_iterator end() const { return data.end(); } + + private: + MAPTYPE data; +}; + +// A CRPTableManager tracks statistics about all customers +// and tables serving some dish in a CRP and can correctly sample what +// table to remove a customer from and what table to join +struct CRPTableManager { + CRPTableManager() : customers(), tables() {} + + inline unsigned num_tables() const { + return tables; + } + + inline unsigned num_customers() const { + return customers; + } + + inline void create_table() { + h.increment(1); + ++tables; + ++customers; + } + + // seat a customer at a table proportional to the number of customers seated at a table, less the discount + // *new tables are never created by this function! + inline void share_table(const double discount, MT19937* rng) { + const double z = customers - discount * num_tables(); + double r = z * rng->next(); + const CRPHistogram::const_iterator end = h.end(); + CRPHistogram::const_iterator it = h.begin(); + for (; it != end; ++it) { + // it->first = number of customers at table + // it->second = number of such tables + double thresh = (it->first - discount) * it->second; + if (thresh > r) break; + r -= thresh; + } + h.move(it->first, it->first + 1); + ++customers; + } + + // randomly sample a customer + // *tables may be removed + // returns -1 if a table is removed, 0 otherwise + inline int remove_customer(MT19937* rng) { + int r = rng->next() * num_customers(); + const CRPHistogram::const_iterator end = h.end(); + CRPHistogram::const_iterator it = h.begin(); + for (; it != end; ++it) { + int thresh = it->first * it->second; + if (thresh > r) break; + r -= thresh; + } + --customers; + const unsigned tc = it->first; + if (tc == 1) { + h.decrement(1); + --tables; + return -1; + } else { + h.move(tc, tc - 1); + return 0; + } + } + + unsigned customers; + unsigned tables; + CRPHistogram h; +}; + +std::ostream& operator<<(std::ostream& os, const CRPTableManager& tm) { + os << '[' << tm.num_customers() << " total customers at " << tm.num_tables() << " tables ||| "; + for (CRPHistogram::const_iterator it = tm.h.begin(); it != tm.h.end(); ++it) { + if (it != tm.h.begin()) os << " -- "; + os << '(' << it->first << ") x " << it->second; + } + return os << ']'; +} + +#endif diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h index 433a5cc5..5647a2a9 100644 --- a/utils/fast_sparse_vector.h +++ b/utils/fast_sparse_vector.h @@ -194,18 +194,19 @@ class FastSparseVector { data_.rbmap = new SPARSE_HASH_MAP(first, last); } } - void erase(int k) { + void erase(unsigned k) { if (is_remote_) { data_.rbmap->erase(k); } else { - for (int i = 0; i < local_size_; ++i) { + for (unsigned i = 0; i < local_size_; ++i) { if (data_.local[i].first() == k) { - for (int j = i+1; j < local_size_; ++j) { + for (unsigned j = i+1; j < local_size_; ++j) { data_.local[j-1].first() = data_.local[j].first(); data_.local[j-1].second() = data_.local[j].second(); } } } + --local_size_; } } const FastSparseVector& operator=(const FastSparseVector& other) { -- cgit v1.2.3