summaryrefslogtreecommitdiff
path: root/utils/mfcr_test.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-01-03 16:59:11 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-01-03 16:59:11 -0500
commita144fb07effc59a3aa269d7fd5f3d0ab9dfe5e54 (patch)
tree1f9f6fd35e6540a8bb2208f4afc6f3c38de41e36 /utils/mfcr_test.cc
parent134228d946a3f119e88f23a5315fa7849d498ee4 (diff)
multi-floor chinese restaurant described by wood&teh (2009)
Diffstat (limited to 'utils/mfcr_test.cc')
-rw-r--r--utils/mfcr_test.cc72
1 files changed, 72 insertions, 0 deletions
diff --git a/utils/mfcr_test.cc b/utils/mfcr_test.cc
new file mode 100644
index 00000000..7c45a37c
--- /dev/null
+++ b/utils/mfcr_test.cc
@@ -0,0 +1,72 @@
+#include "mfcr.h"
+
+#include <iostream>
+#include <cassert>
+#include <cmath>
+
+#include "sampler.h"
+
+using namespace std;
+
+void test_exch(MT19937* rng) {
+ MFCR<int> crp(2, 0.5, 3.0);
+ vector<double> lambdas(2);
+ vector<double> p0s(2);
+ lambdas[0] = 0.2;
+ lambdas[1] = 0.8;
+ p0s[0] = 1.0;
+ p0s[1] = 1.0;
+
+ double tot = 0;
+ double tot2 = 0;
+ double xt = 0;
+ int cust = 10;
+ vector<int> hist(cust + 1, 0), hist2(cust + 1, 0);
+ for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); }
+ const int samples = 100000;
+ const bool simulate = true;
+ for (int k = 0; k < samples; ++k) {
+ if (!simulate) {
+ crp.clear();
+ for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); }
+ } else {
+ int da = rng->next() * cust;
+ bool a = rng->next() < 0.45;
+ if (a) {
+ for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); }
+ for (int i = 0; i < da; ++i) { crp.decrement(1, rng); }
+ xt += 1.0;
+ } else {
+ for (int i = 0; i < da; ++i) { crp.decrement(1, rng); }
+ for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); }
+ }
+ }
+ int c = crp.num_tables(1);
+ ++hist[c];
+ tot += c;
+ int c2 = crp.num_tables(1,0); // tables on floor 0 with dish 1
+ ++hist2[c2];
+ tot2 += c2;
+ }
+ cerr << cust << " = " << crp.num_customers() << endl;
+ cerr << "P(a) = " << (xt / samples) << endl;
+ cerr << "E[num tables] = " << (tot / samples) << endl;
+ double error = fabs((tot / samples) - 6.894);
+ cerr << " error = " << error << endl;
+ for (int i = 1; i <= cust; ++i)
+ cerr << i << ' ' << (hist[i]) << endl;
+ cerr << "E[num tables on floor 0] = " << (tot2 / samples) << endl;
+ double error2 = fabs((tot2 / samples) - 1.379);
+ cerr << " error2 = " << error2 << endl;
+ for (int i = 1; i <= cust; ++i)
+ cerr << i << ' ' << (hist2[i]) << endl;
+ assert(error < 0.05); // these can fail with very low probability
+ assert(error2 < 0.05);
+};
+
+int main(int argc, char** argv) {
+ MT19937 rng;
+ test_exch(&rng);
+ return 0;
+}
+