From 2048ac9943e2695a75b5f0303ca869e66ee32202 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 5 Mar 2012 16:06:45 -0500 Subject: use template parameter inference to figure out what type to use for probability computations, templatatize number of floors in MFCR rather than compile-time set --- utils/ccrp.h | 48 ++------------------------------------ utils/mfcr.h | 68 +++++++++++++++++++++++++++++------------------------- utils/mfcr_test.cc | 10 ++++---- 3 files changed, 43 insertions(+), 83 deletions(-) (limited to 'utils') diff --git a/utils/ccrp.h b/utils/ccrp.h index 5f9db7a6..e24130ac 100644 --- a/utils/ccrp.h +++ b/utils/ccrp.h @@ -92,42 +92,9 @@ class CCRP { return it->total_dish_count_; } - // returns +1 or 0 indicating whether a new table was opened - int increment(const Dish& dish, const double& p0, MT19937* rng) { - DishLocations& loc = dish_locs_[dish]; - bool share_table = false; - if (loc.total_dish_count_) { - const double p_empty = (strength_ + num_tables_ * discount_) * p0; - const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_); - share_table = rng->SelectSample(p_empty, p_share); - } - if (share_table) { - double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); - for (typename std::list::iterator ti = loc.table_counts_.begin(); - ti != loc.table_counts_.end(); ++ti) { - r -= (*ti - discount_); - if (r <= 0.0) { - ++(*ti); - break; - } - } - if (r > 0.0) { - std::cerr << "Serious error: r=" << r << std::endl; - Print(&std::cerr); - assert(r <= 0.0); - } - } else { - loc.table_counts_.push_back(1u); - ++num_tables_; - } - ++loc.total_dish_count_; - ++num_customers_; - return (share_table ? 0 : 1); - } - // returns +1 or 0 indicating whether a new table was opened template - int incrementT(const Dish& dish, const T& p0, MT19937* rng) { + int increment(const Dish& dish, const T& p0, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; bool share_table = false; if (loc.total_dish_count_) { @@ -196,19 +163,8 @@ class CCRP { } } - double prob(const Dish& dish, const double& p0) const { - const typename std::tr1::unordered_map::const_iterator it = dish_locs_.find(dish); - const double r = num_tables_ * discount_ + strength_; - if (it == dish_locs_.end()) { - return r * p0 / (num_customers_ + strength_); - } else { - return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) / - (num_customers_ + strength_); - } - } - template - T probT(const Dish& dish, const T& p0) const { + T prob(const Dish& dish, const T& p0) const { const typename std::tr1::unordered_map::const_iterator it = dish_locs_.find(dish); const T r = T(num_tables_ * discount_ + strength_); if (it == dish_locs_.end()) { diff --git a/utils/mfcr.h b/utils/mfcr.h index aeaf599d..6cc0ebf1 100644 --- a/utils/mfcr.h +++ b/utils/mfcr.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include "sampler.h" @@ -35,12 +36,11 @@ std::ostream& operator<<(std::ostream& o, const TableCount& tc) { // referenced therein. // http://www.aclweb.org/anthology/P/P09/P09-2085.pdf // -template > +template > class MFCR { public: - MFCR(unsigned num_floors, double d, double strength) : - num_floors_(num_floors), + MFCR(double d, double strength) : num_tables_(), num_customers_(), discount_(d), @@ -50,8 +50,7 @@ class MFCR { strength_prior_shape_(std::numeric_limits::quiet_NaN()), strength_prior_rate_(std::numeric_limits::quiet_NaN()) {} - MFCR(unsigned num_floors, double discount_strength, double discount_beta, double strength_shape, double strength_rate, double d = 0.9, double strength = 10.0) : - num_floors_(num_floors), + MFCR(double discount_strength, double discount_beta, double strength_shape, double strength_rate, double d = 0.9, double strength = 10.0) : num_tables_(), num_customers_(), discount_(d), @@ -111,22 +110,22 @@ class MFCR { } // returns (delta, floor) indicating whether a new table (delta) was opened and on which floor - TableCount increment(const Dish& dish, const std::vector& p0s, const std::vector& lambdas, MT19937* rng) { - assert(p0s.size() == num_floors_); - assert(lambdas.size() == num_floors_); - + template + TableCount increment(const Dish& dish, InputIterator p0s, InputIterator2 lambdas, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; // marg_p0 = marginal probability of opening a new table on any floor with label dish - const double marg_p0 = std::inner_product(p0s.begin(), p0s.end(), lambdas.begin(), 0.0); - assert(marg_p0 <= 1.0); + typedef typename std::iterator_traits::value_type F; + const F marg_p0 = std::inner_product(p0s, p0s + Floors, lambdas, F(0.0)); + assert(marg_p0 <= F(1.0001)); int floor = -1; bool share_table = false; if (loc.total_dish_count_) { - const double p_empty = (strength_ + num_tables_ * discount_) * marg_p0; - const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_); + const F p_empty = F(strength_ + num_tables_ * discount_) * marg_p0; + const F p_share = F(loc.total_dish_count_ - loc.table_counts_.size() * discount_); share_table = rng->SelectSample(p_empty, p_share); } if (share_table) { + // this can be done with doubles since P0 (which may be tiny) is not involved double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); for (typename std::list::iterator ti = loc.table_counts_.begin(); ti != loc.table_counts_.end(); ++ti) { @@ -143,12 +142,18 @@ class MFCR { assert(r <= 0.0); } } else { // sit at currently empty table -- must sample what floor - double r = rng->next() * marg_p0; - for (unsigned i = 0; i < p0s.size(); ++i) { - r -= p0s[i] * lambdas[i]; - if (r <= 0.0) { - floor = i; - break; + if (Floors == 1) { + floor = 0; + } else { + F r = F(rng->next()) * marg_p0; + for (unsigned i = 0; i < Floors; ++i) { + r -= (*p0s) * (*lambdas); + ++p0s; + ++lambdas; + if (r <= F(0.0)) { + floor = i; + break; + } } } assert(floor >= 0); @@ -200,18 +205,18 @@ class MFCR { return TableCount(delta, floor); } - double prob(const Dish& dish, const std::vector& p0s, const std::vector& lambdas) const { - assert(p0s.size() == num_floors_); - assert(lambdas.size() == num_floors_); - const double marg_p0 = std::inner_product(p0s.begin(), p0s.end(), lambdas.begin(), 0.0); - assert(marg_p0 <= 1.0); + template + typename std::iterator_traits::value_type prob(const Dish& dish, InputIterator p0s, InputIterator2 lambdas) const { + typedef typename std::iterator_traits::value_type F; + const F marg_p0 = std::inner_product(p0s, p0s + Floors, lambdas, F(0.0)); + assert(marg_p0 <= F(1.0001)); const typename std::tr1::unordered_map::const_iterator it = dish_locs_.find(dish); - const double r = num_tables_ * discount_ + strength_; + const F r = F(num_tables_ * discount_ + strength_); if (it == dish_locs_.end()) { - return r * marg_p0 / (num_customers_ + strength_); + return r * marg_p0 / F(num_customers_ + strength_); } else { - return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * marg_p0) / - (num_customers_ + strength_); + return (F(it->second.total_dish_count_ - discount_ * it->second.table_counts_.size()) + F(r * marg_p0)) / + F(num_customers_ + strength_); } } @@ -303,7 +308,7 @@ class MFCR { }; void Print(std::ostream* out) const { - (*out) << "MFCR(d=" << discount_ << ",strength=" << strength_ << ") customers=" << num_customers_ << std::endl; + (*out) << "MFCR<" << Floors << ">(d=" << discount_ << ",strength=" << strength_ << ") customers=" << num_customers_ << std::endl; for (typename std::tr1::unordered_map::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { (*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; @@ -323,7 +328,6 @@ class MFCR { return dish_locs_.end(); } - unsigned num_floors_; unsigned num_tables_; unsigned num_customers_; std::tr1::unordered_map dish_locs_; @@ -340,8 +344,8 @@ class MFCR { double strength_prior_rate_; }; -template -std::ostream& operator<<(std::ostream& o, const MFCR& c) { +template +std::ostream& operator<<(std::ostream& o, const MFCR& c) { c.Print(&o); return o; } diff --git a/utils/mfcr_test.cc b/utils/mfcr_test.cc index 7c45a37c..cc886335 100644 --- a/utils/mfcr_test.cc +++ b/utils/mfcr_test.cc @@ -9,7 +9,7 @@ using namespace std; void test_exch(MT19937* rng) { - MFCR crp(2, 0.5, 3.0); + MFCR<2, int> crp(0.5, 3.0); vector lambdas(2); vector p0s(2); lambdas[0] = 0.2; @@ -22,23 +22,23 @@ void test_exch(MT19937* rng) { double xt = 0; int cust = 10; vector hist(cust + 1, 0), hist2(cust + 1, 0); - for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); } + for (int i = 0; i < cust; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); } const int samples = 100000; const bool simulate = true; for (int k = 0; k < samples; ++k) { if (!simulate) { crp.clear(); - for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); } + for (int i = 0; i < cust; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); } } else { int da = rng->next() * cust; bool a = rng->next() < 0.45; if (a) { - for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); } + for (int i = 0; i < da; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); } for (int i = 0; i < da; ++i) { crp.decrement(1, rng); } xt += 1.0; } else { for (int i = 0; i < da; ++i) { crp.decrement(1, rng); } - for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); } + for (int i = 0; i < da; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); } } } int c = crp.num_tables(1); -- cgit v1.2.3