summaryrefslogtreecommitdiff
path: root/gi/pf/pf_test.cc
blob: 296e72852f1207bd729af64a1495ff05925c328d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include "ccrp.h"

#include <vector>
#include <iostream>

#include "tdict.h"
#include "transliterations.h"

using namespace std;

MT19937 rng;

static bool verbose = false;

struct Model {

  Model() : bp(), base(0.2, 0.6) , ccrps(5, CCRP<int>(0.8, 0.5)) {}

  double p0(int x) const {
    assert(x > 0);
    assert(x < 5);
    return 1.0/4.0;
  }

  double llh() const {
    double lh = bp + base.log_crp_prob();
    for (int ctx = 1; ctx < 5; ++ctx)
      lh += ccrps[ctx].log_crp_prob();
    return lh;
  }

  double prob(int ctx, int x) const {
    assert(ctx > 0 && ctx < 5);
    return ccrps[ctx].prob(x, base.prob(x, p0(x)));
  }

  void increment(int ctx, int x) {
    assert(ctx > 0 && ctx < 5);
    if (ccrps[ctx].increment(x, base.prob(x, p0(x)), &rng)) {
      if (base.increment(x, p0(x), &rng)) {
        bp += log(1.0 / 4.0);
      }
    }
  }

  // this is just a biased estimate
  double est_base_prob(int x) {
    return (x + 1) * x / 40.0;
  }

  void increment_is(int ctx, int x) {
    assert(ctx > 0 && ctx < 5);
    SampleSet<double> ss;
    const int PARTICLES = 25;
    vector<CCRP<int> > s1s(PARTICLES, CCRP<int>(0.5,0.5));
    vector<CCRP<int> > sbs(PARTICLES, CCRP<int>(0.5,0.5));
    vector<double> sp0s(PARTICLES);

    CCRP<int> s1 = ccrps[ctx];
    CCRP<int> sb = base;
    double sp0 = bp;
    for (int pp = 0; pp < PARTICLES; ++pp) {
      if (pp > 0) {
        ccrps[ctx] = s1;
        base = sb;
        bp = sp0;
      }

      double q = 1;
      double gamma = 1;
      double est_p = est_base_prob(x);
      //base.prob(x, p0(x)) + rng.next() * 0.1;
      if (ccrps[ctx].increment(x, est_p, &rng, &q)) {
        gamma = q * base.prob(x, p0(x));
        q *= est_p;
        if (verbose) cerr << "(DP-base draw) ";
        double qq = -1;
        if (base.increment(x, p0(x), &rng, &qq)) {
          if (verbose) cerr << "(G0 draw) ";
          bp += log(p0(x));
          qq *= p0(x);
        }
      } else { gamma = q; }
      double w = gamma / q;
      if (verbose)
        cerr << "gamma=" << gamma << " q=" << q << "\tw=" << w << endl;
      ss.add(w);
      s1s[pp] = ccrps[ctx];
      sbs[pp] = base;
      sp0s[pp] = bp;
    }
    int ps = rng.SelectSample(ss);
    ccrps[ctx] = s1s[ps];
    base = sbs[ps];
    bp = sp0s[ps];
    if (verbose) {
      cerr << "SELECTED: " << ps << endl;
      static int cc = 0; cc++; if (cc ==10) exit(1);
    }
  }

  void decrement(int ctx, int x) {
    assert(ctx > 0 && ctx < 5);
    if (ccrps[ctx].decrement(x, &rng)) {
      if (base.decrement(x, &rng)) {
        bp -= log(p0(x));
      }
    }
  }

  double bp;
  CCRP<int> base;
  vector<CCRP<int> > ccrps;

};

int main(int argc, char** argv) {
  if (argc > 1) { verbose = true; }
  vector<int> counts(15, 0);
  vector<int> tcounts(15, 0);
  int points[] = {1,2, 2,2, 3,2, 4,1, 3, 4, 3, 3, 2, 3, 4, 1, 4, 1, 3, 2, 1, 3, 1, 4, 0, 0};
  double tlh = 0;
  double tt = 0;
  for (int n = 0; n < 1000; ++n) {
    if (n % 10 == 0) cerr << '.';
    if ((n+1) % 400 == 0) cerr << " [" << (n+1) << "]\n";
    Model m;
    for (int *x = points; *x; x += 2)
      m.increment(x[0], x[1]);

    for (int j = 0; j < 24; ++j) {
      for (int *x = points; *x; x += 2) {
        if (rng.next() < 0.8) {
          m.decrement(x[0], x[1]);
          m.increment_is(x[0], x[1]);
        }
      }
    }
    counts[m.base.num_customers()]++;
    tcounts[m.base.num_tables()]++;
    tlh += m.llh();
    tt += 1.0;
  }
  cerr << "mean LLH = " << (tlh / tt) << endl;
  for (int i = 0; i < 15; ++i)
    cerr << i << ": " << (counts[i] / tt) << "\t" << (tcounts[i] / tt) << endl;
}