summaryrefslogtreecommitdiff
path: root/utils/mfcr_test.cc
blob: 29a1a2ce865f22d6aefa8a1c3f6410cfff2bb294 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include "mfcr.h"

#include <iostream>
#include <cassert>
#include <cmath>

#define BOOST_TEST_MODULE MFCRTest
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>

#include "sampler.h"

using namespace std;

BOOST_AUTO_TEST_CASE(Exchangability) {
  MT19937 r;
  MT19937* rng = &r;
  MFCR<2, int> crp(0.5, 3.0);
  vector<double> lambdas(2);
  vector<double> p0s(2);
  lambdas[0] = 0.2;
  lambdas[1] = 0.8;
  p0s[0] = 1.0;
  p0s[1] = 1.0;

  double tot = 0;
  double tot2 = 0;
  double xt = 0;
  int cust = 10;
  vector<int> hist(cust + 1, 0), hist2(cust + 1, 0);
  for (int i = 0; i < cust; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }
  const int samples = 100000;
  const bool simulate = true;
  for (int k = 0; k < samples; ++k) {
    if (!simulate) {
      crp.clear();
      for (int i = 0; i < cust; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }
    } else {
      int da = rng->next() * cust;
      bool a = rng->next() < 0.45;
      if (a) {
        for (int i = 0; i < da; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }
        for (int i = 0; i < da; ++i) { crp.decrement(1, rng); }
        xt += 1.0;
      } else {
        for (int i = 0; i < da; ++i) { crp.decrement(1, rng); }
        for (int i = 0; i < da; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }
      }
    }
    int c = crp.num_tables(1);
    ++hist[c];
    tot += c;
    int c2 = crp.num_tables(1,0);  // tables on floor 0 with dish 1
    ++hist2[c2];
    tot2 += c2;
  }
  cerr << cust << " = " << crp.num_customers() << endl;
  cerr << "P(a) = " << (xt / samples) << endl;
  cerr << "E[num tables] = " << (tot / samples) << endl;
  double error = fabs((tot / samples) - 6.894);
  cerr << "   error = " << error << endl;
  for (int i = 1; i <= cust; ++i)
    cerr << i << ' ' << (hist[i]) << endl;
  cerr << "E[num tables on floor 0] = " << (tot2 / samples) << endl;
  double error2 = fabs((tot2 / samples) - 1.379);
  cerr << "  error2 = " << error2 << endl;
  for (int i = 1; i <= cust; ++i)
    cerr << i << ' ' << (hist2[i]) << endl;
  assert(error < 0.05);   // these can fail with very low probability
  assert(error2 < 0.05);
};