From 73dbb0343a895345a80d49da9d48edac8858e87a Mon Sep 17 00:00:00 2001 From: philblunsom Date: Mon, 19 Jul 2010 18:33:29 +0000 Subject: Vaguely working distributed implementation. Hierarchical topics doesn't yet work correctly. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@317 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/pyp-topics/src/pyp.hh | 45 +++++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 16 deletions(-) (limited to 'gi/pyp-topics/src/pyp.hh') diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh index 84decb0f..19cd6be8 100644 --- a/gi/pyp-topics/src/pyp.hh +++ b/gi/pyp-topics/src/pyp.hh @@ -71,9 +71,12 @@ public: static double log_prior_a(double a, double beta_a, double beta_b); static double log_prior_b(double b, double gamma_c, double gamma_s); - void resample_prior(); - void resample_prior_a(); - void resample_prior_b(); + template + void resample_prior(Uniform01& rnd); + template + void resample_prior_a(Uniform01& rnd); + template + void resample_prior_b(Uniform01& rnd); protected: double _a, _b; // parameters of the Pitman-Yor distribution @@ -402,13 +405,18 @@ PYP::debug_info(std::ostream& os) const hists += dtit->second.table_histogram.size(); tables += dtit->second.tables; +// if (dtit->second.tables <= 0) +// std::cerr << dtit->first << " " << count(dtit->first) << std::endl; assert(dtit->second.tables > 0); assert(!dtit->second.table_histogram.empty()); +// os << "Dish " << dtit->first << " has " << count(dtit->first) << " customers, and is sitting at " << dtit->second.tables << " tables.\n"; for (std::map::const_iterator hit = dtit->second.table_histogram.begin(); - hit != dtit->second.table_histogram.end(); ++hit) + hit != dtit->second.table_histogram.end(); ++hit) { +// os << " " << hit->second << " tables with " << hit->first << " customers." << std::endl; assert(hit->second > 0); + } } os << "restaurant has " @@ -510,41 +518,46 @@ long double PYP::lgammadist(long double x, long double alpha, long do template + template void -PYP::resample_prior() { +PYP::resample_prior(Uniform01& rnd) { for (int num_its=5; num_its >= 0; --num_its) { - resample_prior_b(); - resample_prior_a(); + resample_prior_b(rnd); + resample_prior_a(rnd); } - resample_prior_b(); + resample_prior_b(rnd); } template + template void -PYP::resample_prior_b() { +PYP::resample_prior_b(Uniform01& rnd) { if (_total_tables == 0) return; - int niterations = 10; // number of resampling iterations + //int niterations = 10; // number of resampling iterations + int niterations = 5; // number of resampling iterations //std::cerr << "\n## resample_prior_b(), initial a = " << _a << ", b = " << _b << std::endl; resample_b_type b_log_prob(_total_customers, _total_tables, _a, _b_gamma_c, _b_gamma_s); - //_b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits::infinity(), - _b = slice_sampler1d(b_log_prob, _b, mt_genrand_res53, (double) 0.0, std::numeric_limits::infinity(), + _b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits::infinity(), + //_b = slice_sampler1d(b_log_prob, _b, mt_genrand_res53, (double) 0.0, std::numeric_limits::infinity(), (double) 0.0, niterations, 100*niterations); //std::cerr << "\n## resample_prior_b(), final a = " << _a << ", b = " << _b << std::endl; } template + template void -PYP::resample_prior_a() { +PYP::resample_prior_a(Uniform01& rnd) { if (_total_tables == 0) return; - int niterations = 10; + //int niterations = 10; + int niterations = 5; //std::cerr << "\n## Initial a = " << _a << ", b = " << _b << std::endl; resample_a_type a_log_prob(_total_customers, _total_tables, _b, _a_beta_a, _a_beta_b, _dish_tables); - //_a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits::min(), - _a = slice_sampler1d(a_log_prob, _a, mt_genrand_res53, std::numeric_limits::min(), + _a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits::min(), + //_a = slice_sampler1d(a_log_prob, _a, mt_genrand_res53, std::numeric_limits::min(), (double) 1.0, (double) 0.0, niterations, 100*niterations); } -- cgit v1.2.3