diff options
Diffstat (limited to 'gi/clda')
-rw-r--r-- | gi/clda/src/Makefile.am | 11 | ||||
-rw-r--r-- | gi/clda/src/ccrp.h | 21 | ||||
-rw-r--r-- | gi/clda/src/clda.cc | 38 | ||||
-rw-r--r-- | gi/clda/src/crp_test.cc | 95 |
4 files changed, 126 insertions, 39 deletions
diff --git a/gi/clda/src/Makefile.am b/gi/clda/src/Makefile.am index 2b1393ac..6a76fc93 100644 --- a/gi/clda/src/Makefile.am +++ b/gi/clda/src/Makefile.am @@ -1,3 +1,14 @@ +if HAVE_GTEST +noinst_PROGRAMS = \ + crp_test + +TESTS = crp_test + +crp_test_SOURCES = crp_test.cc +crp_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) + +endif + bin_PROGRAMS = clda clda_SOURCES = clda.cc diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h index 47f364f2..9b1c5284 100644 --- a/gi/clda/src/ccrp.h +++ b/gi/clda/src/ccrp.h @@ -18,12 +18,19 @@ class CCRP { public: CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {} + void clear() { + num_tables_ = 0; + num_customers_ = 0; + dish_locs_.clear(); + } + unsigned num_tables(const Dish& dish) const { const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); if (it == dish_locs_.end()) return 0; return it->second.table_counts_.size(); } + // returns +1 or 0 indicating whether a new table was opened int increment(const Dish& dish, const double& p0, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; bool share_table = false; @@ -56,6 +63,7 @@ class CCRP { return (share_table ? 0 : 1); } + // returns -1 or 0, indicating whether a table was closed int decrement(const Dish& dish, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; assert(loc.total_dish_count_); @@ -66,11 +74,13 @@ class CCRP { return -1; } else { int delta = 0; - double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); + // sample customer to remove UNIFORMLY. that is, do NOT use the discount + // here. if you do, it will introduce (unwanted) bias! + double r = rng->next() * loc.total_dish_count_; --loc.total_dish_count_; for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); ti != loc.table_counts_.end(); ++ti) { - r -= (*ti - discount_); + r -= *ti; if (r <= 0.0) { if ((--(*ti)) == 0) { --num_tables_; @@ -107,7 +117,9 @@ class CCRP { double lp = 0.0; if (num_customers_) { const double r = lgamma(1.0 - discount_); - lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) - lgamma(concentration_ / discount_); + lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) + + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) + - lgamma(concentration_ / discount_); for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { const DishLocations& cur = it->second; @@ -119,8 +131,9 @@ class CCRP { struct DishLocations { DishLocations() : total_dish_count_() {} - unsigned total_dish_count_; + unsigned total_dish_count_; // customers at all tables with this dish std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want + // .size() is the number of tables for this dish }; void Print(std::ostream* out) const { diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc index 0232331b..10056bc9 100644 --- a/gi/clda/src/clda.cc +++ b/gi/clda/src/clda.cc @@ -25,39 +25,7 @@ void ShowTopWordsForTopic(const map<WordID, int>& counts) { cerr << endl; } -void tc() { - MT19937 rng; - CCRP<string> crp(0.1, 5); - double un = 0.25; - int tt = 0; - tt += crp.increment("hi", un, &rng); - tt += crp.increment("foo", un, &rng); - tt += crp.increment("bar", un, &rng); - tt += crp.increment("bar", un, &rng); - tt += crp.increment("bar", un, &rng); - tt += crp.increment("bar", un, &rng); - tt += crp.increment("bar", un, &rng); - tt += crp.increment("bar", un, &rng); - tt += crp.increment("bar", un, &rng); - cout << "tt=" << tt << endl; - cout << crp << endl; - cout << " P(bar)=" << crp.prob("bar", un) << endl; - cout << " P(hi)=" << crp.prob("hi", un) << endl; - cout << " P(baz)=" << crp.prob("baz", un) << endl; - cout << " P(foo)=" << crp.prob("foo", un) << endl; - double x = crp.prob("bar", un) + crp.prob("hi", un) + crp.prob("baz", un) + crp.prob("foo", un); - cout << " tot=" << x << endl; - tt += crp.decrement("hi", &rng); - tt += crp.decrement("bar", &rng); - cout << crp << endl; - tt += crp.decrement("bar", &rng); - cout << crp << endl; - cout << "tt=" << tt << endl; - cout << crp.log_crp_prob() << endl; -} - int main(int argc, char** argv) { - tc(); if (argc != 3) { cerr << "Usage: " << argv[0] << " num-classes num-samples\n"; return 1; @@ -88,11 +56,11 @@ int main(int argc, char** argv) { MT19937 rng; cerr << "INITIALIZING RANDOM TOPIC ASSIGNMENTS\n"; zji.resize(wji.size()); - double disc = 0.05; + double disc = 0.1; double beta = 10.0; double alpha = 50.0; - double uniform_topic = 1.0 / num_classes; - double uniform_word = 1.0 / TD::NumWords(); + const double uniform_topic = 1.0 / num_classes; + const double uniform_word = 1.0 / TD::NumWords(); vector<CCRP<int> > dr(zji.size(), CCRP<int>(disc, beta)); // dr[i] describes the probability of using a topic in document i vector<CCRP<int> > wr(num_classes, CCRP<int>(disc, alpha)); // wr[k] describes the probability of generating a word in topic k for (int j = 0; j < zji.size(); ++j) { diff --git a/gi/clda/src/crp_test.cc b/gi/clda/src/crp_test.cc new file mode 100644 index 00000000..750d80a7 --- /dev/null +++ b/gi/clda/src/crp_test.cc @@ -0,0 +1,95 @@ +#include <iostream> +#include <vector> +#include <string> + +#include <gtest/gtest.h> + +#include "ccrp.h" +#include "sampler.h" + +const size_t MAX_DOC_LEN_CHARS = 10000000; + +using namespace std; + +class CRPTest : public testing::Test { + public: + CRPTest() {} + protected: + virtual void SetUp() { } + virtual void TearDown() { } + MT19937 rng; +}; + +TEST_F(CRPTest, Dist) { + CCRP<string> crp(0.1, 5); + double un = 0.25; + int tt = 0; + tt += crp.increment("hi", un, &rng); + tt += crp.increment("foo", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + cout << "tt=" << tt << endl; + cout << crp << endl; + cout << " P(bar)=" << crp.prob("bar", un) << endl; + cout << " P(hi)=" << crp.prob("hi", un) << endl; + cout << " P(baz)=" << crp.prob("baz", un) << endl; + cout << " P(foo)=" << crp.prob("foo", un) << endl; + double x = crp.prob("bar", un) + crp.prob("hi", un) + crp.prob("baz", un) + crp.prob("foo", un); + cout << " tot=" << x << endl; + EXPECT_FLOAT_EQ(1.0, x); + tt += crp.decrement("hi", &rng); + tt += crp.decrement("bar", &rng); + cout << crp << endl; + tt += crp.decrement("bar", &rng); + cout << crp << endl; + cout << "tt=" << tt << endl; +} + +TEST_F(CRPTest, Exchangability) { + double tot = 0; + double xt = 0; + CCRP<int> crp(0.5, 1.0); + int cust = 10; + vector<int> hist(cust + 1, 0); + for (int i = 0; i < cust; ++i) { crp.increment(1, 1.0, &rng); } + const int samples = 100000; + const bool simulate = true; + for (int k = 0; k < samples; ++k) { + if (!simulate) { + crp.clear(); + for (int i = 0; i < cust; ++i) { crp.increment(1, 1.0, &rng); } + } else { + int da = rng.next() * cust; + bool a = rng.next() < 0.5; + if (a) { + for (int i = 0; i < da; ++i) { crp.increment(1, 1.0, &rng); } + for (int i = 0; i < da; ++i) { crp.decrement(1, &rng); } + xt += 1.0; + } else { + for (int i = 0; i < da; ++i) { crp.decrement(1, &rng); } + for (int i = 0; i < da; ++i) { crp.increment(1, 1.0, &rng); } + } + } + int c = crp.num_tables(1); + ++hist[c]; + tot += c; + } + cerr << "P(a) = " << (xt / samples) << endl; + cerr << "E[num tables] = " << (tot / samples) << endl; + double error = fabs((tot / samples) - 5.4); + cerr << " error = " << error << endl; + EXPECT_LT(error, 0.1); // it's possible for this to fail, but + // very, very unlikely + for (int i = 1; i <= cust; ++i) + cerr << i << ' ' << (hist[i]) << endl; +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} |