summaryrefslogtreecommitdiff
path: root/gi/pf/tied_resampler.h
blob: 208fb9c7a62cc0cc933a6cbca543a2f860f5eca7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#ifndef _TIED_RESAMPLER_H_
#define _TIED_RESAMPLER_H_

#include <set>
#include "sampler.h"
#include "slice_sampler.h"
#include "m.h"

template <class CRP>
struct TiedResampler {
  explicit TiedResampler(double da, double db, double ss, double sr, double d=0.5, double s=1.0) :
      d_alpha(da),
      d_beta(db),
      s_shape(ss),
      s_rate(sr),
      discount(d),
      strength(s) {}

  void Add(CRP* crp) {
    crps.insert(crp);
    crp->set_discount(discount);
    crp->set_strength(strength);
    assert(!crp->has_discount_prior());
    assert(!crp->has_strength_prior());
  }

  void Remove(CRP* crp) {
    crps.erase(crp);
  }

  double LogLikelihood(double d, double s) const {
    if (s <= -d) return -std::numeric_limits<double>::infinity();
    double llh = Md::log_beta_density(d, d_alpha, d_beta) +
                 Md::log_gamma_density(d + s, s_shape, s_rate);
    for (typename std::set<CRP*>::iterator it = crps.begin(); it != crps.end(); ++it)
      llh += (*it)->log_crp_prob(d, s);
    return llh;
  }

  struct DiscountResampler {
    DiscountResampler(const TiedResampler& m) : m_(m) {}
    const TiedResampler& m_;
    double operator()(const double& proposed_discount) const {
      return m_.LogLikelihood(proposed_discount, m_.strength);
    }
  };

  struct AlphaResampler {
    AlphaResampler(const TiedResampler& m) : m_(m) {}
    const TiedResampler& m_;
    double operator()(const double& proposed_strength) const {
      return m_.LogLikelihood(m_.discount, proposed_strength);
    }
  };

  void ResampleHyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) {
    const DiscountResampler dr(*this);
    const AlphaResampler ar(*this);
    for (int iter = 0; iter < nloop; ++iter) {
      strength = slice_sampler1d(ar, strength, *rng, -discount + std::numeric_limits<double>::min(),
                              std::numeric_limits<double>::infinity(), 0.0, niterations, 100*niterations);
      double min_discount = std::numeric_limits<double>::min();
      if (strength < 0.0) min_discount -= strength;
      discount = slice_sampler1d(dr, discount, *rng, min_discount,
                          1.0, 0.0, niterations, 100*niterations);
    }
    strength = slice_sampler1d(ar, strength, *rng, -discount + std::numeric_limits<double>::min(),
                            std::numeric_limits<double>::infinity(), 0.0, niterations, 100*niterations);
    std::cerr << "TiedCRPs(d=" << discount << ",s="
              << strength << ") = " << LogLikelihood(discount, strength) << std::endl;
    for (typename std::set<CRP*>::iterator it = crps.begin(); it != crps.end(); ++it) {
      (*it)->set_discount(discount);
      (*it)->set_strength(strength);
    }
  }
 private:
  std::set<CRP*> crps;
  const double d_alpha, d_beta, s_shape, s_rate;
  double discount, strength;
};

#endif