summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src/pyp.hh
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics/src/pyp.hh')
-rw-r--r--gi/pyp-topics/src/pyp.hh45
1 files changed, 29 insertions, 16 deletions
diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh
index 84decb0f..19cd6be8 100644
--- a/gi/pyp-topics/src/pyp.hh
+++ b/gi/pyp-topics/src/pyp.hh
@@ -71,9 +71,12 @@ public:
static double log_prior_a(double a, double beta_a, double beta_b);
static double log_prior_b(double b, double gamma_c, double gamma_s);
- void resample_prior();
- void resample_prior_a();
- void resample_prior_b();
+ template <typename Uniform01>
+ void resample_prior(Uniform01& rnd);
+ template <typename Uniform01>
+ void resample_prior_a(Uniform01& rnd);
+ template <typename Uniform01>
+ void resample_prior_b(Uniform01& rnd);
protected:
double _a, _b; // parameters of the Pitman-Yor distribution
@@ -402,13 +405,18 @@ PYP<Dish,Hash>::debug_info(std::ostream& os) const
hists += dtit->second.table_histogram.size();
tables += dtit->second.tables;
+// if (dtit->second.tables <= 0)
+// std::cerr << dtit->first << " " << count(dtit->first) << std::endl;
assert(dtit->second.tables > 0);
assert(!dtit->second.table_histogram.empty());
+// os << "Dish " << dtit->first << " has " << count(dtit->first) << " customers, and is sitting at " << dtit->second.tables << " tables.\n";
for (std::map<int,int>::const_iterator
hit = dtit->second.table_histogram.begin();
- hit != dtit->second.table_histogram.end(); ++hit)
+ hit != dtit->second.table_histogram.end(); ++hit) {
+// os << " " << hit->second << " tables with " << hit->first << " customers." << std::endl;
assert(hit->second > 0);
+ }
}
os << "restaurant has "
@@ -510,41 +518,46 @@ long double PYP<Dish,Hash>::lgammadist(long double x, long double alpha, long do
template <typename Dish, typename Hash>
+ template <typename Uniform01>
void
-PYP<Dish,Hash>::resample_prior() {
+PYP<Dish,Hash>::resample_prior(Uniform01& rnd) {
for (int num_its=5; num_its >= 0; --num_its) {
- resample_prior_b();
- resample_prior_a();
+ resample_prior_b(rnd);
+ resample_prior_a(rnd);
}
- resample_prior_b();
+ resample_prior_b(rnd);
}
template <typename Dish, typename Hash>
+ template <typename Uniform01>
void
-PYP<Dish,Hash>::resample_prior_b() {
+PYP<Dish,Hash>::resample_prior_b(Uniform01& rnd) {
if (_total_tables == 0)
return;
- int niterations = 10; // number of resampling iterations
+ //int niterations = 10; // number of resampling iterations
+ int niterations = 5; // number of resampling iterations
//std::cerr << "\n## resample_prior_b(), initial a = " << _a << ", b = " << _b << std::endl;
resample_b_type b_log_prob(_total_customers, _total_tables, _a, _b_gamma_c, _b_gamma_s);
- //_b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits<double>::infinity(),
- _b = slice_sampler1d(b_log_prob, _b, mt_genrand_res53, (double) 0.0, std::numeric_limits<double>::infinity(),
+ _b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits<double>::infinity(),
+ //_b = slice_sampler1d(b_log_prob, _b, mt_genrand_res53, (double) 0.0, std::numeric_limits<double>::infinity(),
(double) 0.0, niterations, 100*niterations);
//std::cerr << "\n## resample_prior_b(), final a = " << _a << ", b = " << _b << std::endl;
}
template <typename Dish, typename Hash>
+ template <typename Uniform01>
void
-PYP<Dish,Hash>::resample_prior_a() {
+PYP<Dish,Hash>::resample_prior_a(Uniform01& rnd) {
if (_total_tables == 0)
return;
- int niterations = 10;
+ //int niterations = 10;
+ int niterations = 5;
//std::cerr << "\n## Initial a = " << _a << ", b = " << _b << std::endl;
resample_a_type a_log_prob(_total_customers, _total_tables, _b, _a_beta_a, _a_beta_b, _dish_tables);
- //_a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits<double>::min(),
- _a = slice_sampler1d(a_log_prob, _a, mt_genrand_res53, std::numeric_limits<double>::min(),
+ _a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits<double>::min(),
+ //_a = slice_sampler1d(a_log_prob, _a, mt_genrand_res53, std::numeric_limits<double>::min(),
(double) 1.0, (double) 0.0, niterations, 100*niterations);
}