diff options
Diffstat (limited to 'gi/pyp-topics/src/mpi-pyp.hh')
-rw-r--r-- | gi/pyp-topics/src/mpi-pyp.hh | 273 |
1 files changed, 259 insertions, 14 deletions
diff --git a/gi/pyp-topics/src/mpi-pyp.hh b/gi/pyp-topics/src/mpi-pyp.hh index 58be7c5c..65358d20 100644 --- a/gi/pyp-topics/src/mpi-pyp.hh +++ b/gi/pyp-topics/src/mpi-pyp.hh @@ -1,5 +1,5 @@ -#ifndef _pyp_hh -#define _pyp_hh +#ifndef _mpipyp_hh +#define _mpipyp_hh #include <math.h> #include <map> @@ -9,11 +9,15 @@ #include <boost/random/uniform_real.hpp> #include <boost/random/variate_generator.hpp> #include <boost/random/mersenne_twister.hpp> +#include <boost/tuple/tuple.hpp> +#include <boost/serialization/map.hpp> +#include <boost/mpi.hpp> +#include <boost/mpi/environment.hpp> +#include <boost/mpi/communicator.hpp> +#include <boost/mpi/operations.hpp> -#include "pyp.h" -#include "log_add.h" -#include "slice-sampler.h" -#include "mt19937ar.h" + +#include "pyp.hh" // // Pitman-Yor process with customer and table tracking @@ -28,25 +32,104 @@ public: virtual int decrement(Dish d); void clear(); + void reset_deltas(); - void reset_deltas() { m_count_delta.clear(); } + void synchronise(); private: typedef std::map<Dish, int> dish_delta_type; - typedef std::map<Dish, TableCounter> table_delta_type; + typedef std::map<Dish, typename PYP<Dish,Hash>::TableCounter> table_delta_type; dish_delta_type m_count_delta; table_delta_type m_table_delta; }; template <typename Dish, typename Hash> -MPIPYP<Dish,Hash>::MPIPYP(double a, double b, Hash) -: PYP(a, b, Hash) {} +MPIPYP<Dish,Hash>::MPIPYP(double a, double b, Hash h) +: PYP<Dish,Hash>(a, b, 0, h) {} template <typename Dish, typename Hash> int MPIPYP<Dish,Hash>::increment(Dish dish, double p0) { - int delta = PYP<Dish,Hash>::increment(dish, p0); + int delta = 0; + int table_joined=-1; + typename PYP<Dish,Hash>::TableCounter &tc = PYP<Dish,Hash>::_dish_tables[dish]; + + // seated on a new or existing table? + int c = PYP<Dish,Hash>::count(dish); + int t = PYP<Dish,Hash>::num_tables(dish); + int T = PYP<Dish,Hash>::num_tables(); + double& a = PYP<Dish,Hash>::_a; + double& b = PYP<Dish,Hash>::_b; + double pshare = (c > 0) ? (c - a*t) : 0.0; + double pnew = (b + a*T) * p0; + assert (pshare >= 0.0); + + if (mt_genrand_res53() < pnew / (pshare + pnew)) { + // assign to a new table + tc.tables += 1; + tc.table_histogram[1] += 1; + PYP<Dish,Hash>::_total_tables += 1; + delta = 1; + } + else { + // randomly assign to an existing table + // remove constant denominator from inner loop + double r = mt_genrand_res53() * (c - a*t); + for (std::map<int,int>::iterator + hit = tc.table_histogram.begin(); + hit != tc.table_histogram.end(); ++hit) { + r -= ((hit->first - a) * hit->second); + if (r <= 0) { + tc.table_histogram[hit->first+1] += 1; + hit->second -= 1; + if (hit->second == 0) + tc.table_histogram.erase(hit); + table_joined = hit->first+1; + break; + } + } + if (r > 0) { + std::cerr << r << " " << c << " " << a << " " << t << std::endl; + assert(false); + } + delta = 0; + } + + std::tr1::unordered_map<Dish,int,Hash>::operator[](dish) += 1; + //google::sparse_hash_map<Dish,int,Hash>::operator[](dish) += 1; + PYP<Dish,Hash>::_total_customers += 1; + + // MPI Delta handling + // track the customer entering + typename dish_delta_type::iterator customer_it; + bool customer_insert_result; + boost::tie(customer_it, customer_insert_result) + = m_count_delta.insert(std::make_pair(dish,0)); + + customer_it->second += 1; + if (customer_it->second == 0) + m_count_delta.erase(customer_it); + + // increment the histogram bar for the table joined + if (!delta) { + assert (table_joined >= 0); + std::map<int,int> &histogram = m_table_delta[dish].table_histogram; + typename std::map<int,int>::iterator table_it; bool table_insert_result; + boost::tie(table_it, table_insert_result) = histogram.insert(std::make_pair(table_joined,0)); + table_it->second += 1; + if (table_it->second == 0) histogram.erase(table_it); + + // decrement the histogram bar for the table left + boost::tie(table_it, table_insert_result) = histogram.insert(std::make_pair(table_joined-1,0)); + table_it->second -= 1; + if (table_it->second == 0) histogram.erase(table_it); + } + else { + typename PYP<Dish,Hash>::TableCounter &delta_tc = m_table_delta[dish]; + delta_tc.tables += 1; + delta_tc.table_histogram[1] += 1; + } return delta; } @@ -55,15 +138,177 @@ template <typename Dish, typename Hash> int MPIPYP<Dish,Hash>::decrement(Dish dish) { - int delta = PYP<Dish,Hash>::decrement(dish); + typename std::tr1::unordered_map<Dish, int>::iterator dcit = find(dish); + //typename google::sparse_hash_map<Dish, int>::iterator dcit = find(dish); + if (dcit == PYP<Dish,Hash>::end()) { + std::cerr << dish << std::endl; + assert(false); + } + + int delta = 0, table_left=-1; + + typename std::tr1::unordered_map<Dish, typename PYP<Dish,Hash>::TableCounter>::iterator dtit + = PYP<Dish,Hash>::_dish_tables.find(dish); + //typename google::sparse_hash_map<Dish, TableCounter>::iterator dtit = _dish_tables.find(dish); + if (dtit == PYP<Dish,Hash>::_dish_tables.end()) { + std::cerr << dish << std::endl; + assert(false); + } + typename PYP<Dish,Hash>::TableCounter &tc = dtit->second; + + double r = mt_genrand_res53() * PYP<Dish,Hash>::count(dish); + for (std::map<int,int>::iterator hit = tc.table_histogram.begin(); + hit != tc.table_histogram.end(); ++hit) { + r -= (hit->first * hit->second); + if (r <= 0) { + table_left = hit->first; + if (hit->first > 1) { + tc.table_histogram[hit->first-1] += 1; + } + else { + delta = -1; + tc.tables -= 1; + PYP<Dish,Hash>::_total_tables -= 1; + } + + hit->second -= 1; + if (hit->second == 0) tc.table_histogram.erase(hit); + break; + } + } + if (r > 0) { + std::cerr << r << " " << PYP<Dish,Hash>::count(dish) << " " << PYP<Dish,Hash>::_a << " " + << PYP<Dish,Hash>::num_tables(dish) << std::endl; + assert(false); + } + + // remove the customer + dcit->second -= 1; + PYP<Dish,Hash>::_total_customers -= 1; + assert(dcit->second >= 0); + if (dcit->second == 0) { + PYP<Dish,Hash>::erase(dcit); + PYP<Dish,Hash>::_dish_tables.erase(dtit); + } + + typename dish_delta_type::iterator it; + bool insert_result; + boost::tie(it, insert_result) = m_count_delta.insert(std::make_pair(dish,0)); + + it->second -= 1; + + if (it->second == 0) + m_count_delta.erase(it); + + assert (table_left >= 0); + typename PYP<Dish,Hash>::TableCounter& delta_tc = m_table_delta[dish]; + if (table_left > 1) + delta_tc.table_histogram[table_left-1] += 1; + else delta_tc.tables -= 1; + + std::map<int,int>::iterator tit = delta_tc.table_histogram.find(table_left); + //assert (tit != delta_tc.table_histogram.end()); + tit->second -= 1; + if (tit->second == 0) delta_tc.table_histogram.erase(tit); + return delta; } template <typename Dish, typename Hash> void -MPIPYP<Dish,Hash>::clear() -{ +MPIPYP<Dish,Hash>::clear() { PYP<Dish,Hash>::clear(); + reset_deltas(); +} + +template <typename Dish, typename Hash> +void +MPIPYP<Dish,Hash>::reset_deltas() { + m_count_delta.clear(); + m_table_delta.clear(); +} + +template <typename Dish> +struct sum_maps { + typedef std::map<Dish,int> map_type; + map_type& operator() (map_type& l, map_type const & r) const { + for (typename map_type::const_iterator it=r.begin(); it != r.end(); it++) + l[it->first] += it->second; + return l; + } +}; + +// Needed Boost definitions +namespace boost { + namespace mpi { + template <> + struct is_commutative< sum_maps<int>, std::map<int,int> > : mpl::true_ {}; + } + + namespace serialization { + template<class Archive> + void serialize(Archive & ar, PYP<int>::TableCounter& t, const unsigned int version) { + ar & t.table_histogram; + ar & t.tables; + } + + } // namespace serialization +} // namespace boost + + +template <typename Dish, typename Hash> +void +MPIPYP<Dish,Hash>::synchronise() { + boost::mpi::communicator world; + int rank = world.rank(), size = world.size(); + + // communicate the customer count deltas + dish_delta_type global_dish_delta; // the “merged” map + boost::mpi::all_reduce(world, m_count_delta, global_dish_delta, sum_maps<Dish>()); + + // update this restaurant + for (typename dish_delta_type::const_iterator it=global_dish_delta.begin(); + it != global_dish_delta.end(); ++it) { + std::tr1::unordered_map<Dish,int,Hash>::operator[](it->first) += (it->second - m_count_delta[it->first]); + PYP<Dish,Hash>::_total_customers += (it->second - m_count_delta[it->first]); + //std::cerr << "Process " << rank << " adding " << (it->second - m_count_delta[it->first]) << " customers." << std::endl; + } + + // communicate the table count deltas +// for (int process = 0; process < size; ++process) { +// if (rank == process) { +// // broadcast deltas +// std::cerr << " -- Rank " << rank << " broadcasting -- " << std::endl; +// +// boost::mpi::broadcast(world, m_table_delta, process); +// +// std::cerr << " -- Rank " << rank << " done broadcasting -- " << std::endl; +// } +// else { +// std::cerr << " -- Rank " << rank << " receiving -- " << std::endl; +// // receive deltas +// table_delta_type recv_table_delta; +// +// boost::mpi::broadcast(world, recv_table_delta, process); +// +// std::cerr << " -- Rank " << rank << " done receiving -- " << std::endl; +// +// for (typename table_delta_type::const_iterator dish_it=recv_table_delta.begin(); +// dish_it != recv_table_delta.end(); ++dish_it) { +// typename PYP<Dish,Hash>::TableCounter &tc = PYP<Dish,Hash>::_dish_tables[dish_it->first]; +// +// for (std::map<int,int>::const_iterator it=dish_it->second.table_histogram.begin(); +// it != dish_it->second.table_histogram.end(); ++it) { +// tc.table_histogram[it->first] += it->second; +// } +// tc.tables += dish_it->second.tables; +// PYP<Dish,Hash>::_total_tables += dish_it->second.tables; +// } +// } +// } +// std::cerr << " -- Done Reducing -- " << std::endl; + + reset_deltas(); } #endif |