diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-23 04:06:56 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-23 04:06:56 +0000 |
commit | e011e33a5ba5f1c7ef213009d2eec8a30cacd679 (patch) | |
tree | 02f5685b66f7c210269b6cae15306c99456c4b01 /gi/clda/src/crp.h | |
parent | a9c9f9f823cacb8e94d88838d08cf7daa8c4c82e (diff) |
clean up
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@9 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/clda/src/crp.h')
-rw-r--r-- | gi/clda/src/crp.h | 162 |
1 files changed, 0 insertions, 162 deletions
diff --git a/gi/clda/src/crp.h b/gi/clda/src/crp.h index 13596cbf..865c41ac 100644 --- a/gi/clda/src/crp.h +++ b/gi/clda/src/crp.h @@ -2,20 +2,12 @@ #define _CRP_H_ // shamelessly adapted from code by Phil Blunsom and Trevor Cohn -// There are TWO CRP classes here: CRPWithTableTracking tracks the -// (expected) number of customers per table, and CRP just tracks -// the number of customers / dish. -// If you are implementing a HDP model, you should use CRP for the -// base distribution and CRPWithTableTracking for the dependent -// distribution. -#include <iostream> #include <map> #include <boost/functional/hash.hpp> #include <tr1/unordered_map> #include "prob.h" -#include "sampler.h" // RNG template <typename DishType, typename Hash = boost::hash<DishType> > class CRP { @@ -59,158 +51,4 @@ void CRP<Dish,Hash>::decrement(const Dish& dish) { --total_customers_; } -template <typename DishType, typename Hash = boost::hash<DishType>, typename RNG = MT19937> -class CRPWithTableTracking { - public: - CRPWithTableTracking(double alpha, RNG* rng) : - alpha_(alpha), palpha_(alpha), total_customers_(), - total_tables_(), rng_(rng) {} - - // seat a customer for dish d, returns the delta in tables - // with customers - int increment(const DishType& d, const prob_t& p0 = prob_t::One()); - int decrement(const DishType& d); - void erase(const DishType& dish); - - inline int count(const DishType& dish) const { - const typename MapType::const_iterator i = counts_.find(dish); - if (i == counts_.end()) return 0; else return i->second.count_; - } - inline prob_t prob(const DishType& dish) const { - return (prob_t(count(dish) + alpha_)) / prob_t(total_customers_ + alpha_); - } - inline prob_t prob(const DishType& dish, const prob_t& p0) const { - return (prob_t(count(dish)) + palpha_ * p0) / prob_t(total_customers_ + alpha_); - } - private: - struct TableInfo { - TableInfo() : count_(), tables_() {} - int count_; // total customers eating dish - int tables_; // total tables labeled with dish - std::map<int, int> table_histogram_; // num customers at table -> number tables - }; - typedef std::tr1::unordered_map<DishType, TableInfo, Hash> MapType; - - inline prob_t prob_share_table(const double& customer_count) const { - if (customer_count) - return prob_t(customer_count) / prob_t(customer_count + alpha_); - else - return prob_t::Zero(); - } - inline prob_t prob_new_table(const double& customer_count, const prob_t& p0) const { - if (customer_count) - return palpha_ * p0 / prob_t(customer_count + alpha_); - else - return p0; - } - - MapType counts_; - const double alpha_; - const prob_t palpha_; - int total_customers_; - int total_tables_; - RNG* rng_; -}; - -template <typename Dish, typename Hash, typename RNG> -int CRPWithTableTracking<Dish,Hash,RNG>::increment(const Dish& dish, const prob_t& p0) { - TableInfo& tc = counts_[dish]; - - //std::cerr << "\nincrement for " << dish << " with p0 " << p0 << "\n"; - //std::cerr << "\tBEFORE histogram: " << tc.table_histogram_ << " "; - //std::cerr << "count: " << tc.count_ << " "; - //std::cerr << "tables: " << tc.tables_ << "\n"; - - // seated at a new or existing table? - prob_t pshare = prob_share_table(tc.count_); - prob_t pnew = prob_new_table(tc.count_, p0); - - //std::cerr << "\t\tP0 " << p0 << " count(dish) " << count(dish) - // << " tables " << tc - // << " p(share) " << pshare << " p(new) " << pnew << "\n"; - - int delta = 0; - if (tc.count_ == 0 || rng_->SelectSample(pshare, pnew) == 1) { - // assign to a new table - ++tc.tables_; - ++tc.table_histogram_[1]; - ++total_tables_; - delta = 1; - } else { - // can't share a table if there are no other customers - assert(tc.count_ > 0); - - // randomly assign to an existing table - // remove constant denominator from inner loop - int r = static_cast<int>(rng_->next() * tc.count_); - for (std::map<int,int>::iterator hit = tc.table_histogram_.begin(); - hit != tc.table_histogram_.end(); ++hit) { - r -= hit->first * hit->second; - if (r <= 0) { - ++tc.table_histogram_[hit->first+1]; - --hit->second; - if (hit->second == 0) - tc.table_histogram_.erase(hit); - break; - } - } - if (r > 0) { - std::cerr << "CONSISTENCY ERROR: " << tc.count_ << std::endl; - std::cerr << pshare << std::endl; - std::cerr << pnew << std::endl; - std::cerr << r << std::endl; - abort(); - } - } - ++tc.count_; - ++total_customers_; - return delta; -} - -template <typename Dish, typename Hash, typename RNG> -int CRPWithTableTracking<Dish,Hash,RNG>::decrement(const Dish& dish) { - typename MapType::iterator i = counts_.find(dish); - if(i == counts_.end()) { - std::cerr << "MISSING DISH: " << dish << std::endl; - abort(); - } - - int delta = 0; - TableInfo &tc = i->second; - - //std::cout << "\ndecrement for " << dish << " with p0 " << p0 << "\n"; - //std::cout << "\tBEFORE histogram: " << tc.table_histogram << " "; - //std::cout << "count: " << count(dish) << " "; - //std::cout << "tables: " << tc.tables << "\n"; - - int r = static_cast<int>(rng_->next() * tc.count_); - //std::cerr << "FOO: " << r << std::endl; - for (std::map<int,int>::iterator hit = tc.table_histogram_.begin(); - hit != tc.table_histogram_.end(); ++hit) { - r -= (hit->first * hit->second); - if (r <= 0) { - if (hit->first > 1) - tc.table_histogram_[hit->first-1] += 1; - else { - --delta; - --tc.tables_; - --total_tables_; - } - - --hit->second; - if (hit->second == 0) tc.table_histogram_.erase(hit); - break; - } - } - - assert(r <= 0); - - // remove the customer - --tc.count_; - --total_customers_; - assert(tc.count_ >= 0); - if (tc.count_ == 0) counts_.erase(i); - return delta; -} - #endif |