From 0ef9ee63a44bf2659756eb2f1d34f5fddfb458a4 Mon Sep 17 00:00:00 2001 From: redpony Date: Wed, 25 Aug 2010 02:14:41 +0000 Subject: crp with explicit table tracking git-svn-id: https://ws10smt.googlecode.com/svn/trunk@618 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/clda/src/ccrp.h | 139 ++++++++++++++++++++++++++++++++++++++++++++++++++++ gi/clda/src/clda.cc | 87 +++++++++++++++++++------------- 2 files changed, 192 insertions(+), 34 deletions(-) create mode 100644 gi/clda/src/ccrp.h (limited to 'gi/clda') diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h new file mode 100644 index 00000000..e978225d --- /dev/null +++ b/gi/clda/src/ccrp.h @@ -0,0 +1,139 @@ +#ifndef _CCRP_H_ +#define _CCRP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "sampler.h" + +// Chinese restaurant process (Pitman-Yor parameters) with explicit table +// tracking. + +template > +class CCRP { + public: + CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {} + + int increment(const Dish& dish, const double& p0, MT19937* rng) { + DishLocations& loc = dish_locs_[dish]; + bool share_table = false; + if (loc.dish_count_) { + const double p_empty = (concentration_ + num_tables_ * discount_) * p0; + const double p_share = (loc.dish_count_ - loc.table_counts_.size() * discount_); + share_table = rng->SelectSample(p_empty, p_share); + } + if (share_table) { + double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_); + for (typename std::list::iterator ti = loc.table_counts_.begin(); + ti != loc.table_counts_.end(); ++ti) { + r -= (*ti - discount_); + if (r <= 0.0) { + ++(*ti); + break; + } + } + if (r > 0.0) { + std::cerr << "Serious error: r=" << r << std::endl; + Print(&std::cerr); + assert(r <= 0.0); + } + } else { + loc.table_counts_.push_back(1u); + ++num_tables_; + } + ++loc.dish_count_; + ++num_customers_; + return (share_table ? 0 : 1); + } + + int decrement(const Dish& dish, MT19937* rng) { + DishLocations& loc = dish_locs_[dish]; + assert(loc.dish_count_); + if (loc.dish_count_ == 1) { + dish_locs_.erase(dish); + --num_tables_; + --num_customers_; + return -1; + } else { + int delta = 0; + double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_); + --loc.dish_count_; + for (typename std::list::iterator ti = loc.table_counts_.begin(); + ti != loc.table_counts_.end(); ++ti) { + r -= (*ti - discount_); + if (r <= 0.0) { + if ((--(*ti)) == 0) { + --num_tables_; + delta = -1; + loc.table_counts_.erase(ti); + } + break; + } + } + if (r > 0.0) { + std::cerr << "Serious error: r=" << r << std::endl; + Print(&std::cerr); + assert(r <= 0.0); + } + --num_customers_; + return delta; + } + } + + double prob(const Dish& dish, const double& p0) const { + const typename std::tr1::unordered_map::const_iterator it = dish_locs_.find(dish); + const double r = num_tables_ * discount_ + concentration_; + if (it == dish_locs_.end()) { + return r * p0 / (num_customers_ + concentration_); + } else { + return (it->second.dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) / + (num_customers_ + concentration_); + } + } + + double llh() const { + if (num_customers_) { + std::cerr << "not implemented\n"; + return 0.0; + } else { + return 0.0; + } + } + + struct DishLocations { + DishLocations() : dish_count_() {} + unsigned dish_count_; + std::list table_counts_; // list<> gives O(1) deletion and insertion, which we want + }; + + void Print(std::ostream* out) const { + for (typename std::tr1::unordered_map::const_iterator it = dish_locs_.begin(); + it != dish_locs_.end(); ++it) { + (*out) << it->first << " (" << it->second.dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; + for (typename std::list::const_iterator i = it->second.table_counts_.begin(); + i != it->second.table_counts_.end(); ++i) { + (*out) << " " << *i; + } + (*out) << std::endl; + } + } + + unsigned num_tables_; + unsigned num_customers_; + std::tr1::unordered_map dish_locs_; + + double discount_; + double concentration_; +}; + +template +std::ostream& operator<<(std::ostream& o, const CCRP& c) { + c.Print(&o); + return o; +} + +#endif diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc index 757a4691..ea224a27 100644 --- a/gi/clda/src/clda.cc +++ b/gi/clda/src/clda.cc @@ -1,9 +1,11 @@ #include #include #include +#include #include "timer.h" #include "crp.h" +#include "ccrp.h" #include "sampler.h" #include "tdict.h" const size_t MAX_DOC_LEN_CHARS = 10000000; @@ -18,12 +20,44 @@ void ShowTopWordsForTopic(const map& counts) { for (multimap::reverse_iterator it = ms.rbegin(); it != ms.rend(); ++it) { cerr << it->first << ':' << TD::Convert(it->second) << " "; ++cc; - if (cc==12) break; + if (cc==20) break; } cerr << endl; } +void tc() { + MT19937 rng; + CCRP crp(0.1, 5); + double un = 0.25; + int tt = 0; + tt += crp.increment("hi", un, &rng); + tt += crp.increment("foo", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + cout << "tt=" << tt << endl; + cout << crp << endl; + cout << " P(bar)=" << crp.prob("bar", un) << endl; + cout << " P(hi)=" << crp.prob("hi", un) << endl; + cout << " P(baz)=" << crp.prob("baz", un) << endl; + cout << " P(foo)=" << crp.prob("foo", un) << endl; + double x = crp.prob("bar", un) + crp.prob("hi", un) + crp.prob("baz", un) + crp.prob("foo", un); + cout << " tot=" << x << endl; + tt += crp.decrement("hi", &rng); + tt += crp.decrement("bar", &rng); + cout << crp << endl; + tt += crp.decrement("bar", &rng); + cout << crp << endl; + cout << "tt=" << tt << endl; + cout << crp.llh() << endl; +} + int main(int argc, char** argv) { + tc(); if (argc != 3) { cerr << "Usage: " << argv[0] << " num-classes num-samples\n"; return 1; @@ -54,10 +88,13 @@ int main(int argc, char** argv) { MT19937 rng; cerr << "INITIALIZING RANDOM TOPIC ASSIGNMENTS\n"; zji.resize(wji.size()); - double beta = 0.1; - double alpha = 50.0 / num_classes; - vector > dr(zji.size(), CRP(beta)); // dr[i] describes the probability of using a topic in document i - vector > wr(num_classes, CRP(alpha)); // wr[k] describes the probability of generating a word in topic k + double disc = 0.05; + double beta = 10.0; + double alpha = 50.0; + double uniform_topic = 1.0 / num_classes; + double uniform_word = 1.0 / TD::NumWords(); + vector > dr(zji.size(), CCRP(disc, beta)); // dr[i] describes the probability of using a topic in document i + vector > wr(num_classes, CCRP(disc, alpha)); // wr[k] describes the probability of generating a word in topic k for (int j = 0; j < zji.size(); ++j) { const size_t num_words = wji[j].size(); vector& zj = zji[j]; @@ -68,19 +105,15 @@ int main(int argc, char** argv) { if (random_topic == num_classes) { --random_topic; } zj[i] = random_topic; const int word = wj[i]; - dr[j].increment(random_topic); - wr[random_topic].increment(word); + dr[j].increment(random_topic, uniform_topic, &rng); + wr[random_topic].increment(word, uniform_word, &rng); } } cerr << "SAMPLING\n"; vector > t2w(num_classes); Timer timer; - SampleSet ss; + SampleSet ss; const int num_types = TD::NumWords(); - const prob_t class_p0(1.0 / num_classes); - const prob_t word_p0(1.0 / num_types); - cerr << "CLASS PRIOR PROB: " << class_p0 << endl; - cerr << " WORD PRIOR LOGPROB: " << log(word_p0) << endl; ss.resize(num_classes); double total_time = 0; for (int iter = 0; iter < num_iterations; ++iter) { @@ -88,23 +121,6 @@ int main(int argc, char** argv) { if (iter && iter % 10 == 0) { total_time += timer.Elapsed(); timer.Reset(); - prob_t lh = prob_t::One(); - for (int j = 0; j < zji.size(); ++j) { - const size_t num_words = wji[j].size(); - vector& zj = zji[j]; - const vector& wj = wji[j]; - for (int i = 0; i < num_words; ++i) { - const int word = wj[i]; - const int cur_topic = zj[i]; - lh *= dr[j].prob(cur_topic, class_p0); - lh *= wr[cur_topic].prob(word, word_p0); - if (iter > burnin_size) { - ++t2w[cur_topic][word]; - } - } - } - if (iter && iter % 40 == 0) { cerr << " [ITER=" << iter << " SEC/SAMPLE=" << (total_time / 40) << " LLH=" << log(lh) << "]\n"; total_time=0; } - //cerr << "ITERATION " << iter << " LOG LIKELIHOOD: " << log(lh) << endl; } for (int j = 0; j < zji.size(); ++j) { const size_t num_words = wji[j].size(); @@ -113,16 +129,19 @@ int main(int argc, char** argv) { for (int i = 0; i < num_words; ++i) { const int word = wj[i]; const int cur_topic = zj[i]; - dr[j].decrement(cur_topic); - wr[cur_topic].decrement(word); + dr[j].decrement(cur_topic, &rng); + wr[cur_topic].decrement(word, &rng); for (int k = 0; k < num_classes; ++k) { - ss[k]= dr[j].prob(k, class_p0) * wr[k].prob(word, word_p0); + ss[k]= dr[j].prob(k, uniform_topic) * wr[k].prob(word, uniform_word); } const int new_topic = rng.SelectSample(ss); - dr[j].increment(new_topic); - wr[new_topic].increment(word); + dr[j].increment(new_topic, uniform_topic, &rng); + wr[new_topic].increment(word, uniform_word, &rng); zj[i] = new_topic; + if (iter > burnin_size) { + ++t2w[cur_topic][word]; + } } } } -- cgit v1.2.3