summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src/pyp.hh
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics/src/pyp.hh')
-rw-r--r--gi/pyp-topics/src/pyp.hh477
1 files changed, 477 insertions, 0 deletions
diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh
new file mode 100644
index 00000000..b06c6021
--- /dev/null
+++ b/gi/pyp-topics/src/pyp.hh
@@ -0,0 +1,477 @@
+#ifndef _pyp_hh
+#define _pyp_hh
+
+#include <math.h>
+#include <map>
+#include <tr1/unordered_map>
+
+#include "log_add.h"
+#include "gammadist.h"
+#include "slice-sampler.h"
+#include "mt19937ar.h"
+
+//
+// Pitman-Yor process with customer and table tracking
+//
+
+template <typename Dish, typename Hash=std::tr1::hash<Dish> >
+class PYP : protected std::tr1::unordered_map<Dish, int, Hash>
+{
+public:
+ using std::tr1::unordered_map<Dish,int>::const_iterator;
+ using std::tr1::unordered_map<Dish,int>::iterator;
+ using std::tr1::unordered_map<Dish,int>::begin;
+ using std::tr1::unordered_map<Dish,int>::end;
+
+ PYP(long double a, long double b, Hash hash=Hash());
+
+ int increment(Dish d, long double p0);
+ int decrement(Dish d);
+
+ // lookup functions
+ int count(Dish d) const;
+ long double prob(Dish dish, long double p0) const;
+ long double prob(Dish dish, long double dcd, long double dca,
+ long double dtd, long double dta, long double p0) const;
+
+ int num_customers() const { return _total_customers; }
+ int num_types() const { return std::tr1::unordered_map<Dish,int>::size(); }
+ bool empty() const { return _total_customers == 0; }
+
+ long double log_prob(Dish dish, long double log_p0) const;
+ // nb. d* are NOT logs
+ long double log_prob(Dish dish, long double dcd, long double dca,
+ long double dtd, long double dta, long double log_p0) const;
+
+ int num_tables(Dish dish) const;
+ int num_tables() const;
+
+ long double a() const { return _a; }
+ void set_a(long double a) { _a = a; }
+
+ long double b() const { return _b; }
+ void set_b(long double b) { _b = b; }
+
+ void clear();
+ std::ostream& debug_info(std::ostream& os) const;
+
+ long double log_restaurant_prob() const;
+ long double log_prior() const;
+ static long double log_prior_a(long double a, long double beta_a, long double beta_b);
+ static long double log_prior_b(long double b, long double gamma_c, long double gamma_s);
+
+ void resample_prior();
+ void resample_prior_a();
+ void resample_prior_b();
+
+private:
+ long double _a, _b; // parameters of the Pitman-Yor distribution
+ long double _a_beta_a, _a_beta_b; // parameters of Beta prior on a
+ long double _b_gamma_s, _b_gamma_c; // parameters of Gamma prior on b
+
+ struct TableCounter
+ {
+ TableCounter() : tables(0) {};
+ int tables;
+ std::map<int, int> table_histogram; // num customers at table -> number tables
+ };
+ typedef std::tr1::unordered_map<Dish, TableCounter, Hash> DishTableType;
+ DishTableType _dish_tables;
+ int _total_customers, _total_tables;
+
+ // Function objects for calculating the parts of the log_prob for
+ // the parameters a and b
+ struct resample_a_type {
+ int n, m; long double b, a_beta_a, a_beta_b;
+ const DishTableType& dish_tables;
+ resample_a_type(int n, int m, long double b, long double a_beta_a,
+ long double a_beta_b, const DishTableType& dish_tables)
+ : n(n), m(m), b(b), a_beta_a(a_beta_a), a_beta_b(a_beta_b), dish_tables(dish_tables) {}
+
+ long double operator() (long double proposed_a) const {
+ long double log_prior = log_prior_a(proposed_a, a_beta_a, a_beta_b);
+ long double log_prob = 0.0;
+ long double lgamma1a = lgamma(1.0 - proposed_a);
+ for (typename DishTableType::const_iterator dish_it=dish_tables.begin(); dish_it != dish_tables.end(); ++dish_it)
+ for (std::map<int, int>::const_iterator table_it=dish_it->second.table_histogram.begin();
+ table_it !=dish_it->second.table_histogram.end(); ++table_it)
+ log_prob += (table_it->second * (lgamma(table_it->first - proposed_a) - lgamma1a));
+
+ log_prob += (proposed_a == 0.0 ? (m-1.0)*log(b)
+ : ((m-1.0)*log(proposed_a) + lgamma((m-1.0) + b/proposed_a) - lgamma(b/proposed_a)));
+ assert(std::isfinite(log_prob));
+ return log_prob + log_prior;
+ }
+ };
+
+ struct resample_b_type {
+ int n, m; long double a, b_gamma_c, b_gamma_s;
+ resample_b_type(int n, int m, long double a, long double b_gamma_c, long double b_gamma_s)
+ : n(n), m(m), a(a), b_gamma_c(b_gamma_c), b_gamma_s(b_gamma_s) {}
+
+ long double operator() (long double proposed_b) const {
+ long double log_prior = log_prior_b(proposed_b, b_gamma_c, b_gamma_s);
+ long double log_prob = 0.0;
+ log_prob += (a == 0.0 ? (m-1.0)*log(proposed_b)
+ : ((m-1.0)*log(a) + lgamma((m-1.0) + proposed_b/a) - lgamma(proposed_b/a)));
+ log_prob += (lgamma(1.0+proposed_b) - lgamma(n+proposed_b));
+ return log_prob + log_prior;
+ }
+ };
+};
+
+template <typename Dish, typename Hash>
+PYP<Dish,Hash>::PYP(long double a, long double b, Hash cmp)
+: std::tr1::unordered_map<Dish, int, Hash>(), _a(a), _b(b),
+ _a_beta_a(1), _a_beta_b(1), _b_gamma_s(1), _b_gamma_c(1),
+ //_a_beta_a(1), _a_beta_b(1), _b_gamma_s(10), _b_gamma_c(0.1),
+ _total_customers(0), _total_tables(0)
+{
+// std::cerr << "\t##PYP<Dish,Hash>::PYP(a=" << _a << ",b=" << _b << ")" << std::endl;
+}
+
+template <typename Dish, typename Hash>
+long double
+PYP<Dish,Hash>::prob(Dish dish, long double p0) const
+{
+ int c = count(dish), t = num_tables(dish);
+ long double r = num_tables() * _a + _b;
+ //std::cerr << "\t\t\t\tPYP<Dish,Hash>::prob(" << dish << "," << p0 << ") c=" << c << " r=" << r << std::endl;
+ if (c > 0)
+ return (c - _a * t + r * p0) / (num_customers() + _b);
+ else
+ return r * p0 / (num_customers() + _b);
+}
+
+template <typename Dish, typename Hash>
+long double
+PYP<Dish,Hash>::prob(Dish dish, long double dcd, long double dca,
+ long double dtd, long double dta, long double p0)
+const
+{
+ int c = count(dish) + dcd, t = num_tables(dish) + dtd;
+ long double r = (num_tables() + dta) * _a + _b;
+ if (c > 0)
+ return (c - _a * t + r * p0) / (num_customers() + dca + _b);
+ else
+ return r * p0 / (num_customers() + dca + _b);
+}
+
+template <typename Dish, typename Hash>
+long double
+PYP<Dish,Hash>::log_prob(Dish dish, long double log_p0) const
+{
+ using std::log;
+ int c = count(dish), t = num_tables(dish);
+ long double r = log(num_tables() * _a + b);
+ if (c > 0)
+ return Log<double>::add(log(c - _a * t), r + log_p0)
+ - log(num_customers() + _b);
+ else
+ return r + log_p0 - log(num_customers() + b);
+}
+
+template <typename Dish, typename Hash>
+long double
+PYP<Dish,Hash>::log_prob(Dish dish, long double dcd, long double dca,
+ long double dtd, long double dta, long double log_p0)
+const
+{
+ using std::log;
+ int c = count(dish) + dcd, t = num_tables(dish) + dtd;
+ long double r = log((num_tables() + dta) * _a + b);
+ if (c > 0)
+ return Log<double>::add(log(c - _a * t), r + log_p0)
+ - log(num_customers() + dca + _b);
+ else
+ return r + log_p0 - log(num_customers() + dca + b);
+}
+
+template <typename Dish, typename Hash>
+int
+PYP<Dish,Hash>::increment(Dish dish, long double p0) {
+ int delta = 0;
+ TableCounter &tc = _dish_tables[dish];
+
+ // seated on a new or existing table?
+ int c = count(dish), t = num_tables(dish), T = num_tables();
+ long double pshare = (c > 0) ? (c - _a*t) : 0.0;
+ long double pnew = (_b + _a*T) * p0;
+ assert (pshare >= 0.0);
+ //assert (pnew > 0.0);
+
+ if (mt_genrand_res53() < pnew / (pshare + pnew)) {
+ // assign to a new table
+ tc.tables += 1;
+ tc.table_histogram[1] += 1;
+ _total_tables += 1;
+ delta = 1;
+ }
+ else {
+ // randomly assign to an existing table
+ // remove constant denominator from inner loop
+ long double r = mt_genrand_res53() * (c - _a*t);
+ for (std::map<int,int>::iterator
+ hit = tc.table_histogram.begin();
+ hit != tc.table_histogram.end(); ++hit) {
+ r -= ((hit->first - _a) * hit->second);
+ if (r <= 0) {
+ tc.table_histogram[hit->first+1] += 1;
+ hit->second -= 1;
+ if (hit->second == 0)
+ tc.table_histogram.erase(hit);
+ break;
+ }
+ }
+ if (r > 0) {
+ std::cerr << r << " " << c << " " << _a << " " << t << std::endl;
+ assert(false);
+ }
+ delta = 0;
+ }
+
+ std::tr1::unordered_map<Dish,int,Hash>::operator[](dish) += 1;
+ _total_customers += 1;
+
+ return delta;
+}
+
+template <typename Dish, typename Hash>
+int
+PYP<Dish,Hash>::count(Dish dish) const
+{
+ typename std::tr1::unordered_map<Dish, int>::const_iterator
+ dcit = find(dish);
+ if (dcit != end())
+ return dcit->second;
+ else
+ return 0;
+}
+
+template <typename Dish, typename Hash>
+int
+PYP<Dish,Hash>::decrement(Dish dish)
+{
+ typename std::tr1::unordered_map<Dish, int>::iterator dcit = find(dish);
+ if (dcit == end()) {
+ std::cerr << dish << std::endl;
+ assert(false);
+ }
+
+ int delta = 0;
+
+ typename std::tr1::unordered_map<Dish, TableCounter>::iterator dtit = _dish_tables.find(dish);
+ if (dtit == _dish_tables.end()) {
+ std::cerr << dish << std::endl;
+ assert(false);
+ }
+ TableCounter &tc = dtit->second;
+
+ //std::cerr << "\tdecrement for " << dish << "\n";
+ //std::cerr << "\tBEFORE histogram: " << tc.table_histogram << " ";
+ //std::cerr << "count: " << count(dish) << " ";
+ //std::cerr << "tables: " << tc.tables << "\n";
+
+ long double r = mt_genrand_res53() * count(dish);
+ for (std::map<int,int>::iterator hit = tc.table_histogram.begin();
+ hit != tc.table_histogram.end(); ++hit)
+ {
+ //r -= (hit->first - _a) * hit->second;
+ r -= (hit->first) * hit->second;
+ if (r <= 0)
+ {
+ if (hit->first > 1)
+ tc.table_histogram[hit->first-1] += 1;
+ else
+ {
+ delta = -1;
+ tc.tables -= 1;
+ _total_tables -= 1;
+ }
+
+ hit->second -= 1;
+ if (hit->second == 0) tc.table_histogram.erase(hit);
+ break;
+ }
+ }
+ if (r > 0) {
+ std::cerr << r << " " << count(dish) << " " << _a << " " << num_tables(dish) << std::endl;
+ assert(false);
+ }
+
+ // remove the customer
+ dcit->second -= 1;
+ _total_customers -= 1;
+ assert(dcit->second >= 0);
+ if (dcit->second == 0) {
+ erase(dcit);
+ _dish_tables.erase(dtit);
+ //std::cerr << "\tAFTER histogram: Empty\n";
+ }
+ else {
+ //std::cerr << "\tAFTER histogram: " << _dish_tables[dish].table_histogram << " ";
+ //std::cerr << "count: " << count(dish) << " ";
+ //std::cerr << "tables: " << _dish_tables[dish].tables << "\n";
+ }
+
+ return delta;
+}
+
+template <typename Dish, typename Hash>
+int
+PYP<Dish,Hash>::num_tables(Dish dish) const
+{
+ typename std::tr1::unordered_map<Dish, TableCounter, Hash>::const_iterator
+ dtit = _dish_tables.find(dish);
+
+ //assert(dtit != _dish_tables.end());
+ if (dtit == _dish_tables.end())
+ return 0;
+
+ return dtit->second.tables;
+}
+
+template <typename Dish, typename Hash>
+int
+PYP<Dish,Hash>::num_tables() const
+{
+ return _total_tables;
+}
+
+template <typename Dish, typename Hash>
+std::ostream&
+PYP<Dish,Hash>::debug_info(std::ostream& os) const
+{
+ int hists = 0, tables = 0;
+ for (typename std::tr1::unordered_map<Dish, TableCounter, Hash>::const_iterator
+ dtit = _dish_tables.begin(); dtit != _dish_tables.end(); ++dtit)
+ {
+ hists += dtit->second.table_histogram.size();
+ tables += dtit->second.tables;
+
+ assert(dtit->second.tables > 0);
+ assert(!dtit->second.table_histogram.empty());
+
+ for (std::map<int,int>::const_iterator
+ hit = dtit->second.table_histogram.begin();
+ hit != dtit->second.table_histogram.end(); ++hit)
+ assert(hit->second > 0);
+ }
+
+ os << "restaurant has "
+ << _total_customers << " customers; "
+ << _total_tables << " tables; "
+ << tables << " tables'; "
+ << num_types() << " dishes; "
+ << _dish_tables.size() << " dishes'; and "
+ << hists << " histogram entries\n";
+
+ return os;
+}
+
+template <typename Dish, typename Hash>
+void
+PYP<Dish,Hash>::clear()
+{
+ this->std::tr1::unordered_map<Dish,int,Hash>::clear();
+ _dish_tables.clear();
+ _total_tables = _total_customers = 0;
+}
+
+// log_restaurant_prob returns the log probability of the PYP table configuration.
+// Excludes Hierarchical P0 term which must be calculated separately.
+template <typename Dish, typename Hash>
+long double
+PYP<Dish,Hash>::log_restaurant_prob() const {
+ if (_total_customers < 1)
+ return (long double)0.0;
+
+ long double log_prob = 0.0;
+ long double lgamma1a = lgamma(1.0-_a);
+
+ //std::cerr << "-------------------\n" << std::endl;
+ for (typename DishTableType::const_iterator dish_it=_dish_tables.begin();
+ dish_it != _dish_tables.end(); ++dish_it) {
+ for (std::map<int, int>::const_iterator table_it=dish_it->second.table_histogram.begin();
+ table_it !=dish_it->second.table_histogram.end(); ++table_it) {
+ log_prob += (table_it->second * (lgamma(table_it->first - _a) - lgamma1a));
+ //std::cerr << "|" << dish_it->first->parent << " --> " << dish_it->first->rhs << " " << table_it->first << " " << table_it->second << " " << log_prob;
+ }
+ }
+ //std::cerr << std::endl;
+
+ log_prob += (_a == (long double)0.0 ? (_total_tables-1.0)*log(_b) : (_total_tables-1.0)*log(_a) + lgamma((_total_tables-1.0) + _b/_a) - lgamma(_b/_a));
+ //std::cerr << "\t\t" << log_prob << std::endl;
+ log_prob += (lgamma(1.0 + _b) - lgamma(_total_customers + _b));
+
+ //std::cerr << _total_customers << " " << _total_tables << " " << log_prob << " " << log_prior() << std::endl;
+ //std::cerr << _a << " " << _b << std::endl;
+ if (!std::isfinite(log_prob)) {
+ assert(false);
+ }
+ //return log_prob;
+ return log_prob + log_prior();
+}
+
+template <typename Dish, typename Hash>
+long double
+PYP<Dish,Hash>::log_prior() const {
+ long double prior = 0.0;
+ if (_a_beta_a > 0.0 && _a_beta_b > 0.0 && _a > 0.0)
+ prior += log_prior_a(_a, _a_beta_a, _a_beta_b);
+ if (_b_gamma_s > 0.0 && _b_gamma_c > 0.0)
+ prior += log_prior_b(_b, _b_gamma_c, _b_gamma_s);
+
+ return prior;
+}
+
+template <typename Dish, typename Hash>
+long double
+PYP<Dish,Hash>::log_prior_a(long double a, long double beta_a, long double beta_b) {
+ return lbetadist(a, beta_a, beta_b);
+}
+
+template <typename Dish, typename Hash>
+long double
+PYP<Dish,Hash>::log_prior_b(long double b, long double gamma_c, long double gamma_s) {
+ return lgammadist(b, gamma_c, gamma_s);
+}
+
+template <typename Dish, typename Hash>
+void
+PYP<Dish,Hash>::resample_prior() {
+ for (int num_its=5; num_its >= 0; --num_its) {
+ resample_prior_b();
+ resample_prior_a();
+ }
+ resample_prior_b();
+}
+
+template <typename Dish, typename Hash>
+void
+PYP<Dish,Hash>::resample_prior_b() {
+ if (_total_tables == 0)
+ return;
+
+ int niterations = 10; // number of resampling iterations
+ //std::cerr << "\n## resample_prior_b(), initial a = " << _a << ", b = " << _b << std::endl;
+ resample_b_type b_log_prob(_total_customers, _total_tables, _a, _b_gamma_c, _b_gamma_s);
+ _b = slice_sampler1d(b_log_prob, _b, mt_genrand_res53, (long double) 0.0, std::numeric_limits<long double>::infinity(),
+ (long double) 0.0, niterations, 100*niterations);
+ //std::cerr << "\n## resample_prior_b(), final a = " << _a << ", b = " << _b << std::endl;
+}
+
+template <typename Dish, typename Hash>
+void
+PYP<Dish,Hash>::resample_prior_a() {
+ if (_total_tables == 0)
+ return;
+
+ int niterations = 10;
+ //std::cerr << "\n## Initial a = " << _a << ", b = " << _b << std::endl;
+ resample_a_type a_log_prob(_total_customers, _total_tables, _b, _a_beta_a, _a_beta_b, _dish_tables);
+ _a = slice_sampler1d(a_log_prob, _a, mt_genrand_res53, std::numeric_limits<long double>::min(),
+ (long double) 1.0, (long double) 0.0, niterations, 100*niterations);
+}
+
+#endif