summaryrefslogtreecommitdiff
path: root/utils/sampler.h
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-06-18 20:28:42 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-06-18 20:28:42 -0400
commitb89a1d3cb72ac36c137d6ae342f48ab9b8ee6655 (patch)
tree74dbff7519a3f3fe6906fff44128563300fec19b /utils/sampler.h
parent953ec50e659084c13433ea311f6a07e7e1b292f8 (diff)
add non-const iterators to sparse vector, speed up model1 code
Diffstat (limited to 'utils/sampler.h')
-rw-r--r--utils/sampler.h16
1 files changed, 15 insertions, 1 deletions
diff --git a/utils/sampler.h b/utils/sampler.h
index b237c716..3e4a4086 100644
--- a/utils/sampler.h
+++ b/utils/sampler.h
@@ -12,6 +12,7 @@
#include <boost/random/mersenne_twister.hpp>
#include <boost/random/uniform_real.hpp>
#include <boost/random/variate_generator.hpp>
+#include <boost/random/gamma_distribution.hpp>
#include <boost/random/normal_distribution.hpp>
#include <boost/random/poisson_distribution.hpp>
#include <boost/random/uniform_int.hpp>
@@ -76,6 +77,18 @@ struct RandomNumberGenerator {
return boost::poisson_distribution<int>(lambda)(m_random);
}
+ double NextGamma(double shape, double scale = 1.0) {
+ boost::gamma_distribution<> gamma(shape);
+ boost::variate_generator<boost::mt19937&,boost::gamma_distribution<> > vg(m_generator, gamma);
+ return vg() * scale;
+ }
+
+ double NextBeta(double alpha, double beta) {
+ double x = NextGamma(alpha);
+ double y = NextGamma(beta);
+ return x / (x + y);
+ }
+
bool AcceptMetropolisHastings(const prob_t& p_cur,
const prob_t& p_prev,
const prob_t& q_cur,
@@ -123,11 +136,12 @@ size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet<F>& ss, double T
const bool anneal = (T != 1.0);
F sum = F(0);
if (anneal) {
- for (int i = 0; i < ss.m_scores.size(); ++i)
+ for (unsigned i = 0; i < ss.m_scores.size(); ++i)
sum += pow(ss.m_scores[i], annealing_factor); // p^(1/T)
} else {
sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), F(0));
}
+ //std::cerr << "SUM: " << sum << std::endl;
//for (size_t i = 0; i < ss.m_scores.size(); ++i) std::cerr << ss.m_scores[i] << ",";
//std::cerr << std::endl;