summaryrefslogtreecommitdiff
path: root/utils/sampler.h
diff options
context:
space:
mode:
Diffstat (limited to 'utils/sampler.h')
-rw-r--r--utils/sampler.h34
1 files changed, 19 insertions, 15 deletions
diff --git a/utils/sampler.h b/utils/sampler.h
index 5fef45d0..f75d96b6 100644
--- a/utils/sampler.h
+++ b/utils/sampler.h
@@ -18,7 +18,7 @@
#include "prob.h"
-struct SampleSet;
+template <typename F> struct SampleSet;
template <typename RNG>
struct RandomNumberGenerator {
@@ -45,7 +45,8 @@ struct RandomNumberGenerator {
m_generator.seed(seed);
}
- size_t SelectSample(const prob_t& a, const prob_t& b, double T = 1.0) {
+ template <typename F>
+ size_t SelectSample(const F& a, const F& b, double T = 1.0) {
if (T == 1.0) {
if (this->next() > (a / (a + b))) return 1; else return 0;
} else {
@@ -54,7 +55,8 @@ struct RandomNumberGenerator {
}
// T is the annealing temperature, if desired
- size_t SelectSample(const SampleSet& ss, double T = 1.0);
+ template <typename F>
+ size_t SelectSample(const SampleSet<F>& ss, double T = 1.0);
// draw a value from U(0,1)
double next() {return m_random();}
@@ -94,36 +96,38 @@ struct RandomNumberGenerator {
typedef RandomNumberGenerator<boost::mt19937> MT19937;
+template <typename F = double>
class SampleSet {
public:
- const prob_t& operator[](int i) const { return m_scores[i]; }
- prob_t& operator[](int i) { return m_scores[i]; }
+ const F& operator[](int i) const { return m_scores[i]; }
+ F& operator[](int i) { return m_scores[i]; }
bool empty() const { return m_scores.empty(); }
void add(const prob_t& s) { m_scores.push_back(s); }
void clear() { m_scores.clear(); }
size_t size() const { return m_scores.size(); }
void resize(int size) { m_scores.resize(size); }
- std::vector<prob_t> m_scores;
+ std::vector<F> m_scores;
};
template <typename RNG>
-size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet& ss, double T) {
+template <typename F>
+size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet<F>& ss, double T) {
assert(T > 0.0);
assert(ss.m_scores.size() > 0);
if (ss.m_scores.size() == 1) return 0;
- const prob_t annealing_factor(1.0 / T);
- const bool anneal = (annealing_factor != prob_t::One());
- prob_t sum = prob_t::Zero();
+ const double annealing_factor = 1.0 / T;
+ const bool anneal = (T != 1.0);
+ F sum = F(0);
if (anneal) {
for (int i = 0; i < ss.m_scores.size(); ++i)
- sum += ss.m_scores[i].pow(annealing_factor); // p^(1/T)
+ sum += pow(ss.m_scores[i], annealing_factor); // p^(1/T)
} else {
- sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), prob_t::Zero());
+ sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), F(0));
}
//for (size_t i = 0; i < ss.m_scores.size(); ++i) std::cerr << ss.m_scores[i] << ",";
//std::cerr << std::endl;
- prob_t random(this->next()); // random number between 0 and 1
+ F random(this->next()); // random number between 0 and 1
random *= sum; // scale with normalization factor
//std::cerr << "Random number " << random << std::endl;
@@ -131,9 +135,9 @@ size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet& ss, double T) {
size_t position = 1;
sum = ss.m_scores[0];
if (anneal) {
- sum.poweq(annealing_factor);
+ sum = pow(sum, annealing_factor);
for (; position < ss.m_scores.size() && sum < random; ++position)
- sum += ss.m_scores[position].pow(annealing_factor);
+ sum += pow(ss.m_scores[position], annealing_factor);
} else {
for (; position < ss.m_scores.size() && sum < random; ++position)
sum += ss.m_scores[position];