summaryrefslogtreecommitdiff
path: root/gi/pf/pyp_word_model.h
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pf/pyp_word_model.h')
-rw-r--r--gi/pf/pyp_word_model.h46
1 files changed, 25 insertions, 21 deletions
diff --git a/gi/pf/pyp_word_model.h b/gi/pf/pyp_word_model.h
index ff366865..224a9034 100644
--- a/gi/pf/pyp_word_model.h
+++ b/gi/pf/pyp_word_model.h
@@ -11,48 +11,52 @@
#include "os_phrase.h"
// PYP(d,s,poisson-uniform) represented as a CRP
+template <class Base>
struct PYPWordModel {
- explicit PYPWordModel(const unsigned vocab_e_size, const double mean_len = 5) :
- base(prob_t::One()), r(1,1,1,1,0.66,50.0), u0(-std::log(vocab_e_size)), mean_length(mean_len) {}
-
- void ResampleHyperparameters(MT19937* rng);
+ explicit PYPWordModel(Base* b) :
+ base(*b),
+ r(1,1,1,1,0.66,50.0)
+ {}
+
+ void ResampleHyperparameters(MT19937* rng) {
+ r.resample_hyperparameters(rng);
+ std::cerr << " PYPWordModel(d=" << r.discount() << ",s=" << r.strength() << ")\n";
+ }
inline prob_t operator()(const std::vector<WordID>& s) const {
- return r.prob(s, p0(s));
+ return r.prob(s, base(s));
}
inline void Increment(const std::vector<WordID>& s, MT19937* rng) {
- if (r.increment(s, p0(s), rng))
- base *= p0(s);
+ if (r.increment(s, base(s), rng))
+ base.Increment(s, rng);
}
inline void Decrement(const std::vector<WordID>& s, MT19937 *rng) {
if (r.decrement(s, rng))
- base /= p0(s);
+ base.Decrement(s, rng);
}
inline prob_t Likelihood() const {
prob_t p; p.logeq(r.log_crp_prob());
- p *= base;
+ p *= base.Likelihood();
return p;
}
- void Summary() const;
-
- private:
- inline double logp0(const std::vector<WordID>& s) const {
- return Md::log_poisson(s.size(), mean_length) + s.size() * u0;
+ void Summary() const {
+ std::cerr << "PYPWordModel: generations=" << r.num_customers()
+ << " PYP(d=" << r.discount() << ",s=" << r.strength() << ')' << std::endl;
+ for (typename CCRP<std::vector<WordID> >::const_iterator it = r.begin(); it != r.end(); ++it) {
+ std::cerr << " " << it->second.total_dish_count_
+ << " (on " << it->second.table_counts_.size() << " tables) "
+ << TD::GetString(it->first) << std::endl;
+ }
}
- inline prob_t p0(const std::vector<WordID>& s) const {
- prob_t p; p.logeq(logp0(s));
- return p;
- }
+ private:
- prob_t base; // keeps track of the draws from the base distribution
+ Base& base; // keeps track of the draws from the base distribution
CCRP<std::vector<WordID> > r;
- const double u0; // uniform log prob of generating a letter
- const double mean_length; // mean length of a word in the base distribution
};
#endif