summaryrefslogtreecommitdiff
path: root/utils/slice_sampler.h
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-02-13 21:52:48 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-02-13 21:52:48 -0500
commitd69cee7a46d63b585e319d93429df27717e5043c (patch)
treec94f42539a2d9cad6f6cf9a383b1711a481be490 /utils/slice_sampler.h
parent3ef45aa928719f0631488a1240960ffdeeab384c (diff)
phrasinator training initial
Diffstat (limited to 'utils/slice_sampler.h')
-rw-r--r--utils/slice_sampler.h191
1 files changed, 191 insertions, 0 deletions
diff --git a/utils/slice_sampler.h b/utils/slice_sampler.h
new file mode 100644
index 00000000..aa48a169
--- /dev/null
+++ b/utils/slice_sampler.h
@@ -0,0 +1,191 @@
+//! slice-sampler.h is an MCMC slice sampler
+//!
+//! Mark Johnson, 1st August 2008
+
+#ifndef SLICE_SAMPLER_H
+#define SLICE_SAMPLER_H
+
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <iostream>
+#include <limits>
+
+//! slice_sampler_rfc_type{} returns the value of a user-specified
+//! function if the argument is within range, or - infinity otherwise
+//
+template <typename F, typename Fn, typename U>
+struct slice_sampler_rfc_type {
+ F min_x, max_x;
+ const Fn& f;
+ U max_nfeval, nfeval;
+ slice_sampler_rfc_type(F min_x, F max_x, const Fn& f, U max_nfeval)
+ : min_x(min_x), max_x(max_x), f(f), max_nfeval(max_nfeval), nfeval(0) { }
+
+ F operator() (F x) {
+ if (min_x < x && x < max_x) {
+ assert(++nfeval <= max_nfeval);
+ F fx = f(x);
+ assert(std::isfinite(fx));
+ return fx;
+ }
+ return -std::numeric_limits<F>::infinity();
+ }
+}; // slice_sampler_rfc_type{}
+
+//! slice_sampler1d() implements the univariate "range doubling" slice sampler
+//! described in Neal (2003) "Slice Sampling", The Annals of Statistics 31(3), 705-767.
+//
+template <typename F, typename LogF, typename Uniform01>
+F slice_sampler1d(const LogF& logF0, //!< log of function to sample
+ F x, //!< starting point
+ Uniform01& u01, //!< uniform [0,1) random number generator
+ F min_x = -std::numeric_limits<F>::infinity(), //!< minimum value of support
+ F max_x = std::numeric_limits<F>::infinity(), //!< maximum value of support
+ F w = 0.0, //!< guess at initial width
+ unsigned nsamples=1, //!< number of samples to draw
+ unsigned max_nfeval=200) //!< max number of function evaluations
+{
+ typedef unsigned U;
+ slice_sampler_rfc_type<F,LogF,U> logF(min_x, max_x, logF0, max_nfeval);
+
+ assert(std::isfinite(x));
+
+ if (w <= 0.0) { // set w to a default width
+ if (min_x > -std::numeric_limits<F>::infinity() && max_x < std::numeric_limits<F>::infinity())
+ w = (max_x - min_x)/4;
+ else
+ w = std::max(((x < 0.0) ? -x : x)/4, (F) 0.1);
+ }
+ assert(std::isfinite(w));
+
+ F logFx = logF(x);
+ for (U sample = 0; sample < nsamples; ++sample) {
+ F logY = logFx + log(u01()+1e-100); //! slice logFx at this value
+ assert(std::isfinite(logY));
+
+ F xl = x - w*u01(); //! lower bound on slice interval
+ F logFxl = logF(xl);
+ F xr = xl + w; //! upper bound on slice interval
+ F logFxr = logF(xr);
+
+ while (logY < logFxl || logY < logFxr) // doubling procedure
+ if (u01() < 0.5)
+ logFxl = logF(xl -= xr - xl);
+ else
+ logFxr = logF(xr += xr - xl);
+
+ F xl1 = xl;
+ F xr1 = xr;
+ while (true) { // shrinking procedure
+ F x1 = xl1 + u01()*(xr1 - xl1);
+ if (logY < logF(x1)) {
+ F xl2 = xl; // acceptance procedure
+ F xr2 = xr;
+ bool d = false;
+ while (xr2 - xl2 > 1.1*w) {
+ F xm = (xl2 + xr2)/2;
+ if ((x < xm && x1 >= xm) || (x >= xm && x1 < xm))
+ d = true;
+ if (x1 < xm)
+ xr2 = xm;
+ else
+ xl2 = xm;
+ if (d && logY >= logF(xl2) && logY >= logF(xr2))
+ goto unacceptable;
+ }
+ x = x1;
+ goto acceptable;
+ }
+ goto acceptable;
+ unacceptable:
+ if (x1 < x) // rest of shrinking procedure
+ xl1 = x1;
+ else
+ xr1 = x1;
+ }
+ acceptable:
+ w = (4*w + (xr1 - xl1))/5; // update width estimate
+ }
+ return x;
+}
+
+/*
+//! slice_sampler1d() implements a 1-d MCMC slice sampler.
+//! It should be correct for unimodal distributions, but
+//! not for multimodal ones.
+//
+template <typename F, typename LogP, typename Uniform01>
+F slice_sampler1d(const LogP& logP, //!< log of distribution to sample
+ F x, //!< initial sample
+ Uniform01& u01, //!< uniform random number generator
+ F min_x = -std::numeric_limits<F>::infinity(), //!< minimum value of support
+ F max_x = std::numeric_limits<F>::infinity(), //!< maximum value of support
+ F w = 0.0, //!< guess at initial width
+ unsigned nsamples=1, //!< number of samples to draw
+ unsigned max_nfeval=200) //!< max number of function evaluations
+{
+ typedef unsigned U;
+ assert(std::isfinite(x));
+ if (w <= 0.0) {
+ if (min_x > -std::numeric_limits<F>::infinity() && max_x < std::numeric_limits<F>::infinity())
+ w = (max_x - min_x)/4;
+ else
+ w = std::max(((x < 0.0) ? -x : x)/4, 0.1);
+ }
+ // TRACE4(x, min_x, max_x, w);
+ F logPx = logP(x);
+ assert(std::isfinite(logPx));
+ U nfeval = 1;
+ for (U sample = 0; sample < nsamples; ++sample) {
+ F x0 = x;
+ F logU = logPx + log(u01()+1e-100);
+ assert(std::isfinite(logU));
+ F r = u01();
+ F xl = std::max(min_x, x - r*w);
+ F xr = std::min(max_x, x + (1-r)*w);
+ // TRACE3(x, logPx, logU);
+ while (xl > min_x && logP(xl) > logU) {
+ xl -= w;
+ w *= 2;
+ ++nfeval;
+ if (nfeval >= max_nfeval)
+ std::cerr << "## Error: nfeval = " << nfeval << ", max_nfeval = " << max_nfeval << ", sample = " << sample << ", nsamples = " << nsamples << ", r = " << r << ", w = " << w << ", xl = " << xl << std::endl;
+ assert(nfeval < max_nfeval);
+ }
+ xl = std::max(xl, min_x);
+ while (xr < max_x && logP(xr) > logU) {
+ xr += w;
+ w *= 2;
+ ++nfeval;
+ if (nfeval >= max_nfeval)
+ std::cerr << "## Error: nfeval = " << nfeval << ", max_nfeval = " << max_nfeval << ", sample = " << sample << ", nsamples = " << nsamples << ", r = " << r << ", w = " << w << ", xr = " << xr << std::endl;
+ assert(nfeval < max_nfeval);
+ }
+ xr = std::min(xr, max_x);
+ while (true) {
+ r = u01();
+ x = r*xl + (1-r)*xr;
+ assert(std::isfinite(x));
+ logPx = logP(x);
+ // TRACE4(logPx, x, xl, xr);
+ assert(std::isfinite(logPx));
+ ++nfeval;
+ if (nfeval >= max_nfeval)
+ std::cerr << "## Error: nfeval = " << nfeval << ", max_nfeval = " << max_nfeval << ", sample = " << sample << ", nsamples = " << nsamples << ", r = " << r << ", w = " << w << ", xl = " << xl << ", xr = " << xr << ", x = " << x << std::endl;
+ assert(nfeval < max_nfeval);
+ if (logPx > logU)
+ break;
+ else if (x > x0)
+ xr = x;
+ else
+ xl = x;
+ }
+ // w = (4*w + (xr-xl))/5; // gradually adjust w
+ }
+ // TRACE2(logPx, x);
+ return x;
+} // slice_sampler1d()
+*/
+
+#endif // SLICE_SAMPLER_H