From 9be9a5dde934577de314ce8ac6fb3eb0ba787503 Mon Sep 17 00:00:00 2001 From: redpony Date: Tue, 22 Jun 2010 22:31:28 +0000 Subject: chris's crappy lda git-svn-id: https://ws10smt.googlecode.com/svn/trunk@6 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/clda/src/crp.h | 216 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 216 insertions(+) create mode 100644 gi/clda/src/crp.h (limited to 'gi/clda/src/crp.h') diff --git a/gi/clda/src/crp.h b/gi/clda/src/crp.h new file mode 100644 index 00000000..13596cbf --- /dev/null +++ b/gi/clda/src/crp.h @@ -0,0 +1,216 @@ +#ifndef _CRP_H_ +#define _CRP_H_ + +// shamelessly adapted from code by Phil Blunsom and Trevor Cohn +// There are TWO CRP classes here: CRPWithTableTracking tracks the +// (expected) number of customers per table, and CRP just tracks +// the number of customers / dish. +// If you are implementing a HDP model, you should use CRP for the +// base distribution and CRPWithTableTracking for the dependent +// distribution. + +#include +#include +#include +#include + +#include "prob.h" +#include "sampler.h" // RNG + +template > +class CRP { + public: + CRP(double alpha) : alpha_(alpha), palpha_(alpha), total_customers_() {} + void increment(const DishType& dish); + void decrement(const DishType& dish); + void erase(const DishType& dish) { + counts_.erase(dish); + } + inline int count(const DishType& dish) const { + const typename MapType::const_iterator i = counts_.find(dish); + if (i == counts_.end()) return 0; else return i->second; + } + inline prob_t prob(const DishType& dish) const { + return (prob_t(count(dish) + alpha_)) / prob_t(total_customers_ + alpha_); + } + inline prob_t prob(const DishType& dish, const prob_t& p0) const { + return (prob_t(count(dish)) + palpha_ * p0) / prob_t(total_customers_ + alpha_); + } + private: + typedef std::tr1::unordered_map MapType; + MapType counts_; + const double alpha_; + const prob_t palpha_; + int total_customers_; +}; + +template +void CRP::increment(const Dish& dish) { + ++counts_[dish]; + ++total_customers_; +} + +template +void CRP::decrement(const Dish& dish) { + typename MapType::iterator i = counts_.find(dish); + assert(i != counts_.end()); + if (--i->second == 0) + counts_.erase(i); + --total_customers_; +} + +template , typename RNG = MT19937> +class CRPWithTableTracking { + public: + CRPWithTableTracking(double alpha, RNG* rng) : + alpha_(alpha), palpha_(alpha), total_customers_(), + total_tables_(), rng_(rng) {} + + // seat a customer for dish d, returns the delta in tables + // with customers + int increment(const DishType& d, const prob_t& p0 = prob_t::One()); + int decrement(const DishType& d); + void erase(const DishType& dish); + + inline int count(const DishType& dish) const { + const typename MapType::const_iterator i = counts_.find(dish); + if (i == counts_.end()) return 0; else return i->second.count_; + } + inline prob_t prob(const DishType& dish) const { + return (prob_t(count(dish) + alpha_)) / prob_t(total_customers_ + alpha_); + } + inline prob_t prob(const DishType& dish, const prob_t& p0) const { + return (prob_t(count(dish)) + palpha_ * p0) / prob_t(total_customers_ + alpha_); + } + private: + struct TableInfo { + TableInfo() : count_(), tables_() {} + int count_; // total customers eating dish + int tables_; // total tables labeled with dish + std::map table_histogram_; // num customers at table -> number tables + }; + typedef std::tr1::unordered_map MapType; + + inline prob_t prob_share_table(const double& customer_count) const { + if (customer_count) + return prob_t(customer_count) / prob_t(customer_count + alpha_); + else + return prob_t::Zero(); + } + inline prob_t prob_new_table(const double& customer_count, const prob_t& p0) const { + if (customer_count) + return palpha_ * p0 / prob_t(customer_count + alpha_); + else + return p0; + } + + MapType counts_; + const double alpha_; + const prob_t palpha_; + int total_customers_; + int total_tables_; + RNG* rng_; +}; + +template +int CRPWithTableTracking::increment(const Dish& dish, const prob_t& p0) { + TableInfo& tc = counts_[dish]; + + //std::cerr << "\nincrement for " << dish << " with p0 " << p0 << "\n"; + //std::cerr << "\tBEFORE histogram: " << tc.table_histogram_ << " "; + //std::cerr << "count: " << tc.count_ << " "; + //std::cerr << "tables: " << tc.tables_ << "\n"; + + // seated at a new or existing table? + prob_t pshare = prob_share_table(tc.count_); + prob_t pnew = prob_new_table(tc.count_, p0); + + //std::cerr << "\t\tP0 " << p0 << " count(dish) " << count(dish) + // << " tables " << tc + // << " p(share) " << pshare << " p(new) " << pnew << "\n"; + + int delta = 0; + if (tc.count_ == 0 || rng_->SelectSample(pshare, pnew) == 1) { + // assign to a new table + ++tc.tables_; + ++tc.table_histogram_[1]; + ++total_tables_; + delta = 1; + } else { + // can't share a table if there are no other customers + assert(tc.count_ > 0); + + // randomly assign to an existing table + // remove constant denominator from inner loop + int r = static_cast(rng_->next() * tc.count_); + for (std::map::iterator hit = tc.table_histogram_.begin(); + hit != tc.table_histogram_.end(); ++hit) { + r -= hit->first * hit->second; + if (r <= 0) { + ++tc.table_histogram_[hit->first+1]; + --hit->second; + if (hit->second == 0) + tc.table_histogram_.erase(hit); + break; + } + } + if (r > 0) { + std::cerr << "CONSISTENCY ERROR: " << tc.count_ << std::endl; + std::cerr << pshare << std::endl; + std::cerr << pnew << std::endl; + std::cerr << r << std::endl; + abort(); + } + } + ++tc.count_; + ++total_customers_; + return delta; +} + +template +int CRPWithTableTracking::decrement(const Dish& dish) { + typename MapType::iterator i = counts_.find(dish); + if(i == counts_.end()) { + std::cerr << "MISSING DISH: " << dish << std::endl; + abort(); + } + + int delta = 0; + TableInfo &tc = i->second; + + //std::cout << "\ndecrement for " << dish << " with p0 " << p0 << "\n"; + //std::cout << "\tBEFORE histogram: " << tc.table_histogram << " "; + //std::cout << "count: " << count(dish) << " "; + //std::cout << "tables: " << tc.tables << "\n"; + + int r = static_cast(rng_->next() * tc.count_); + //std::cerr << "FOO: " << r << std::endl; + for (std::map::iterator hit = tc.table_histogram_.begin(); + hit != tc.table_histogram_.end(); ++hit) { + r -= (hit->first * hit->second); + if (r <= 0) { + if (hit->first > 1) + tc.table_histogram_[hit->first-1] += 1; + else { + --delta; + --tc.tables_; + --total_tables_; + } + + --hit->second; + if (hit->second == 0) tc.table_histogram_.erase(hit); + break; + } + } + + assert(r <= 0); + + // remove the customer + --tc.count_; + --total_customers_; + assert(tc.count_ >= 0); + if (tc.count_ == 0) counts_.erase(i); + return delta; +} + +#endif -- cgit v1.2.3