summaryrefslogtreecommitdiff
path: root/gi/pf/tied_resampler.h
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pf/tied_resampler.h')
-rw-r--r--gi/pf/tied_resampler.h82
1 files changed, 82 insertions, 0 deletions
diff --git a/gi/pf/tied_resampler.h b/gi/pf/tied_resampler.h
new file mode 100644
index 00000000..208fb9c7
--- /dev/null
+++ b/gi/pf/tied_resampler.h
@@ -0,0 +1,82 @@
+#ifndef _TIED_RESAMPLER_H_
+#define _TIED_RESAMPLER_H_
+
+#include <set>
+#include "sampler.h"
+#include "slice_sampler.h"
+#include "m.h"
+
+template <class CRP>
+struct TiedResampler {
+ explicit TiedResampler(double da, double db, double ss, double sr, double d=0.5, double s=1.0) :
+ d_alpha(da),
+ d_beta(db),
+ s_shape(ss),
+ s_rate(sr),
+ discount(d),
+ strength(s) {}
+
+ void Add(CRP* crp) {
+ crps.insert(crp);
+ crp->set_discount(discount);
+ crp->set_strength(strength);
+ assert(!crp->has_discount_prior());
+ assert(!crp->has_strength_prior());
+ }
+
+ void Remove(CRP* crp) {
+ crps.erase(crp);
+ }
+
+ double LogLikelihood(double d, double s) const {
+ if (s <= -d) return -std::numeric_limits<double>::infinity();
+ double llh = Md::log_beta_density(d, d_alpha, d_beta) +
+ Md::log_gamma_density(d + s, s_shape, s_rate);
+ for (typename std::set<CRP*>::iterator it = crps.begin(); it != crps.end(); ++it)
+ llh += (*it)->log_crp_prob(d, s);
+ return llh;
+ }
+
+ struct DiscountResampler {
+ DiscountResampler(const TiedResampler& m) : m_(m) {}
+ const TiedResampler& m_;
+ double operator()(const double& proposed_discount) const {
+ return m_.LogLikelihood(proposed_discount, m_.strength);
+ }
+ };
+
+ struct AlphaResampler {
+ AlphaResampler(const TiedResampler& m) : m_(m) {}
+ const TiedResampler& m_;
+ double operator()(const double& proposed_strength) const {
+ return m_.LogLikelihood(m_.discount, proposed_strength);
+ }
+ };
+
+ void ResampleHyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) {
+ const DiscountResampler dr(*this);
+ const AlphaResampler ar(*this);
+ for (int iter = 0; iter < nloop; ++iter) {
+ strength = slice_sampler1d(ar, strength, *rng, -discount + std::numeric_limits<double>::min(),
+ std::numeric_limits<double>::infinity(), 0.0, niterations, 100*niterations);
+ double min_discount = std::numeric_limits<double>::min();
+ if (strength < 0.0) min_discount -= strength;
+ discount = slice_sampler1d(dr, discount, *rng, min_discount,
+ 1.0, 0.0, niterations, 100*niterations);
+ }
+ strength = slice_sampler1d(ar, strength, *rng, -discount + std::numeric_limits<double>::min(),
+ std::numeric_limits<double>::infinity(), 0.0, niterations, 100*niterations);
+ std::cerr << "TiedCRPs(d=" << discount << ",s="
+ << strength << ") = " << LogLikelihood(discount, strength) << std::endl;
+ for (typename std::set<CRP*>::iterator it = crps.begin(); it != crps.end(); ++it) {
+ (*it)->set_discount(discount);
+ (*it)->set_strength(strength);
+ }
+ }
+ private:
+ std::set<CRP*> crps;
+ const double d_alpha, d_beta, s_shape, s_rate;
+ double discount, strength;
+};
+
+#endif