From aa2a66378777d3825cc31640dff2e206491ccee3 Mon Sep 17 00:00:00 2001 From: redpony Date: Wed, 25 Aug 2010 16:16:27 +0000 Subject: unit tests for PYP sampler git-svn-id: https://ws10smt.googlecode.com/svn/trunk@620 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/clda/src/crp_test.cc | 95 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 gi/clda/src/crp_test.cc (limited to 'gi/clda/src/crp_test.cc') diff --git a/gi/clda/src/crp_test.cc b/gi/clda/src/crp_test.cc new file mode 100644 index 00000000..750d80a7 --- /dev/null +++ b/gi/clda/src/crp_test.cc @@ -0,0 +1,95 @@ +#include +#include +#include + +#include + +#include "ccrp.h" +#include "sampler.h" + +const size_t MAX_DOC_LEN_CHARS = 10000000; + +using namespace std; + +class CRPTest : public testing::Test { + public: + CRPTest() {} + protected: + virtual void SetUp() { } + virtual void TearDown() { } + MT19937 rng; +}; + +TEST_F(CRPTest, Dist) { + CCRP crp(0.1, 5); + double un = 0.25; + int tt = 0; + tt += crp.increment("hi", un, &rng); + tt += crp.increment("foo", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + cout << "tt=" << tt << endl; + cout << crp << endl; + cout << " P(bar)=" << crp.prob("bar", un) << endl; + cout << " P(hi)=" << crp.prob("hi", un) << endl; + cout << " P(baz)=" << crp.prob("baz", un) << endl; + cout << " P(foo)=" << crp.prob("foo", un) << endl; + double x = crp.prob("bar", un) + crp.prob("hi", un) + crp.prob("baz", un) + crp.prob("foo", un); + cout << " tot=" << x << endl; + EXPECT_FLOAT_EQ(1.0, x); + tt += crp.decrement("hi", &rng); + tt += crp.decrement("bar", &rng); + cout << crp << endl; + tt += crp.decrement("bar", &rng); + cout << crp << endl; + cout << "tt=" << tt << endl; +} + +TEST_F(CRPTest, Exchangability) { + double tot = 0; + double xt = 0; + CCRP crp(0.5, 1.0); + int cust = 10; + vector hist(cust + 1, 0); + for (int i = 0; i < cust; ++i) { crp.increment(1, 1.0, &rng); } + const int samples = 100000; + const bool simulate = true; + for (int k = 0; k < samples; ++k) { + if (!simulate) { + crp.clear(); + for (int i = 0; i < cust; ++i) { crp.increment(1, 1.0, &rng); } + } else { + int da = rng.next() * cust; + bool a = rng.next() < 0.5; + if (a) { + for (int i = 0; i < da; ++i) { crp.increment(1, 1.0, &rng); } + for (int i = 0; i < da; ++i) { crp.decrement(1, &rng); } + xt += 1.0; + } else { + for (int i = 0; i < da; ++i) { crp.decrement(1, &rng); } + for (int i = 0; i < da; ++i) { crp.increment(1, 1.0, &rng); } + } + } + int c = crp.num_tables(1); + ++hist[c]; + tot += c; + } + cerr << "P(a) = " << (xt / samples) << endl; + cerr << "E[num tables] = " << (tot / samples) << endl; + double error = fabs((tot / samples) - 5.4); + cerr << " error = " << error << endl; + EXPECT_LT(error, 0.1); // it's possible for this to fail, but + // very, very unlikely + for (int i = 1; i <= cust; ++i) + cerr << i << ' ' << (hist[i]) << endl; +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} -- cgit v1.2.3