summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src/mpi-pyp.hh
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics/src/mpi-pyp.hh')
-rw-r--r--gi/pyp-topics/src/mpi-pyp.hh273
1 files changed, 259 insertions, 14 deletions
diff --git a/gi/pyp-topics/src/mpi-pyp.hh b/gi/pyp-topics/src/mpi-pyp.hh
index 58be7c5c..65358d20 100644
--- a/gi/pyp-topics/src/mpi-pyp.hh
+++ b/gi/pyp-topics/src/mpi-pyp.hh
@@ -1,5 +1,5 @@
-#ifndef _pyp_hh
-#define _pyp_hh
+#ifndef _mpipyp_hh
+#define _mpipyp_hh
#include <math.h>
#include <map>
@@ -9,11 +9,15 @@
#include <boost/random/uniform_real.hpp>
#include <boost/random/variate_generator.hpp>
#include <boost/random/mersenne_twister.hpp>
+#include <boost/tuple/tuple.hpp>
+#include <boost/serialization/map.hpp>
+#include <boost/mpi.hpp>
+#include <boost/mpi/environment.hpp>
+#include <boost/mpi/communicator.hpp>
+#include <boost/mpi/operations.hpp>
-#include "pyp.h"
-#include "log_add.h"
-#include "slice-sampler.h"
-#include "mt19937ar.h"
+
+#include "pyp.hh"
//
// Pitman-Yor process with customer and table tracking
@@ -28,25 +32,104 @@ public:
virtual int decrement(Dish d);
void clear();
+ void reset_deltas();
- void reset_deltas() { m_count_delta.clear(); }
+ void synchronise();
private:
typedef std::map<Dish, int> dish_delta_type;
- typedef std::map<Dish, TableCounter> table_delta_type;
+ typedef std::map<Dish, typename PYP<Dish,Hash>::TableCounter> table_delta_type;
dish_delta_type m_count_delta;
table_delta_type m_table_delta;
};
template <typename Dish, typename Hash>
-MPIPYP<Dish,Hash>::MPIPYP(double a, double b, Hash)
-: PYP(a, b, Hash) {}
+MPIPYP<Dish,Hash>::MPIPYP(double a, double b, Hash h)
+: PYP<Dish,Hash>(a, b, 0, h) {}
template <typename Dish, typename Hash>
int
MPIPYP<Dish,Hash>::increment(Dish dish, double p0) {
- int delta = PYP<Dish,Hash>::increment(dish, p0);
+ int delta = 0;
+ int table_joined=-1;
+ typename PYP<Dish,Hash>::TableCounter &tc = PYP<Dish,Hash>::_dish_tables[dish];
+
+ // seated on a new or existing table?
+ int c = PYP<Dish,Hash>::count(dish);
+ int t = PYP<Dish,Hash>::num_tables(dish);
+ int T = PYP<Dish,Hash>::num_tables();
+ double& a = PYP<Dish,Hash>::_a;
+ double& b = PYP<Dish,Hash>::_b;
+ double pshare = (c > 0) ? (c - a*t) : 0.0;
+ double pnew = (b + a*T) * p0;
+ assert (pshare >= 0.0);
+
+ if (mt_genrand_res53() < pnew / (pshare + pnew)) {
+ // assign to a new table
+ tc.tables += 1;
+ tc.table_histogram[1] += 1;
+ PYP<Dish,Hash>::_total_tables += 1;
+ delta = 1;
+ }
+ else {
+ // randomly assign to an existing table
+ // remove constant denominator from inner loop
+ double r = mt_genrand_res53() * (c - a*t);
+ for (std::map<int,int>::iterator
+ hit = tc.table_histogram.begin();
+ hit != tc.table_histogram.end(); ++hit) {
+ r -= ((hit->first - a) * hit->second);
+ if (r <= 0) {
+ tc.table_histogram[hit->first+1] += 1;
+ hit->second -= 1;
+ if (hit->second == 0)
+ tc.table_histogram.erase(hit);
+ table_joined = hit->first+1;
+ break;
+ }
+ }
+ if (r > 0) {
+ std::cerr << r << " " << c << " " << a << " " << t << std::endl;
+ assert(false);
+ }
+ delta = 0;
+ }
+
+ std::tr1::unordered_map<Dish,int,Hash>::operator[](dish) += 1;
+ //google::sparse_hash_map<Dish,int,Hash>::operator[](dish) += 1;
+ PYP<Dish,Hash>::_total_customers += 1;
+
+ // MPI Delta handling
+ // track the customer entering
+ typename dish_delta_type::iterator customer_it;
+ bool customer_insert_result;
+ boost::tie(customer_it, customer_insert_result)
+ = m_count_delta.insert(std::make_pair(dish,0));
+
+ customer_it->second += 1;
+ if (customer_it->second == 0)
+ m_count_delta.erase(customer_it);
+
+ // increment the histogram bar for the table joined
+ if (!delta) {
+ assert (table_joined >= 0);
+ std::map<int,int> &histogram = m_table_delta[dish].table_histogram;
+ typename std::map<int,int>::iterator table_it; bool table_insert_result;
+ boost::tie(table_it, table_insert_result) = histogram.insert(std::make_pair(table_joined,0));
+ table_it->second += 1;
+ if (table_it->second == 0) histogram.erase(table_it);
+
+ // decrement the histogram bar for the table left
+ boost::tie(table_it, table_insert_result) = histogram.insert(std::make_pair(table_joined-1,0));
+ table_it->second -= 1;
+ if (table_it->second == 0) histogram.erase(table_it);
+ }
+ else {
+ typename PYP<Dish,Hash>::TableCounter &delta_tc = m_table_delta[dish];
+ delta_tc.tables += 1;
+ delta_tc.table_histogram[1] += 1;
+ }
return delta;
}
@@ -55,15 +138,177 @@ template <typename Dish, typename Hash>
int
MPIPYP<Dish,Hash>::decrement(Dish dish)
{
- int delta = PYP<Dish,Hash>::decrement(dish);
+ typename std::tr1::unordered_map<Dish, int>::iterator dcit = find(dish);
+ //typename google::sparse_hash_map<Dish, int>::iterator dcit = find(dish);
+ if (dcit == PYP<Dish,Hash>::end()) {
+ std::cerr << dish << std::endl;
+ assert(false);
+ }
+
+ int delta = 0, table_left=-1;
+
+ typename std::tr1::unordered_map<Dish, typename PYP<Dish,Hash>::TableCounter>::iterator dtit
+ = PYP<Dish,Hash>::_dish_tables.find(dish);
+ //typename google::sparse_hash_map<Dish, TableCounter>::iterator dtit = _dish_tables.find(dish);
+ if (dtit == PYP<Dish,Hash>::_dish_tables.end()) {
+ std::cerr << dish << std::endl;
+ assert(false);
+ }
+ typename PYP<Dish,Hash>::TableCounter &tc = dtit->second;
+
+ double r = mt_genrand_res53() * PYP<Dish,Hash>::count(dish);
+ for (std::map<int,int>::iterator hit = tc.table_histogram.begin();
+ hit != tc.table_histogram.end(); ++hit) {
+ r -= (hit->first * hit->second);
+ if (r <= 0) {
+ table_left = hit->first;
+ if (hit->first > 1) {
+ tc.table_histogram[hit->first-1] += 1;
+ }
+ else {
+ delta = -1;
+ tc.tables -= 1;
+ PYP<Dish,Hash>::_total_tables -= 1;
+ }
+
+ hit->second -= 1;
+ if (hit->second == 0) tc.table_histogram.erase(hit);
+ break;
+ }
+ }
+ if (r > 0) {
+ std::cerr << r << " " << PYP<Dish,Hash>::count(dish) << " " << PYP<Dish,Hash>::_a << " "
+ << PYP<Dish,Hash>::num_tables(dish) << std::endl;
+ assert(false);
+ }
+
+ // remove the customer
+ dcit->second -= 1;
+ PYP<Dish,Hash>::_total_customers -= 1;
+ assert(dcit->second >= 0);
+ if (dcit->second == 0) {
+ PYP<Dish,Hash>::erase(dcit);
+ PYP<Dish,Hash>::_dish_tables.erase(dtit);
+ }
+
+ typename dish_delta_type::iterator it;
+ bool insert_result;
+ boost::tie(it, insert_result) = m_count_delta.insert(std::make_pair(dish,0));
+
+ it->second -= 1;
+
+ if (it->second == 0)
+ m_count_delta.erase(it);
+
+ assert (table_left >= 0);
+ typename PYP<Dish,Hash>::TableCounter& delta_tc = m_table_delta[dish];
+ if (table_left > 1)
+ delta_tc.table_histogram[table_left-1] += 1;
+ else delta_tc.tables -= 1;
+
+ std::map<int,int>::iterator tit = delta_tc.table_histogram.find(table_left);
+ //assert (tit != delta_tc.table_histogram.end());
+ tit->second -= 1;
+ if (tit->second == 0) delta_tc.table_histogram.erase(tit);
+
return delta;
}
template <typename Dish, typename Hash>
void
-MPIPYP<Dish,Hash>::clear()
-{
+MPIPYP<Dish,Hash>::clear() {
PYP<Dish,Hash>::clear();
+ reset_deltas();
+}
+
+template <typename Dish, typename Hash>
+void
+MPIPYP<Dish,Hash>::reset_deltas() {
+ m_count_delta.clear();
+ m_table_delta.clear();
+}
+
+template <typename Dish>
+struct sum_maps {
+ typedef std::map<Dish,int> map_type;
+ map_type& operator() (map_type& l, map_type const & r) const {
+ for (typename map_type::const_iterator it=r.begin(); it != r.end(); it++)
+ l[it->first] += it->second;
+ return l;
+ }
+};
+
+// Needed Boost definitions
+namespace boost {
+ namespace mpi {
+ template <>
+ struct is_commutative< sum_maps<int>, std::map<int,int> > : mpl::true_ {};
+ }
+
+ namespace serialization {
+ template<class Archive>
+ void serialize(Archive & ar, PYP<int>::TableCounter& t, const unsigned int version) {
+ ar & t.table_histogram;
+ ar & t.tables;
+ }
+
+ } // namespace serialization
+} // namespace boost
+
+
+template <typename Dish, typename Hash>
+void
+MPIPYP<Dish,Hash>::synchronise() {
+ boost::mpi::communicator world;
+ int rank = world.rank(), size = world.size();
+
+ // communicate the customer count deltas
+ dish_delta_type global_dish_delta; // the “merged” map
+ boost::mpi::all_reduce(world, m_count_delta, global_dish_delta, sum_maps<Dish>());
+
+ // update this restaurant
+ for (typename dish_delta_type::const_iterator it=global_dish_delta.begin();
+ it != global_dish_delta.end(); ++it) {
+ std::tr1::unordered_map<Dish,int,Hash>::operator[](it->first) += (it->second - m_count_delta[it->first]);
+ PYP<Dish,Hash>::_total_customers += (it->second - m_count_delta[it->first]);
+ //std::cerr << "Process " << rank << " adding " << (it->second - m_count_delta[it->first]) << " customers." << std::endl;
+ }
+
+ // communicate the table count deltas
+// for (int process = 0; process < size; ++process) {
+// if (rank == process) {
+// // broadcast deltas
+// std::cerr << " -- Rank " << rank << " broadcasting -- " << std::endl;
+//
+// boost::mpi::broadcast(world, m_table_delta, process);
+//
+// std::cerr << " -- Rank " << rank << " done broadcasting -- " << std::endl;
+// }
+// else {
+// std::cerr << " -- Rank " << rank << " receiving -- " << std::endl;
+// // receive deltas
+// table_delta_type recv_table_delta;
+//
+// boost::mpi::broadcast(world, recv_table_delta, process);
+//
+// std::cerr << " -- Rank " << rank << " done receiving -- " << std::endl;
+//
+// for (typename table_delta_type::const_iterator dish_it=recv_table_delta.begin();
+// dish_it != recv_table_delta.end(); ++dish_it) {
+// typename PYP<Dish,Hash>::TableCounter &tc = PYP<Dish,Hash>::_dish_tables[dish_it->first];
+//
+// for (std::map<int,int>::const_iterator it=dish_it->second.table_histogram.begin();
+// it != dish_it->second.table_histogram.end(); ++it) {
+// tc.table_histogram[it->first] += it->second;
+// }
+// tc.tables += dish_it->second.tables;
+// PYP<Dish,Hash>::_total_tables += dish_it->second.tables;
+// }
+// }
+// }
+// std::cerr << " -- Done Reducing -- " << std::endl;
+
+ reset_deltas();
}
#endif