From 89d63600524bc042b6c2741d7d67db6a3a74dc8c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 9 Mar 2012 22:23:50 -0500 Subject: moar --- gi/pf/pyp_word_model.h | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 gi/pf/pyp_word_model.h (limited to 'gi/pf/pyp_word_model.h') diff --git a/gi/pf/pyp_word_model.h b/gi/pf/pyp_word_model.h new file mode 100644 index 00000000..800a4fd7 --- /dev/null +++ b/gi/pf/pyp_word_model.h @@ -0,0 +1,58 @@ +#ifndef _PYP_WORD_MODEL_H_ +#define _PYP_WORD_MODEL_H_ + +#include +#include +#include +#include "prob.h" +#include "ccrp.h" +#include "m.h" +#include "tdict.h" +#include "os_phrase.h" + +// PYP(d,s,poisson-uniform) represented as a CRP +struct PYPWordModel { + explicit PYPWordModel(const unsigned vocab_e_size, const double mean_len = 7.5) : + base(prob_t::One()), r(1,1,1,1,0.66,50.0), u0(-std::log(vocab_e_size)), mean_length(mean_len) {} + + void ResampleHyperparameters(MT19937* rng); + + inline prob_t operator()(const std::vector& s) const { + return r.prob(s, p0(s)); + } + + inline void Increment(const std::vector& s, MT19937* rng) { + if (r.increment(s, p0(s), rng)) + base *= p0(s); + } + + inline void Decrement(const std::vector& s, MT19937 *rng) { + if (r.decrement(s, rng)) + base /= p0(s); + } + + inline prob_t Likelihood() const { + prob_t p; p.logeq(r.log_crp_prob()); + p *= base; + return p; + } + + void Summary() const; + + private: + inline double logp0(const std::vector& s) const { + return Md::log_poisson(s.size(), mean_length) + s.size() * u0; + } + + inline prob_t p0(const std::vector& s) const { + prob_t p; p.logeq(logp0(s)); + return p; + } + + prob_t base; // keeps track of the draws from the base distribution + CCRP > r; + const double u0; // uniform log prob of generating a letter + const double mean_length; // mean length of a word in the base distribution +}; + +#endif -- cgit v1.2.3 From 5f9f400f4359bc14f7231d6eabd76b7ceee737aa Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 9 Mar 2012 23:13:09 -0500 Subject: logging after alignment --- gi/pf/align-lexonly-pyp.cc | 1 + gi/pf/pyp_tm.cc | 7 +++++-- gi/pf/pyp_word_model.h | 2 +- utils/ccrp.h | 1 + 4 files changed, 8 insertions(+), 3 deletions(-) (limited to 'gi/pf/pyp_word_model.h') diff --git a/gi/pf/align-lexonly-pyp.cc b/gi/pf/align-lexonly-pyp.cc index d68a4b8f..4a1d1db6 100644 --- a/gi/pf/align-lexonly-pyp.cc +++ b/gi/pf/align-lexonly-pyp.cc @@ -208,6 +208,7 @@ int main(int argc, char** argv) { } for (unsigned i = 0; i < corpus.size(); ++i) WriteAlignments(corpus[i]); + aligner.model.Summary(); return 0; } diff --git a/gi/pf/pyp_tm.cc b/gi/pf/pyp_tm.cc index 94cbe7c3..b5262f47 100644 --- a/gi/pf/pyp_tm.cc +++ b/gi/pf/pyp_tm.cc @@ -54,8 +54,6 @@ struct ConditionalPYPWordModel { assert(it != r.end()); if (it->second.decrement(trglets, rng)) { base.Decrement(trglets, rng); - if (it->second.num_customers() == 0) - r.erase(it); } } @@ -84,6 +82,11 @@ PYPLexicalTranslation::PYPLexicalTranslation(const vector >& lets tmodel(new ConditionalPYPWordModel(up0)), kX(-TD::Convert("X")) {} +void PYPLexicalTranslation::Summary() const { + tmodel->Summary(); + up0->Summary(); +} + prob_t PYPLexicalTranslation::Likelihood() const { prob_t p = up0->Likelihood(); p *= tmodel->Likelihood(); diff --git a/gi/pf/pyp_word_model.h b/gi/pf/pyp_word_model.h index 800a4fd7..ff366865 100644 --- a/gi/pf/pyp_word_model.h +++ b/gi/pf/pyp_word_model.h @@ -12,7 +12,7 @@ // PYP(d,s,poisson-uniform) represented as a CRP struct PYPWordModel { - explicit PYPWordModel(const unsigned vocab_e_size, const double mean_len = 7.5) : + explicit PYPWordModel(const unsigned vocab_e_size, const double mean_len = 5) : base(prob_t::One()), r(1,1,1,1,0.66,50.0), u0(-std::log(vocab_e_size)), mean_length(mean_len) {} void ResampleHyperparameters(MT19937* rng); diff --git a/utils/ccrp.h b/utils/ccrp.h index 439d7e1e..4a8b80e7 100644 --- a/utils/ccrp.h +++ b/utils/ccrp.h @@ -221,6 +221,7 @@ class CCRP { void resample_hyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) { assert(has_discount_prior() || has_strength_prior()); + if (num_customers() == 0) return; DiscountResampler dr(*this); StrengthResampler sr(*this); for (int iter = 0; iter < nloop; ++iter) { -- cgit v1.2.3