summaryrefslogtreecommitdiff
path: root/gi/pf/tied_resampler.h
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pf/tied_resampler.h')
-rw-r--r--gi/pf/tied_resampler.h31
1 files changed, 31 insertions, 0 deletions
diff --git a/gi/pf/tied_resampler.h b/gi/pf/tied_resampler.h
index 208fb9c7..5a262f9d 100644
--- a/gi/pf/tied_resampler.h
+++ b/gi/pf/tied_resampler.h
@@ -2,6 +2,7 @@
#define _TIED_RESAMPLER_H_
#include <set>
+#include <vector>
#include "sampler.h"
#include "slice_sampler.h"
#include "m.h"
@@ -28,6 +29,10 @@ struct TiedResampler {
crps.erase(crp);
}
+ size_t size() const {
+ return crps.size();
+ }
+
double LogLikelihood(double d, double s) const {
if (s <= -d) return -std::numeric_limits<double>::infinity();
double llh = Md::log_beta_density(d, d_alpha, d_beta) +
@@ -54,6 +59,7 @@ struct TiedResampler {
};
void ResampleHyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) {
+ if (size() == 0) { std::cerr << "EMPTY - not resampling\n"; return; }
const DiscountResampler dr(*this);
const AlphaResampler ar(*this);
for (int iter = 0; iter < nloop; ++iter) {
@@ -79,4 +85,29 @@ struct TiedResampler {
double discount, strength;
};
+// split according to some criterion
+template <class CRP>
+struct BinTiedResampler {
+ explicit BinTiedResampler(unsigned nbins) :
+ resamplers(nbins, TiedResampler<CRP>(1,1,1,1)) {}
+
+ void Add(unsigned bin, CRP* crp) {
+ resamplers[bin].Add(crp);
+ }
+
+ void Remove(unsigned bin, CRP* crp) {
+ resamplers[bin].Remove(crp);
+ }
+
+ void ResampleHyperparameters(MT19937* rng) {
+ for (unsigned i = 0; i < resamplers.size(); ++i) {
+ std::cerr << "BIN " << i << " (" << resamplers[i].size() << " CRPs): " << std::flush;
+ resamplers[i].ResampleHyperparameters(rng);
+ }
+ }
+
+ private:
+ std::vector<TiedResampler<CRP> > resamplers;
+};
+
#endif