diff options
Diffstat (limited to 'utils')
| -rw-r--r-- | utils/ccrp.h | 48 | ||||
| -rw-r--r-- | utils/mfcr.h | 68 | ||||
| -rw-r--r-- | utils/mfcr_test.cc | 10 | 
3 files changed, 43 insertions, 83 deletions
| diff --git a/utils/ccrp.h b/utils/ccrp.h index 5f9db7a6..e24130ac 100644 --- a/utils/ccrp.h +++ b/utils/ccrp.h @@ -93,41 +93,8 @@ class CCRP {    }    // returns +1 or 0 indicating whether a new table was opened -  int increment(const Dish& dish, const double& p0, MT19937* rng) { -    DishLocations& loc = dish_locs_[dish]; -    bool share_table = false; -    if (loc.total_dish_count_) { -      const double p_empty = (strength_ + num_tables_ * discount_) * p0; -      const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_); -      share_table = rng->SelectSample(p_empty, p_share); -    } -    if (share_table) { -      double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); -      for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); -           ti != loc.table_counts_.end(); ++ti) { -        r -= (*ti - discount_); -        if (r <= 0.0) { -          ++(*ti); -          break; -        } -      } -      if (r > 0.0) { -        std::cerr << "Serious error: r=" << r << std::endl; -        Print(&std::cerr); -        assert(r <= 0.0); -      } -    } else { -      loc.table_counts_.push_back(1u); -      ++num_tables_; -    } -    ++loc.total_dish_count_; -    ++num_customers_; -    return (share_table ? 0 : 1); -  } - -  // returns +1 or 0 indicating whether a new table was opened    template <typename T> -  int incrementT(const Dish& dish, const T& p0, MT19937* rng) { +  int increment(const Dish& dish, const T& p0, MT19937* rng) {      DishLocations& loc = dish_locs_[dish];      bool share_table = false;      if (loc.total_dish_count_) { @@ -196,19 +163,8 @@ class CCRP {      }    } -  double prob(const Dish& dish, const double& p0) const { -    const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); -    const double r = num_tables_ * discount_ + strength_; -    if (it == dish_locs_.end()) { -      return r * p0 / (num_customers_ + strength_); -    } else { -      return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) / -               (num_customers_ + strength_); -    } -  } -    template <typename T> -  T probT(const Dish& dish, const T& p0) const { +  T prob(const Dish& dish, const T& p0) const {      const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);      const T r = T(num_tables_ * discount_ + strength_);      if (it == dish_locs_.end()) { diff --git a/utils/mfcr.h b/utils/mfcr.h index aeaf599d..6cc0ebf1 100644 --- a/utils/mfcr.h +++ b/utils/mfcr.h @@ -8,6 +8,7 @@  #include <list>  #include <iostream>  #include <vector> +#include <iterator>  #include <tr1/unordered_map>  #include <boost/functional/hash.hpp>  #include "sampler.h" @@ -35,12 +36,11 @@ std::ostream& operator<<(std::ostream& o, const TableCount& tc) {  // referenced therein.  // http://www.aclweb.org/anthology/P/P09/P09-2085.pdf  // -template <typename Dish, typename DishHash = boost::hash<Dish> > +template <unsigned Floors, typename Dish, typename DishHash = boost::hash<Dish> >  class MFCR {   public: -  MFCR(unsigned num_floors, double d, double strength) : -    num_floors_(num_floors), +  MFCR(double d, double strength) :      num_tables_(),      num_customers_(),      discount_(d), @@ -50,8 +50,7 @@ class MFCR {      strength_prior_shape_(std::numeric_limits<double>::quiet_NaN()),      strength_prior_rate_(std::numeric_limits<double>::quiet_NaN()) {} -  MFCR(unsigned num_floors, double discount_strength, double discount_beta, double strength_shape, double strength_rate, double d = 0.9, double strength = 10.0) : -    num_floors_(num_floors), +  MFCR(double discount_strength, double discount_beta, double strength_shape, double strength_rate, double d = 0.9, double strength = 10.0) :      num_tables_(),      num_customers_(),      discount_(d), @@ -111,22 +110,22 @@ class MFCR {    }    // returns (delta, floor) indicating whether a new table (delta) was opened and on which floor -  TableCount increment(const Dish& dish, const std::vector<double>& p0s, const std::vector<double>& lambdas, MT19937* rng) { -    assert(p0s.size() == num_floors_); -    assert(lambdas.size() == num_floors_); - +  template <class InputIterator, class InputIterator2> +  TableCount increment(const Dish& dish, InputIterator p0s, InputIterator2 lambdas, MT19937* rng) {      DishLocations& loc = dish_locs_[dish];      // marg_p0 = marginal probability of opening a new table on any floor with label dish -    const double marg_p0 = std::inner_product(p0s.begin(), p0s.end(), lambdas.begin(), 0.0); -    assert(marg_p0 <= 1.0); +    typedef typename std::iterator_traits<InputIterator>::value_type F; +    const F marg_p0 = std::inner_product(p0s, p0s + Floors, lambdas, F(0.0)); +    assert(marg_p0 <= F(1.0001));      int floor = -1;      bool share_table = false;      if (loc.total_dish_count_) { -      const double p_empty = (strength_ + num_tables_ * discount_) * marg_p0; -      const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_); +      const F p_empty = F(strength_ + num_tables_ * discount_) * marg_p0; +      const F p_share = F(loc.total_dish_count_ - loc.table_counts_.size() * discount_);        share_table = rng->SelectSample(p_empty, p_share);      }      if (share_table) { +      // this can be done with doubles since P0 (which may be tiny) is not involved        double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_);        for (typename std::list<TableCount>::iterator ti = loc.table_counts_.begin();             ti != loc.table_counts_.end(); ++ti) { @@ -143,12 +142,18 @@ class MFCR {          assert(r <= 0.0);        }      } else { // sit at currently empty table -- must sample what floor -      double r = rng->next() * marg_p0; -      for (unsigned i = 0; i < p0s.size(); ++i) { -        r -= p0s[i] * lambdas[i]; -        if (r <= 0.0) { -          floor = i; -          break; +      if (Floors == 1) { +        floor = 0; +      } else { +        F r = F(rng->next()) * marg_p0; +        for (unsigned i = 0; i < Floors; ++i) { +          r -= (*p0s) * (*lambdas); +          ++p0s; +          ++lambdas; +          if (r <= F(0.0)) { +            floor = i; +            break; +          }          }        }        assert(floor >= 0); @@ -200,18 +205,18 @@ class MFCR {      return TableCount(delta, floor);    } -  double prob(const Dish& dish, const std::vector<double>& p0s, const std::vector<double>& lambdas) const { -    assert(p0s.size() == num_floors_); -    assert(lambdas.size() == num_floors_); -    const double marg_p0 = std::inner_product(p0s.begin(), p0s.end(), lambdas.begin(), 0.0); -    assert(marg_p0 <= 1.0); +  template <class InputIterator, class InputIterator2> +  typename std::iterator_traits<InputIterator>::value_type prob(const Dish& dish, InputIterator p0s, InputIterator2 lambdas) const { +    typedef typename std::iterator_traits<InputIterator>::value_type F; +    const F marg_p0 = std::inner_product(p0s, p0s + Floors, lambdas, F(0.0)); +    assert(marg_p0 <= F(1.0001));      const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); -    const double r = num_tables_ * discount_ + strength_; +    const F r = F(num_tables_ * discount_ + strength_);      if (it == dish_locs_.end()) { -      return r * marg_p0 / (num_customers_ + strength_); +      return r * marg_p0 / F(num_customers_ + strength_);      } else { -      return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * marg_p0) / -               (num_customers_ + strength_); +      return (F(it->second.total_dish_count_ - discount_ * it->second.table_counts_.size()) + F(r * marg_p0)) / +               F(num_customers_ + strength_);      }    } @@ -303,7 +308,7 @@ class MFCR {    };    void Print(std::ostream* out) const { -    (*out) << "MFCR(d=" << discount_ << ",strength=" << strength_ << ") customers=" << num_customers_ << std::endl; +    (*out) << "MFCR<" << Floors << ">(d=" << discount_ << ",strength=" << strength_ << ") customers=" << num_customers_ << std::endl;      for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();           it != dish_locs_.end(); ++it) {        (*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; @@ -323,7 +328,6 @@ class MFCR {      return dish_locs_.end();    } -  unsigned num_floors_;    unsigned num_tables_;    unsigned num_customers_;    std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_; @@ -340,8 +344,8 @@ class MFCR {    double strength_prior_rate_;  }; -template <typename T,typename H> -std::ostream& operator<<(std::ostream& o, const MFCR<T,H>& c) { +template <unsigned N,typename T,typename H> +std::ostream& operator<<(std::ostream& o, const MFCR<N,T,H>& c) {    c.Print(&o);    return o;  } diff --git a/utils/mfcr_test.cc b/utils/mfcr_test.cc index 7c45a37c..cc886335 100644 --- a/utils/mfcr_test.cc +++ b/utils/mfcr_test.cc @@ -9,7 +9,7 @@  using namespace std;  void test_exch(MT19937* rng) { -  MFCR<int> crp(2, 0.5, 3.0); +  MFCR<2, int> crp(0.5, 3.0);    vector<double> lambdas(2);    vector<double> p0s(2);    lambdas[0] = 0.2; @@ -22,23 +22,23 @@ void test_exch(MT19937* rng) {    double xt = 0;    int cust = 10;    vector<int> hist(cust + 1, 0), hist2(cust + 1, 0); -  for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); } +  for (int i = 0; i < cust; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }    const int samples = 100000;    const bool simulate = true;    for (int k = 0; k < samples; ++k) {      if (!simulate) {        crp.clear(); -      for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); } +      for (int i = 0; i < cust; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }      } else {        int da = rng->next() * cust;        bool a = rng->next() < 0.45;        if (a) { -        for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); } +        for (int i = 0; i < da; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }          for (int i = 0; i < da; ++i) { crp.decrement(1, rng); }          xt += 1.0;        } else {          for (int i = 0; i < da; ++i) { crp.decrement(1, rng); } -        for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); } +        for (int i = 0; i < da; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }        }      }      int c = crp.num_tables(1); | 
