summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src/mpi-pyp.hh
blob: 58be7c5c77d9115b5258c45d2c798e693a0ed12f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#ifndef _pyp_hh
#define _pyp_hh

#include <math.h>
#include <map>
#include <tr1/unordered_map>
//#include <google/sparse_hash_map>

#include <boost/random/uniform_real.hpp>
#include <boost/random/variate_generator.hpp>
#include <boost/random/mersenne_twister.hpp>

#include "pyp.h"
#include "log_add.h"
#include "slice-sampler.h"
#include "mt19937ar.h"

//
// Pitman-Yor process with customer and table tracking
//

template <typename Dish, typename Hash=std::tr1::hash<Dish> >
class MPIPYP : public PYP<Dish, Hash> {
public:
  MPIPYP(double a, double b, Hash hash=Hash());

  virtual int increment(Dish d, double p0);
  virtual int decrement(Dish d);

  void clear();

  void reset_deltas() { m_count_delta.clear(); }

private:
  typedef std::map<Dish, int> dish_delta_type;
  typedef std::map<Dish, TableCounter> table_delta_type;

  dish_delta_type m_count_delta;
  table_delta_type m_table_delta;
};

template <typename Dish, typename Hash>
MPIPYP<Dish,Hash>::MPIPYP(double a, double b, Hash)
: PYP(a, b, Hash) {}

template <typename Dish, typename Hash>
int 
MPIPYP<Dish,Hash>::increment(Dish dish, double p0) {
  int delta = PYP<Dish,Hash>::increment(dish, p0);

  return delta;
}

template <typename Dish, typename Hash>
int 
MPIPYP<Dish,Hash>::decrement(Dish dish)
{
  int delta = PYP<Dish,Hash>::decrement(dish);
  return delta;
}

template <typename Dish, typename Hash>
void 
MPIPYP<Dish,Hash>::clear()
{
  PYP<Dish,Hash>::clear();
}

#endif