diff options
Diffstat (limited to 'gi/pyp-topics/src/pyp.hh')
-rw-r--r-- | gi/pyp-topics/src/pyp.hh | 45 |
1 files changed, 29 insertions, 16 deletions
diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh index 84decb0f..19cd6be8 100644 --- a/gi/pyp-topics/src/pyp.hh +++ b/gi/pyp-topics/src/pyp.hh @@ -71,9 +71,12 @@ public: static double log_prior_a(double a, double beta_a, double beta_b); static double log_prior_b(double b, double gamma_c, double gamma_s); - void resample_prior(); - void resample_prior_a(); - void resample_prior_b(); + template <typename Uniform01> + void resample_prior(Uniform01& rnd); + template <typename Uniform01> + void resample_prior_a(Uniform01& rnd); + template <typename Uniform01> + void resample_prior_b(Uniform01& rnd); protected: double _a, _b; // parameters of the Pitman-Yor distribution @@ -402,13 +405,18 @@ PYP<Dish,Hash>::debug_info(std::ostream& os) const hists += dtit->second.table_histogram.size(); tables += dtit->second.tables; +// if (dtit->second.tables <= 0) +// std::cerr << dtit->first << " " << count(dtit->first) << std::endl; assert(dtit->second.tables > 0); assert(!dtit->second.table_histogram.empty()); +// os << "Dish " << dtit->first << " has " << count(dtit->first) << " customers, and is sitting at " << dtit->second.tables << " tables.\n"; for (std::map<int,int>::const_iterator hit = dtit->second.table_histogram.begin(); - hit != dtit->second.table_histogram.end(); ++hit) + hit != dtit->second.table_histogram.end(); ++hit) { +// os << " " << hit->second << " tables with " << hit->first << " customers." << std::endl; assert(hit->second > 0); + } } os << "restaurant has " @@ -510,41 +518,46 @@ long double PYP<Dish,Hash>::lgammadist(long double x, long double alpha, long do template <typename Dish, typename Hash> + template <typename Uniform01> void -PYP<Dish,Hash>::resample_prior() { +PYP<Dish,Hash>::resample_prior(Uniform01& rnd) { for (int num_its=5; num_its >= 0; --num_its) { - resample_prior_b(); - resample_prior_a(); + resample_prior_b(rnd); + resample_prior_a(rnd); } - resample_prior_b(); + resample_prior_b(rnd); } template <typename Dish, typename Hash> + template <typename Uniform01> void -PYP<Dish,Hash>::resample_prior_b() { +PYP<Dish,Hash>::resample_prior_b(Uniform01& rnd) { if (_total_tables == 0) return; - int niterations = 10; // number of resampling iterations + //int niterations = 10; // number of resampling iterations + int niterations = 5; // number of resampling iterations //std::cerr << "\n## resample_prior_b(), initial a = " << _a << ", b = " << _b << std::endl; resample_b_type b_log_prob(_total_customers, _total_tables, _a, _b_gamma_c, _b_gamma_s); - //_b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits<double>::infinity(), - _b = slice_sampler1d(b_log_prob, _b, mt_genrand_res53, (double) 0.0, std::numeric_limits<double>::infinity(), + _b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits<double>::infinity(), + //_b = slice_sampler1d(b_log_prob, _b, mt_genrand_res53, (double) 0.0, std::numeric_limits<double>::infinity(), (double) 0.0, niterations, 100*niterations); //std::cerr << "\n## resample_prior_b(), final a = " << _a << ", b = " << _b << std::endl; } template <typename Dish, typename Hash> + template <typename Uniform01> void -PYP<Dish,Hash>::resample_prior_a() { +PYP<Dish,Hash>::resample_prior_a(Uniform01& rnd) { if (_total_tables == 0) return; - int niterations = 10; + //int niterations = 10; + int niterations = 5; //std::cerr << "\n## Initial a = " << _a << ", b = " << _b << std::endl; resample_a_type a_log_prob(_total_customers, _total_tables, _b, _a_beta_a, _a_beta_b, _dish_tables); - //_a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits<double>::min(), - _a = slice_sampler1d(a_log_prob, _a, mt_genrand_res53, std::numeric_limits<double>::min(), + _a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits<double>::min(), + //_a = slice_sampler1d(a_log_prob, _a, mt_genrand_res53, std::numeric_limits<double>::min(), (double) 1.0, (double) 0.0, niterations, 100*niterations); } |