summaryrefslogtreecommitdiff
path: root/utils/crp_table_manager.h
blob: 753e721f39120705e4724934fcb9f5e0e0ebb034 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#ifndef _CRP_TABLE_MANAGER_H_
#define _CRP_TABLE_MANAGER_H_

#include <iostream>
#include "sparse_vector.h"
#include "sampler.h"

// these are helper classes for implementing token-based CRP samplers
// basically the data structures recommended by Blunsom et al. in the Note.

struct CRPHistogram {
  //typedef std::map<unsigned, unsigned> MAPTYPE;
  typedef SparseVector<unsigned> MAPTYPE;
  typedef MAPTYPE::const_iterator const_iterator;

  inline void increment(unsigned bin, unsigned delta = 1u) {
    data[bin] += delta;
  }
  inline void decrement(unsigned bin, unsigned delta = 1u) {
    unsigned r = data[bin] -= delta;
    if (!r) data.erase(bin);
  }
  inline void move(unsigned from_bin, unsigned to_bin, unsigned delta = 1u) {
    decrement(from_bin, delta);
    increment(to_bin, delta);
  }
  inline const_iterator begin() const { return data.begin(); }
  inline const_iterator end() const { return data.end(); }

 private:
  MAPTYPE data;
};

// A CRPTableManager tracks statistics about all customers
// and tables serving some dish in a CRP and can correctly sample what
// table to remove a customer from and what table to join
struct CRPTableManager {
  CRPTableManager() : customers(), tables() {}

  inline unsigned num_tables() const {
    return tables;
  }

  inline unsigned num_customers() const {
    return customers;
  }

  inline void create_table() {
    h.increment(1);
    ++tables;
    ++customers;
  }

  // seat a customer at a table proportional to the number of customers seated at a table, less the discount
  // *new tables are never created by this function!
  inline void share_table(const double discount, MT19937* rng) {
    const double z = customers - discount * num_tables();
    double r = z * rng->next();
    const CRPHistogram::const_iterator end = h.end();
    CRPHistogram::const_iterator it = h.begin();
    for (; it != end; ++it) {
      // it->first = number of customers at table
      // it->second = number of such tables
      double thresh = (it->first - discount) * it->second;
      if (thresh > r) break;
      r -= thresh;
    }
    h.move(it->first, it->first + 1);
    ++customers;
  }

  // randomly sample a customer
  // *tables may be removed
  // returns -1 if a table is removed, 0 otherwise
  inline int remove_customer(MT19937* rng) {
    int r = rng->next() * num_customers();
    const CRPHistogram::const_iterator end = h.end();
    CRPHistogram::const_iterator it = h.begin();
    for (; it != end; ++it) {
      int thresh = it->first * it->second;
      if (thresh > r) break;
      r -= thresh;
    }
    --customers;
    const unsigned tc = it->first;
    if (tc == 1) {
      h.decrement(1);
      --tables;
      return -1;
    } else {
      h.move(tc, tc - 1);
      return 0;
    }
  }

  typedef CRPHistogram::const_iterator const_iterator;
  const_iterator begin() const { return h.begin(); }
  const_iterator end() const { return h.end(); }

  unsigned customers;
  unsigned tables;
  CRPHistogram h;
};

std::ostream& operator<<(std::ostream& os, const CRPTableManager& tm) {
  os << '[' << tm.num_customers() << " total customers at " << tm.num_tables() << " tables ||| ";
  for (CRPHistogram::const_iterator it = tm.begin(); it != tm.end(); ++it) {
    if (it != tm.h.begin()) os << "  --  ";
    os << '(' << it->first << ") x " << it->second;
  }
  return os << ']';
}

#endif