diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-08-12 02:14:57 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-08-12 02:14:57 -0400 |
commit | 21373c28b7786b51d1e91529ebb189bfbc4c6dd8 (patch) | |
tree | f59f23568b5df4f7f82f8c0bb0f0d98ad3c971fd /utils | |
parent | 7527592aaf4245749845500aca6a7fcc97eb2f17 (diff) |
fix sparse vector api, and add crp helper class
Diffstat (limited to 'utils')
-rw-r--r-- | utils/crp_table_manager.h | 110 | ||||
-rw-r--r-- | utils/fast_sparse_vector.h | 7 |
2 files changed, 114 insertions, 3 deletions
diff --git a/utils/crp_table_manager.h b/utils/crp_table_manager.h new file mode 100644 index 00000000..32840ad8 --- /dev/null +++ b/utils/crp_table_manager.h @@ -0,0 +1,110 @@ +#ifndef _CRP_TABLE_MANAGER_H_ +#define _CRP_TABLE_MANAGER_H_ + +#include <iostream> +#include "sparse_vector.h" +#include "sampler.h" + +// these are helper classes for implementing token-based CRP samplers +// basically the data structures recommended by Blunsom et al. in the Note. + +struct CRPHistogram { + //typedef std::map<unsigned, unsigned> MAPTYPE; + typedef SparseVector<unsigned> MAPTYPE; + typedef MAPTYPE::const_iterator const_iterator; + + inline void increment(unsigned bin, unsigned delta = 1u) { + data[bin] += delta; + } + inline void decrement(unsigned bin, unsigned delta = 1u) { + unsigned r = data[bin] -= delta; + if (!r) data.erase(bin); + } + inline void move(unsigned from_bin, unsigned to_bin, unsigned delta = 1u) { + decrement(from_bin, delta); + increment(to_bin, delta); + } + inline const_iterator begin() const { return data.begin(); } + inline const_iterator end() const { return data.end(); } + + private: + MAPTYPE data; +}; + +// A CRPTableManager tracks statistics about all customers +// and tables serving some dish in a CRP and can correctly sample what +// table to remove a customer from and what table to join +struct CRPTableManager { + CRPTableManager() : customers(), tables() {} + + inline unsigned num_tables() const { + return tables; + } + + inline unsigned num_customers() const { + return customers; + } + + inline void create_table() { + h.increment(1); + ++tables; + ++customers; + } + + // seat a customer at a table proportional to the number of customers seated at a table, less the discount + // *new tables are never created by this function! + inline void share_table(const double discount, MT19937* rng) { + const double z = customers - discount * num_tables(); + double r = z * rng->next(); + const CRPHistogram::const_iterator end = h.end(); + CRPHistogram::const_iterator it = h.begin(); + for (; it != end; ++it) { + // it->first = number of customers at table + // it->second = number of such tables + double thresh = (it->first - discount) * it->second; + if (thresh > r) break; + r -= thresh; + } + h.move(it->first, it->first + 1); + ++customers; + } + + // randomly sample a customer + // *tables may be removed + // returns -1 if a table is removed, 0 otherwise + inline int remove_customer(MT19937* rng) { + int r = rng->next() * num_customers(); + const CRPHistogram::const_iterator end = h.end(); + CRPHistogram::const_iterator it = h.begin(); + for (; it != end; ++it) { + int thresh = it->first * it->second; + if (thresh > r) break; + r -= thresh; + } + --customers; + const unsigned tc = it->first; + if (tc == 1) { + h.decrement(1); + --tables; + return -1; + } else { + h.move(tc, tc - 1); + return 0; + } + } + + unsigned customers; + unsigned tables; + CRPHistogram h; +}; + +std::ostream& operator<<(std::ostream& os, const CRPTableManager& tm) { + os << '[' << tm.num_customers() << " total customers at " << tm.num_tables() << " tables ||| "; + for (CRPHistogram::const_iterator it = tm.h.begin(); it != tm.h.end(); ++it) { + if (it != tm.h.begin()) os << " -- "; + os << '(' << it->first << ") x " << it->second; + } + return os << ']'; +} + +#endif diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h index 433a5cc5..5647a2a9 100644 --- a/utils/fast_sparse_vector.h +++ b/utils/fast_sparse_vector.h @@ -194,18 +194,19 @@ class FastSparseVector { data_.rbmap = new SPARSE_HASH_MAP<unsigned, T>(first, last); } } - void erase(int k) { + void erase(unsigned k) { if (is_remote_) { data_.rbmap->erase(k); } else { - for (int i = 0; i < local_size_; ++i) { + for (unsigned i = 0; i < local_size_; ++i) { if (data_.local[i].first() == k) { - for (int j = i+1; j < local_size_; ++j) { + for (unsigned j = i+1; j < local_size_; ++j) { data_.local[j-1].first() = data_.local[j].first(); data_.local[j-1].second() = data_.local[j].second(); } } } + --local_size_; } } const FastSparseVector<T>& operator=(const FastSparseVector<T>& other) { |