summaryrefslogtreecommitdiff
path: root/gi/clda/src
diff options
context:
space:
mode:
Diffstat (limited to 'gi/clda/src')
-rw-r--r--gi/clda/src/ccrp.h119
-rw-r--r--gi/clda/src/clda.cc12
-rw-r--r--gi/clda/src/crp_test.cc6
-rw-r--r--gi/clda/src/slice_sampler.h191
4 files changed, 319 insertions, 9 deletions
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h
index eeccce1a..74d5be29 100644
--- a/gi/clda/src/ccrp.h
+++ b/gi/clda/src/ccrp.h
@@ -1,6 +1,7 @@
#ifndef _CCRP_H_
#define _CCRP_H_
+#include <numeric>
#include <cassert>
#include <cmath>
#include <list>
@@ -9,6 +10,7 @@
#include <tr1/unordered_map>
#include <boost/functional/hash.hpp>
#include "sampler.h"
+#include "slice_sampler.h"
// Chinese restaurant process (Pitman-Yor parameters) with explicit table
// tracking.
@@ -16,7 +18,36 @@
template <typename Dish, typename DishHash = boost::hash<Dish> >
class CCRP {
public:
- CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {}
+ CCRP(double disc, double conc) :
+ num_tables_(),
+ num_customers_(),
+ discount_(disc),
+ concentration_(conc),
+ discount_prior_alpha_(std::numeric_limits<double>::quiet_NaN()),
+ discount_prior_beta_(std::numeric_limits<double>::quiet_NaN()),
+ concentration_prior_shape_(std::numeric_limits<double>::quiet_NaN()),
+ concentration_prior_rate_(std::numeric_limits<double>::quiet_NaN()) {}
+
+ CCRP(double d_alpha, double d_beta, double c_shape, double c_rate, double d = 0.1, double c = 10.0) :
+ num_tables_(),
+ num_customers_(),
+ discount_(d),
+ concentration_(c),
+ discount_prior_alpha_(d_alpha),
+ discount_prior_beta_(d_beta),
+ concentration_prior_shape_(c_shape),
+ concentration_prior_rate_(c_rate) {}
+
+ double discount() const { return discount_; }
+ double concentration() const { return concentration_; }
+
+ bool has_discount_prior() const {
+ return !std::isnan(discount_prior_alpha_);
+ }
+
+ bool has_concentration_prior() const {
+ return !std::isnan(concentration_prior_shape_);
+ }
void clear() {
num_tables_ = 0;
@@ -115,24 +146,90 @@ class CCRP {
}
}
+ double log_crp_prob() const {
+ return log_crp_prob(discount_, concentration_);
+ }
+
+ static double log_beta_density(const double& x, const double& alpha, const double& beta) {
+ assert(x > 0.0);
+ assert(x < 1.0);
+ assert(alpha > 0.0);
+ assert(beta > 0.0);
+ const double lp = (alpha-1)*log(x)+(beta-1)*log(1-x)+lgamma(alpha+beta)-lgamma(alpha)-lgamma(beta);
+ return lp;
+ }
+
+ static double log_gamma_density(const double& x, const double& shape, const double& rate) {
+ assert(x >= 0.0);
+ assert(shape > 0.0);
+ assert(rate > 0.0);
+ const double lp = (shape-1)*log(x) - shape*log(rate) - x/rate - lgamma(shape);
+ return lp;
+ }
+
// taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process
// does not include P_0's
- double log_crp_prob() const {
+ double log_crp_prob(const double& discount, const double& concentration) const {
double lp = 0.0;
+ if (has_discount_prior())
+ lp = log_beta_density(discount, discount_prior_alpha_, discount_prior_beta_);
+ if (has_concentration_prior())
+ lp += log_gamma_density(concentration, concentration_prior_shape_, concentration_prior_rate_);
+ assert(lp <= 0.0);
if (num_customers_) {
- const double r = lgamma(1.0 - discount_);
- lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_)
- + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_)
- - lgamma(concentration_ / discount_);
+ const double r = lgamma(1.0 - discount);
+ lp += lgamma(concentration) - lgamma(concentration + num_customers_)
+ + num_tables_ * discount + lgamma(concentration / discount + num_tables_)
+ - lgamma(concentration / discount);
+ assert(std::isfinite(lp));
for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
it != dish_locs_.end(); ++it) {
const DishLocations& cur = it->second;
- lp += lgamma(cur.total_dish_count_ - discount_) - r;
+ for (std::list<unsigned>::const_iterator ti = cur.table_counts_.begin(); ti != cur.table_counts_.end(); ++ti) {
+ lp += lgamma(*ti - discount) - r;
+ }
}
}
+ assert(std::isfinite(lp));
return lp;
}
+ void resample_hyperparameters(MT19937* rng) {
+ assert(has_discount_prior() || has_concentration_prior());
+ DiscountResampler dr(*this);
+ ConcentrationResampler cr(*this);
+ const int niterations = 10;
+ double gamma_upper = std::numeric_limits<double>::infinity();
+ for (int iter = 0; iter < 5; ++iter) {
+ if (has_concentration_prior()) {
+ concentration_ = slice_sampler1d(cr, concentration_, *rng, 0.0,
+ gamma_upper, 0.0, niterations, 100*niterations);
+ }
+ if (has_discount_prior()) {
+ discount_ = slice_sampler1d(dr, discount_, *rng, std::numeric_limits<double>::min(),
+ 1.0, 0.0, niterations, 100*niterations);
+ }
+ }
+ concentration_ = slice_sampler1d(cr, concentration_, *rng, 0.0,
+ gamma_upper, 0.0, niterations, 100*niterations);
+ }
+
+ struct DiscountResampler {
+ DiscountResampler(const CCRP& crp) : crp_(crp) {}
+ const CCRP& crp_;
+ double operator()(const double& proposed_discount) const {
+ return crp_.log_crp_prob(proposed_discount, crp_.concentration_);
+ }
+ };
+
+ struct ConcentrationResampler {
+ ConcentrationResampler(const CCRP& crp) : crp_(crp) {}
+ const CCRP& crp_;
+ double operator()(const double& proposed_concentration) const {
+ return crp_.log_crp_prob(crp_.discount_, proposed_concentration);
+ }
+ };
+
struct DishLocations {
DishLocations() : total_dish_count_() {}
unsigned total_dish_count_; // customers at all tables with this dish
@@ -166,6 +263,14 @@ class CCRP {
double discount_;
double concentration_;
+
+ // optional beta prior on discount_ (NaN if no prior)
+ double discount_prior_alpha_;
+ double discount_prior_beta_;
+
+ // optional gamma prior on concentration_ (NaN if no prior)
+ double concentration_prior_shape_;
+ double concentration_prior_rate_;
};
template <typename T,typename H>
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc
index 10056bc9..f548997f 100644
--- a/gi/clda/src/clda.cc
+++ b/gi/clda/src/clda.cc
@@ -61,8 +61,8 @@ int main(int argc, char** argv) {
double alpha = 50.0;
const double uniform_topic = 1.0 / num_classes;
const double uniform_word = 1.0 / TD::NumWords();
- vector<CCRP<int> > dr(zji.size(), CCRP<int>(disc, beta)); // dr[i] describes the probability of using a topic in document i
- vector<CCRP<int> > wr(num_classes, CCRP<int>(disc, alpha)); // wr[k] describes the probability of generating a word in topic k
+ vector<CCRP<int> > dr(zji.size(), CCRP<int>(1,1,1,1,disc, beta)); // dr[i] describes the probability of using a topic in document i
+ vector<CCRP<int> > wr(num_classes, CCRP<int>(1,1,1,1,disc, alpha)); // wr[k] describes the probability of generating a word in topic k
for (int j = 0; j < zji.size(); ++j) {
const size_t num_words = wji[j].size();
vector<int>& zj = zji[j];
@@ -89,6 +89,13 @@ int main(int argc, char** argv) {
total_time += timer.Elapsed();
timer.Reset();
double llh = 0;
+#if 1
+ for (int j = 0; j < dr.size(); ++j)
+ dr[j].resample_hyperparameters(&rng);
+ for (int j = 0; j < wr.size(); ++j)
+ wr[j].resample_hyperparameters(&rng);
+#endif
+
for (int j = 0; j < dr.size(); ++j)
llh += dr[j].log_crp_prob();
for (int j = 0; j < wr.size(); ++j)
@@ -120,6 +127,7 @@ int main(int argc, char** argv) {
}
for (int i = 0; i < num_classes; ++i) {
cerr << "---------------------------------\n";
+ cerr << " final PYP(" << wr[i].discount() << "," << wr[i].concentration() << ")\n";
ShowTopWordsForTopic(t2w[i]);
}
cerr << "-------------\n";
diff --git a/gi/clda/src/crp_test.cc b/gi/clda/src/crp_test.cc
index ed384f81..561cd4dd 100644
--- a/gi/clda/src/crp_test.cc
+++ b/gi/clda/src/crp_test.cc
@@ -90,6 +90,12 @@ TEST_F(CRPTest, Exchangability) {
cerr << i << ' ' << (hist[i]) << endl;
}
+TEST_F(CRPTest, LP) {
+ CCRP<string> crp(1,1,1,1,0.1,50.0);
+ crp.increment("foo", 1.0, &rng);
+ cerr << crp.log_crp_prob() << endl;
+}
+
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
diff --git a/gi/clda/src/slice_sampler.h b/gi/clda/src/slice_sampler.h
new file mode 100644
index 00000000..aa48a169
--- /dev/null
+++ b/gi/clda/src/slice_sampler.h
@@ -0,0 +1,191 @@
+//! slice-sampler.h is an MCMC slice sampler
+//!
+//! Mark Johnson, 1st August 2008
+
+#ifndef SLICE_SAMPLER_H
+#define SLICE_SAMPLER_H
+
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <iostream>
+#include <limits>
+
+//! slice_sampler_rfc_type{} returns the value of a user-specified
+//! function if the argument is within range, or - infinity otherwise
+//
+template <typename F, typename Fn, typename U>
+struct slice_sampler_rfc_type {
+ F min_x, max_x;
+ const Fn& f;
+ U max_nfeval, nfeval;
+ slice_sampler_rfc_type(F min_x, F max_x, const Fn& f, U max_nfeval)
+ : min_x(min_x), max_x(max_x), f(f), max_nfeval(max_nfeval), nfeval(0) { }
+
+ F operator() (F x) {
+ if (min_x < x && x < max_x) {
+ assert(++nfeval <= max_nfeval);
+ F fx = f(x);
+ assert(std::isfinite(fx));
+ return fx;
+ }
+ return -std::numeric_limits<F>::infinity();
+ }
+}; // slice_sampler_rfc_type{}
+
+//! slice_sampler1d() implements the univariate "range doubling" slice sampler
+//! described in Neal (2003) "Slice Sampling", The Annals of Statistics 31(3), 705-767.
+//
+template <typename F, typename LogF, typename Uniform01>
+F slice_sampler1d(const LogF& logF0, //!< log of function to sample
+ F x, //!< starting point
+ Uniform01& u01, //!< uniform [0,1) random number generator
+ F min_x = -std::numeric_limits<F>::infinity(), //!< minimum value of support
+ F max_x = std::numeric_limits<F>::infinity(), //!< maximum value of support
+ F w = 0.0, //!< guess at initial width
+ unsigned nsamples=1, //!< number of samples to draw
+ unsigned max_nfeval=200) //!< max number of function evaluations
+{
+ typedef unsigned U;
+ slice_sampler_rfc_type<F,LogF,U> logF(min_x, max_x, logF0, max_nfeval);
+
+ assert(std::isfinite(x));
+
+ if (w <= 0.0) { // set w to a default width
+ if (min_x > -std::numeric_limits<F>::infinity() && max_x < std::numeric_limits<F>::infinity())
+ w = (max_x - min_x)/4;
+ else
+ w = std::max(((x < 0.0) ? -x : x)/4, (F) 0.1);
+ }
+ assert(std::isfinite(w));
+
+ F logFx = logF(x);
+ for (U sample = 0; sample < nsamples; ++sample) {
+ F logY = logFx + log(u01()+1e-100); //! slice logFx at this value
+ assert(std::isfinite(logY));
+
+ F xl = x - w*u01(); //! lower bound on slice interval
+ F logFxl = logF(xl);
+ F xr = xl + w; //! upper bound on slice interval
+ F logFxr = logF(xr);
+
+ while (logY < logFxl || logY < logFxr) // doubling procedure
+ if (u01() < 0.5)
+ logFxl = logF(xl -= xr - xl);
+ else
+ logFxr = logF(xr += xr - xl);
+
+ F xl1 = xl;
+ F xr1 = xr;
+ while (true) { // shrinking procedure
+ F x1 = xl1 + u01()*(xr1 - xl1);
+ if (logY < logF(x1)) {
+ F xl2 = xl; // acceptance procedure
+ F xr2 = xr;
+ bool d = false;
+ while (xr2 - xl2 > 1.1*w) {
+ F xm = (xl2 + xr2)/2;
+ if ((x < xm && x1 >= xm) || (x >= xm && x1 < xm))
+ d = true;
+ if (x1 < xm)
+ xr2 = xm;
+ else
+ xl2 = xm;
+ if (d && logY >= logF(xl2) && logY >= logF(xr2))
+ goto unacceptable;
+ }
+ x = x1;
+ goto acceptable;
+ }
+ goto acceptable;
+ unacceptable:
+ if (x1 < x) // rest of shrinking procedure
+ xl1 = x1;
+ else
+ xr1 = x1;
+ }
+ acceptable:
+ w = (4*w + (xr1 - xl1))/5; // update width estimate
+ }
+ return x;
+}
+
+/*
+//! slice_sampler1d() implements a 1-d MCMC slice sampler.
+//! It should be correct for unimodal distributions, but
+//! not for multimodal ones.
+//
+template <typename F, typename LogP, typename Uniform01>
+F slice_sampler1d(const LogP& logP, //!< log of distribution to sample
+ F x, //!< initial sample
+ Uniform01& u01, //!< uniform random number generator
+ F min_x = -std::numeric_limits<F>::infinity(), //!< minimum value of support
+ F max_x = std::numeric_limits<F>::infinity(), //!< maximum value of support
+ F w = 0.0, //!< guess at initial width
+ unsigned nsamples=1, //!< number of samples to draw
+ unsigned max_nfeval=200) //!< max number of function evaluations
+{
+ typedef unsigned U;
+ assert(std::isfinite(x));
+ if (w <= 0.0) {
+ if (min_x > -std::numeric_limits<F>::infinity() && max_x < std::numeric_limits<F>::infinity())
+ w = (max_x - min_x)/4;
+ else
+ w = std::max(((x < 0.0) ? -x : x)/4, 0.1);
+ }
+ // TRACE4(x, min_x, max_x, w);
+ F logPx = logP(x);
+ assert(std::isfinite(logPx));
+ U nfeval = 1;
+ for (U sample = 0; sample < nsamples; ++sample) {
+ F x0 = x;
+ F logU = logPx + log(u01()+1e-100);
+ assert(std::isfinite(logU));
+ F r = u01();
+ F xl = std::max(min_x, x - r*w);
+ F xr = std::min(max_x, x + (1-r)*w);
+ // TRACE3(x, logPx, logU);
+ while (xl > min_x && logP(xl) > logU) {
+ xl -= w;
+ w *= 2;
+ ++nfeval;
+ if (nfeval >= max_nfeval)
+ std::cerr << "## Error: nfeval = " << nfeval << ", max_nfeval = " << max_nfeval << ", sample = " << sample << ", nsamples = " << nsamples << ", r = " << r << ", w = " << w << ", xl = " << xl << std::endl;
+ assert(nfeval < max_nfeval);
+ }
+ xl = std::max(xl, min_x);
+ while (xr < max_x && logP(xr) > logU) {
+ xr += w;
+ w *= 2;
+ ++nfeval;
+ if (nfeval >= max_nfeval)
+ std::cerr << "## Error: nfeval = " << nfeval << ", max_nfeval = " << max_nfeval << ", sample = " << sample << ", nsamples = " << nsamples << ", r = " << r << ", w = " << w << ", xr = " << xr << std::endl;
+ assert(nfeval < max_nfeval);
+ }
+ xr = std::min(xr, max_x);
+ while (true) {
+ r = u01();
+ x = r*xl + (1-r)*xr;
+ assert(std::isfinite(x));
+ logPx = logP(x);
+ // TRACE4(logPx, x, xl, xr);
+ assert(std::isfinite(logPx));
+ ++nfeval;
+ if (nfeval >= max_nfeval)
+ std::cerr << "## Error: nfeval = " << nfeval << ", max_nfeval = " << max_nfeval << ", sample = " << sample << ", nsamples = " << nsamples << ", r = " << r << ", w = " << w << ", xl = " << xl << ", xr = " << xr << ", x = " << x << std::endl;
+ assert(nfeval < max_nfeval);
+ if (logPx > logU)
+ break;
+ else if (x > x0)
+ xr = x;
+ else
+ xl = x;
+ }
+ // w = (4*w + (xr-xl))/5; // gradually adjust w
+ }
+ // TRACE2(logPx, x);
+ return x;
+} // slice_sampler1d()
+*/
+
+#endif // SLICE_SAMPLER_H