diff options
Diffstat (limited to 'gi/pf/tied_resampler.h')
-rw-r--r-- | gi/pf/tied_resampler.h | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/gi/pf/tied_resampler.h b/gi/pf/tied_resampler.h index 208fb9c7..5a262f9d 100644 --- a/gi/pf/tied_resampler.h +++ b/gi/pf/tied_resampler.h @@ -2,6 +2,7 @@ #define _TIED_RESAMPLER_H_ #include <set> +#include <vector> #include "sampler.h" #include "slice_sampler.h" #include "m.h" @@ -28,6 +29,10 @@ struct TiedResampler { crps.erase(crp); } + size_t size() const { + return crps.size(); + } + double LogLikelihood(double d, double s) const { if (s <= -d) return -std::numeric_limits<double>::infinity(); double llh = Md::log_beta_density(d, d_alpha, d_beta) + @@ -54,6 +59,7 @@ struct TiedResampler { }; void ResampleHyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) { + if (size() == 0) { std::cerr << "EMPTY - not resampling\n"; return; } const DiscountResampler dr(*this); const AlphaResampler ar(*this); for (int iter = 0; iter < nloop; ++iter) { @@ -79,4 +85,29 @@ struct TiedResampler { double discount, strength; }; +// split according to some criterion +template <class CRP> +struct BinTiedResampler { + explicit BinTiedResampler(unsigned nbins) : + resamplers(nbins, TiedResampler<CRP>(1,1,1,1)) {} + + void Add(unsigned bin, CRP* crp) { + resamplers[bin].Add(crp); + } + + void Remove(unsigned bin, CRP* crp) { + resamplers[bin].Remove(crp); + } + + void ResampleHyperparameters(MT19937* rng) { + for (unsigned i = 0; i < resamplers.size(); ++i) { + std::cerr << "BIN " << i << " (" << resamplers[i].size() << " CRPs): " << std::flush; + resamplers[i].ResampleHyperparameters(rng); + } + } + + private: + std::vector<TiedResampler<CRP> > resamplers; +}; + #endif |