summaryrefslogtreecommitdiff
path: root/gi/pf/pyp_word_model.h
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-03-09 22:23:50 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-03-09 22:23:50 -0500
commit89d63600524bc042b6c2741d7d67db6a3a74dc8c (patch)
treeccec74b3e311084d6d2013e4b3b101e29108ded0 /gi/pf/pyp_word_model.h
parent0ab9e175f86b3fd02a4a94f350282210aba054e3 (diff)
moar
Diffstat (limited to 'gi/pf/pyp_word_model.h')
-rw-r--r--gi/pf/pyp_word_model.h58
1 files changed, 58 insertions, 0 deletions
diff --git a/gi/pf/pyp_word_model.h b/gi/pf/pyp_word_model.h
new file mode 100644
index 00000000..800a4fd7
--- /dev/null
+++ b/gi/pf/pyp_word_model.h
@@ -0,0 +1,58 @@
+#ifndef _PYP_WORD_MODEL_H_
+#define _PYP_WORD_MODEL_H_
+
+#include <iostream>
+#include <cmath>
+#include <vector>
+#include "prob.h"
+#include "ccrp.h"
+#include "m.h"
+#include "tdict.h"
+#include "os_phrase.h"
+
+// PYP(d,s,poisson-uniform) represented as a CRP
+struct PYPWordModel {
+ explicit PYPWordModel(const unsigned vocab_e_size, const double mean_len = 7.5) :
+ base(prob_t::One()), r(1,1,1,1,0.66,50.0), u0(-std::log(vocab_e_size)), mean_length(mean_len) {}
+
+ void ResampleHyperparameters(MT19937* rng);
+
+ inline prob_t operator()(const std::vector<WordID>& s) const {
+ return r.prob(s, p0(s));
+ }
+
+ inline void Increment(const std::vector<WordID>& s, MT19937* rng) {
+ if (r.increment(s, p0(s), rng))
+ base *= p0(s);
+ }
+
+ inline void Decrement(const std::vector<WordID>& s, MT19937 *rng) {
+ if (r.decrement(s, rng))
+ base /= p0(s);
+ }
+
+ inline prob_t Likelihood() const {
+ prob_t p; p.logeq(r.log_crp_prob());
+ p *= base;
+ return p;
+ }
+
+ void Summary() const;
+
+ private:
+ inline double logp0(const std::vector<WordID>& s) const {
+ return Md::log_poisson(s.size(), mean_length) + s.size() * u0;
+ }
+
+ inline prob_t p0(const std::vector<WordID>& s) const {
+ prob_t p; p.logeq(logp0(s));
+ return p;
+ }
+
+ prob_t base; // keeps track of the draws from the base distribution
+ CCRP<std::vector<WordID> > r;
+ const double u0; // uniform log prob of generating a letter
+ const double mean_length; // mean length of a word in the base distribution
+};
+
+#endif