From b2171f53c6c597ac4326f63250269aa13df84718 Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Fri, 21 Oct 2011 15:24:32 +0100 Subject: more particle filter stuff --- gi/pf/pf.h | 84 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 gi/pf/pf.h (limited to 'gi/pf/pf.h') diff --git a/gi/pf/pf.h b/gi/pf/pf.h new file mode 100644 index 00000000..ede7cda8 --- /dev/null +++ b/gi/pf/pf.h @@ -0,0 +1,84 @@ +#ifndef _PF_H_ +#define _PF_H_ + +#include +#include +#include "sampler.h" +#include "prob.h" + +template +struct ParticleRenormalizer { + void operator()(std::vector* pv) const { + if (pv->empty()) return; + prob_t z = prob_t::Zero(); + for (unsigned i = 0; i < pv->size(); ++i) + z += (*pv)[i].weight; + assert(z > prob_t::Zero()); + for (unsigned i = 0; i < pv->size(); ++i) + (*pv)[i].weight /= z; + } +}; + +template +struct MultinomialResampleFilter { + explicit MultinomialResampleFilter(MT19937* rng) : rng_(rng) {} + + void operator()(std::vector* pv) { + if (pv->empty()) return; + std::vector& ps = *pv; + SampleSet ss; + for (int i = 0; i < ps.size(); ++i) + ss.add(ps[i].weight); + std::vector nps; nps.reserve(ps.size()); + const prob_t uniform_weight(1.0 / ps.size()); + for (int i = 0; i < ps.size(); ++i) { + nps.push_back(ps[rng_->SelectSample(ss)]); + nps[i].weight = uniform_weight; + } + nps.swap(ps); + } + + private: + MT19937* rng_; +}; + +template +struct SystematicResampleFilter { + explicit SystematicResampleFilter(MT19937* rng) : rng_(rng), renorm_() {} + + void operator()(std::vector* pv) { + if (pv->empty()) return; + renorm_(pv); + std::vector& ps = *pv; + std::vector nps; nps.reserve(ps.size()); + double lower = 0, upper = 0; + const double skip = 1.0 / ps.size(); + double u_j = rng_->next() * skip; + //std::cerr << "u_0: " << u_j << std::endl; + int j = 0; + for (unsigned i = 0; i < ps.size(); ++i) { + upper += ps[i].weight.as_float(); + //std::cerr << "lower: " << lower << " upper: " << upper << std::endl; + // how many children does ps[i] have? + while (u_j < lower) { u_j += skip; ++j; } + while (u_j >= lower && u_j <= upper) { + assert(j < ps.size()); + nps.push_back(ps[i]); + u_j += skip; + //std::cerr << " add u_j=" << u_j << std::endl; + ++j; + } + lower = upper; + } + //std::cerr << ps.size() << " " << nps.size() << "\n"; + assert(ps.size() == nps.size()); + //exit(1); + ps.swap(nps); + } + + private: + MT19937* rng_; + ParticleRenormalizer renorm_; +}; + +#endif -- cgit v1.2.3