summaryrefslogtreecommitdiff
path: root/gi/clda/src
diff options
context:
space:
mode:
Diffstat (limited to 'gi/clda/src')
-rw-r--r--gi/clda/src/Makefile.am11
-rw-r--r--gi/clda/src/ccrp.h21
-rw-r--r--gi/clda/src/clda.cc38
-rw-r--r--gi/clda/src/crp_test.cc95
4 files changed, 126 insertions, 39 deletions
diff --git a/gi/clda/src/Makefile.am b/gi/clda/src/Makefile.am
index 2b1393ac..6a76fc93 100644
--- a/gi/clda/src/Makefile.am
+++ b/gi/clda/src/Makefile.am
@@ -1,3 +1,14 @@
+if HAVE_GTEST
+noinst_PROGRAMS = \
+ crp_test
+
+TESTS = crp_test
+
+crp_test_SOURCES = crp_test.cc
+crp_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS)
+
+endif
+
bin_PROGRAMS = clda
clda_SOURCES = clda.cc
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h
index 47f364f2..9b1c5284 100644
--- a/gi/clda/src/ccrp.h
+++ b/gi/clda/src/ccrp.h
@@ -18,12 +18,19 @@ class CCRP {
public:
CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {}
+ void clear() {
+ num_tables_ = 0;
+ num_customers_ = 0;
+ dish_locs_.clear();
+ }
+
unsigned num_tables(const Dish& dish) const {
const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
if (it == dish_locs_.end()) return 0;
return it->second.table_counts_.size();
}
+ // returns +1 or 0 indicating whether a new table was opened
int increment(const Dish& dish, const double& p0, MT19937* rng) {
DishLocations& loc = dish_locs_[dish];
bool share_table = false;
@@ -56,6 +63,7 @@ class CCRP {
return (share_table ? 0 : 1);
}
+ // returns -1 or 0, indicating whether a table was closed
int decrement(const Dish& dish, MT19937* rng) {
DishLocations& loc = dish_locs_[dish];
assert(loc.total_dish_count_);
@@ -66,11 +74,13 @@ class CCRP {
return -1;
} else {
int delta = 0;
- double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_);
+ // sample customer to remove UNIFORMLY. that is, do NOT use the discount
+ // here. if you do, it will introduce (unwanted) bias!
+ double r = rng->next() * loc.total_dish_count_;
--loc.total_dish_count_;
for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
ti != loc.table_counts_.end(); ++ti) {
- r -= (*ti - discount_);
+ r -= *ti;
if (r <= 0.0) {
if ((--(*ti)) == 0) {
--num_tables_;
@@ -107,7 +117,9 @@ class CCRP {
double lp = 0.0;
if (num_customers_) {
const double r = lgamma(1.0 - discount_);
- lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) - lgamma(concentration_ / discount_);
+ lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_)
+ + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_)
+ - lgamma(concentration_ / discount_);
for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
it != dish_locs_.end(); ++it) {
const DishLocations& cur = it->second;
@@ -119,8 +131,9 @@ class CCRP {
struct DishLocations {
DishLocations() : total_dish_count_() {}
- unsigned total_dish_count_;
+ unsigned total_dish_count_; // customers at all tables with this dish
std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want
+ // .size() is the number of tables for this dish
};
void Print(std::ostream* out) const {
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc
index 0232331b..10056bc9 100644
--- a/gi/clda/src/clda.cc
+++ b/gi/clda/src/clda.cc
@@ -25,39 +25,7 @@ void ShowTopWordsForTopic(const map<WordID, int>& counts) {
cerr << endl;
}
-void tc() {
- MT19937 rng;
- CCRP<string> crp(0.1, 5);
- double un = 0.25;
- int tt = 0;
- tt += crp.increment("hi", un, &rng);
- tt += crp.increment("foo", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- cout << "tt=" << tt << endl;
- cout << crp << endl;
- cout << " P(bar)=" << crp.prob("bar", un) << endl;
- cout << " P(hi)=" << crp.prob("hi", un) << endl;
- cout << " P(baz)=" << crp.prob("baz", un) << endl;
- cout << " P(foo)=" << crp.prob("foo", un) << endl;
- double x = crp.prob("bar", un) + crp.prob("hi", un) + crp.prob("baz", un) + crp.prob("foo", un);
- cout << " tot=" << x << endl;
- tt += crp.decrement("hi", &rng);
- tt += crp.decrement("bar", &rng);
- cout << crp << endl;
- tt += crp.decrement("bar", &rng);
- cout << crp << endl;
- cout << "tt=" << tt << endl;
- cout << crp.log_crp_prob() << endl;
-}
-
int main(int argc, char** argv) {
- tc();
if (argc != 3) {
cerr << "Usage: " << argv[0] << " num-classes num-samples\n";
return 1;
@@ -88,11 +56,11 @@ int main(int argc, char** argv) {
MT19937 rng;
cerr << "INITIALIZING RANDOM TOPIC ASSIGNMENTS\n";
zji.resize(wji.size());
- double disc = 0.05;
+ double disc = 0.1;
double beta = 10.0;
double alpha = 50.0;
- double uniform_topic = 1.0 / num_classes;
- double uniform_word = 1.0 / TD::NumWords();
+ const double uniform_topic = 1.0 / num_classes;
+ const double uniform_word = 1.0 / TD::NumWords();
vector<CCRP<int> > dr(zji.size(), CCRP<int>(disc, beta)); // dr[i] describes the probability of using a topic in document i
vector<CCRP<int> > wr(num_classes, CCRP<int>(disc, alpha)); // wr[k] describes the probability of generating a word in topic k
for (int j = 0; j < zji.size(); ++j) {
diff --git a/gi/clda/src/crp_test.cc b/gi/clda/src/crp_test.cc
new file mode 100644
index 00000000..750d80a7
--- /dev/null
+++ b/gi/clda/src/crp_test.cc
@@ -0,0 +1,95 @@
+#include <iostream>
+#include <vector>
+#include <string>
+
+#include <gtest/gtest.h>
+
+#include "ccrp.h"
+#include "sampler.h"
+
+const size_t MAX_DOC_LEN_CHARS = 10000000;
+
+using namespace std;
+
+class CRPTest : public testing::Test {
+ public:
+ CRPTest() {}
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+ MT19937 rng;
+};
+
+TEST_F(CRPTest, Dist) {
+ CCRP<string> crp(0.1, 5);
+ double un = 0.25;
+ int tt = 0;
+ tt += crp.increment("hi", un, &rng);
+ tt += crp.increment("foo", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ cout << "tt=" << tt << endl;
+ cout << crp << endl;
+ cout << " P(bar)=" << crp.prob("bar", un) << endl;
+ cout << " P(hi)=" << crp.prob("hi", un) << endl;
+ cout << " P(baz)=" << crp.prob("baz", un) << endl;
+ cout << " P(foo)=" << crp.prob("foo", un) << endl;
+ double x = crp.prob("bar", un) + crp.prob("hi", un) + crp.prob("baz", un) + crp.prob("foo", un);
+ cout << " tot=" << x << endl;
+ EXPECT_FLOAT_EQ(1.0, x);
+ tt += crp.decrement("hi", &rng);
+ tt += crp.decrement("bar", &rng);
+ cout << crp << endl;
+ tt += crp.decrement("bar", &rng);
+ cout << crp << endl;
+ cout << "tt=" << tt << endl;
+}
+
+TEST_F(CRPTest, Exchangability) {
+ double tot = 0;
+ double xt = 0;
+ CCRP<int> crp(0.5, 1.0);
+ int cust = 10;
+ vector<int> hist(cust + 1, 0);
+ for (int i = 0; i < cust; ++i) { crp.increment(1, 1.0, &rng); }
+ const int samples = 100000;
+ const bool simulate = true;
+ for (int k = 0; k < samples; ++k) {
+ if (!simulate) {
+ crp.clear();
+ for (int i = 0; i < cust; ++i) { crp.increment(1, 1.0, &rng); }
+ } else {
+ int da = rng.next() * cust;
+ bool a = rng.next() < 0.5;
+ if (a) {
+ for (int i = 0; i < da; ++i) { crp.increment(1, 1.0, &rng); }
+ for (int i = 0; i < da; ++i) { crp.decrement(1, &rng); }
+ xt += 1.0;
+ } else {
+ for (int i = 0; i < da; ++i) { crp.decrement(1, &rng); }
+ for (int i = 0; i < da; ++i) { crp.increment(1, 1.0, &rng); }
+ }
+ }
+ int c = crp.num_tables(1);
+ ++hist[c];
+ tot += c;
+ }
+ cerr << "P(a) = " << (xt / samples) << endl;
+ cerr << "E[num tables] = " << (tot / samples) << endl;
+ double error = fabs((tot / samples) - 5.4);
+ cerr << " error = " << error << endl;
+ EXPECT_LT(error, 0.1); // it's possible for this to fail, but
+ // very, very unlikely
+ for (int i = 1; i <= cust; ++i)
+ cerr << i << ' ' << (hist[i]) << endl;
+}
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}