From 0b598b997a7c1d2d9dc255cc2ff1bf9bb2c425a1 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 15 Mar 2012 22:47:04 -0400 Subject: bayes bayes bayes --- gi/pf/pyp_word_model.h | 46 +++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) (limited to 'gi/pf/pyp_word_model.h') diff --git a/gi/pf/pyp_word_model.h b/gi/pf/pyp_word_model.h index ff366865..224a9034 100644 --- a/gi/pf/pyp_word_model.h +++ b/gi/pf/pyp_word_model.h @@ -11,48 +11,52 @@ #include "os_phrase.h" // PYP(d,s,poisson-uniform) represented as a CRP +template struct PYPWordModel { - explicit PYPWordModel(const unsigned vocab_e_size, const double mean_len = 5) : - base(prob_t::One()), r(1,1,1,1,0.66,50.0), u0(-std::log(vocab_e_size)), mean_length(mean_len) {} - - void ResampleHyperparameters(MT19937* rng); + explicit PYPWordModel(Base* b) : + base(*b), + r(1,1,1,1,0.66,50.0) + {} + + void ResampleHyperparameters(MT19937* rng) { + r.resample_hyperparameters(rng); + std::cerr << " PYPWordModel(d=" << r.discount() << ",s=" << r.strength() << ")\n"; + } inline prob_t operator()(const std::vector& s) const { - return r.prob(s, p0(s)); + return r.prob(s, base(s)); } inline void Increment(const std::vector& s, MT19937* rng) { - if (r.increment(s, p0(s), rng)) - base *= p0(s); + if (r.increment(s, base(s), rng)) + base.Increment(s, rng); } inline void Decrement(const std::vector& s, MT19937 *rng) { if (r.decrement(s, rng)) - base /= p0(s); + base.Decrement(s, rng); } inline prob_t Likelihood() const { prob_t p; p.logeq(r.log_crp_prob()); - p *= base; + p *= base.Likelihood(); return p; } - void Summary() const; - - private: - inline double logp0(const std::vector& s) const { - return Md::log_poisson(s.size(), mean_length) + s.size() * u0; + void Summary() const { + std::cerr << "PYPWordModel: generations=" << r.num_customers() + << " PYP(d=" << r.discount() << ",s=" << r.strength() << ')' << std::endl; + for (typename CCRP >::const_iterator it = r.begin(); it != r.end(); ++it) { + std::cerr << " " << it->second.total_dish_count_ + << " (on " << it->second.table_counts_.size() << " tables) " + << TD::GetString(it->first) << std::endl; + } } - inline prob_t p0(const std::vector& s) const { - prob_t p; p.logeq(logp0(s)); - return p; - } + private: - prob_t base; // keeps track of the draws from the base distribution + Base& base; // keeps track of the draws from the base distribution CCRP > r; - const double u0; // uniform log prob of generating a letter - const double mean_length; // mean length of a word in the base distribution }; #endif -- cgit v1.2.3