summaryrefslogtreecommitdiff
path: root/gi/clda
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-25 02:14:41 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-25 02:14:41 +0000
commit0ef9ee63a44bf2659756eb2f1d34f5fddfb458a4 (patch)
treebd79d681a5919da99575a0af194ea5a35b05bbdd /gi/clda
parent1090a065dc48211dd71f4980cf8ff34e47333ad0 (diff)
crp with explicit table tracking
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@618 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/clda')
-rw-r--r--gi/clda/src/ccrp.h139
-rw-r--r--gi/clda/src/clda.cc87
2 files changed, 192 insertions, 34 deletions
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h
new file mode 100644
index 00000000..e978225d
--- /dev/null
+++ b/gi/clda/src/ccrp.h
@@ -0,0 +1,139 @@
+#ifndef _CCRP_H_
+#define _CCRP_H_
+
+#include <cassert>
+#include <cmath>
+#include <list>
+#include <iostream>
+#include <vector>
+#include <tr1/unordered_map>
+#include <boost/functional/hash.hpp>
+#include "sampler.h"
+
+// Chinese restaurant process (Pitman-Yor parameters) with explicit table
+// tracking.
+
+template <typename Dish, typename DishHash = boost::hash<Dish> >
+class CCRP {
+ public:
+ CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {}
+
+ int increment(const Dish& dish, const double& p0, MT19937* rng) {
+ DishLocations& loc = dish_locs_[dish];
+ bool share_table = false;
+ if (loc.dish_count_) {
+ const double p_empty = (concentration_ + num_tables_ * discount_) * p0;
+ const double p_share = (loc.dish_count_ - loc.table_counts_.size() * discount_);
+ share_table = rng->SelectSample(p_empty, p_share);
+ }
+ if (share_table) {
+ double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_);
+ for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
+ ti != loc.table_counts_.end(); ++ti) {
+ r -= (*ti - discount_);
+ if (r <= 0.0) {
+ ++(*ti);
+ break;
+ }
+ }
+ if (r > 0.0) {
+ std::cerr << "Serious error: r=" << r << std::endl;
+ Print(&std::cerr);
+ assert(r <= 0.0);
+ }
+ } else {
+ loc.table_counts_.push_back(1u);
+ ++num_tables_;
+ }
+ ++loc.dish_count_;
+ ++num_customers_;
+ return (share_table ? 0 : 1);
+ }
+
+ int decrement(const Dish& dish, MT19937* rng) {
+ DishLocations& loc = dish_locs_[dish];
+ assert(loc.dish_count_);
+ if (loc.dish_count_ == 1) {
+ dish_locs_.erase(dish);
+ --num_tables_;
+ --num_customers_;
+ return -1;
+ } else {
+ int delta = 0;
+ double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_);
+ --loc.dish_count_;
+ for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
+ ti != loc.table_counts_.end(); ++ti) {
+ r -= (*ti - discount_);
+ if (r <= 0.0) {
+ if ((--(*ti)) == 0) {
+ --num_tables_;
+ delta = -1;
+ loc.table_counts_.erase(ti);
+ }
+ break;
+ }
+ }
+ if (r > 0.0) {
+ std::cerr << "Serious error: r=" << r << std::endl;
+ Print(&std::cerr);
+ assert(r <= 0.0);
+ }
+ --num_customers_;
+ return delta;
+ }
+ }
+
+ double prob(const Dish& dish, const double& p0) const {
+ const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
+ const double r = num_tables_ * discount_ + concentration_;
+ if (it == dish_locs_.end()) {
+ return r * p0 / (num_customers_ + concentration_);
+ } else {
+ return (it->second.dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) /
+ (num_customers_ + concentration_);
+ }
+ }
+
+ double llh() const {
+ if (num_customers_) {
+ std::cerr << "not implemented\n";
+ return 0.0;
+ } else {
+ return 0.0;
+ }
+ }
+
+ struct DishLocations {
+ DishLocations() : dish_count_() {}
+ unsigned dish_count_;
+ std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want
+ };
+
+ void Print(std::ostream* out) const {
+ for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
+ it != dish_locs_.end(); ++it) {
+ (*out) << it->first << " (" << it->second.dish_count_ << " on " << it->second.table_counts_.size() << " tables): ";
+ for (typename std::list<unsigned>::const_iterator i = it->second.table_counts_.begin();
+ i != it->second.table_counts_.end(); ++i) {
+ (*out) << " " << *i;
+ }
+ (*out) << std::endl;
+ }
+ }
+
+ unsigned num_tables_;
+ unsigned num_customers_;
+ std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_;
+
+ double discount_;
+ double concentration_;
+};
+
+template <typename T,typename H>
+std::ostream& operator<<(std::ostream& o, const CCRP<T,H>& c) {
+ c.Print(&o);
+ return o;
+}
+
+#endif
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc
index 757a4691..ea224a27 100644
--- a/gi/clda/src/clda.cc
+++ b/gi/clda/src/clda.cc
@@ -1,9 +1,11 @@
#include <iostream>
#include <vector>
#include <map>
+#include <string>
#include "timer.h"
#include "crp.h"
+#include "ccrp.h"
#include "sampler.h"
#include "tdict.h"
const size_t MAX_DOC_LEN_CHARS = 10000000;
@@ -18,12 +20,44 @@ void ShowTopWordsForTopic(const map<WordID, int>& counts) {
for (multimap<int, WordID>::reverse_iterator it = ms.rbegin(); it != ms.rend(); ++it) {
cerr << it->first << ':' << TD::Convert(it->second) << " ";
++cc;
- if (cc==12) break;
+ if (cc==20) break;
}
cerr << endl;
}
+void tc() {
+ MT19937 rng;
+ CCRP<string> crp(0.1, 5);
+ double un = 0.25;
+ int tt = 0;
+ tt += crp.increment("hi", un, &rng);
+ tt += crp.increment("foo", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ cout << "tt=" << tt << endl;
+ cout << crp << endl;
+ cout << " P(bar)=" << crp.prob("bar", un) << endl;
+ cout << " P(hi)=" << crp.prob("hi", un) << endl;
+ cout << " P(baz)=" << crp.prob("baz", un) << endl;
+ cout << " P(foo)=" << crp.prob("foo", un) << endl;
+ double x = crp.prob("bar", un) + crp.prob("hi", un) + crp.prob("baz", un) + crp.prob("foo", un);
+ cout << " tot=" << x << endl;
+ tt += crp.decrement("hi", &rng);
+ tt += crp.decrement("bar", &rng);
+ cout << crp << endl;
+ tt += crp.decrement("bar", &rng);
+ cout << crp << endl;
+ cout << "tt=" << tt << endl;
+ cout << crp.llh() << endl;
+}
+
int main(int argc, char** argv) {
+ tc();
if (argc != 3) {
cerr << "Usage: " << argv[0] << " num-classes num-samples\n";
return 1;
@@ -54,10 +88,13 @@ int main(int argc, char** argv) {
MT19937 rng;
cerr << "INITIALIZING RANDOM TOPIC ASSIGNMENTS\n";
zji.resize(wji.size());
- double beta = 0.1;
- double alpha = 50.0 / num_classes;
- vector<CRP<int> > dr(zji.size(), CRP<int>(beta)); // dr[i] describes the probability of using a topic in document i
- vector<CRP<int> > wr(num_classes, CRP<int>(alpha)); // wr[k] describes the probability of generating a word in topic k
+ double disc = 0.05;
+ double beta = 10.0;
+ double alpha = 50.0;
+ double uniform_topic = 1.0 / num_classes;
+ double uniform_word = 1.0 / TD::NumWords();
+ vector<CCRP<int> > dr(zji.size(), CCRP<int>(disc, beta)); // dr[i] describes the probability of using a topic in document i
+ vector<CCRP<int> > wr(num_classes, CCRP<int>(disc, alpha)); // wr[k] describes the probability of generating a word in topic k
for (int j = 0; j < zji.size(); ++j) {
const size_t num_words = wji[j].size();
vector<int>& zj = zji[j];
@@ -68,19 +105,15 @@ int main(int argc, char** argv) {
if (random_topic == num_classes) { --random_topic; }
zj[i] = random_topic;
const int word = wj[i];
- dr[j].increment(random_topic);
- wr[random_topic].increment(word);
+ dr[j].increment(random_topic, uniform_topic, &rng);
+ wr[random_topic].increment(word, uniform_word, &rng);
}
}
cerr << "SAMPLING\n";
vector<map<WordID, int> > t2w(num_classes);
Timer timer;
- SampleSet ss;
+ SampleSet<double> ss;
const int num_types = TD::NumWords();
- const prob_t class_p0(1.0 / num_classes);
- const prob_t word_p0(1.0 / num_types);
- cerr << "CLASS PRIOR PROB: " << class_p0 << endl;
- cerr << " WORD PRIOR LOGPROB: " << log(word_p0) << endl;
ss.resize(num_classes);
double total_time = 0;
for (int iter = 0; iter < num_iterations; ++iter) {
@@ -88,23 +121,6 @@ int main(int argc, char** argv) {
if (iter && iter % 10 == 0) {
total_time += timer.Elapsed();
timer.Reset();
- prob_t lh = prob_t::One();
- for (int j = 0; j < zji.size(); ++j) {
- const size_t num_words = wji[j].size();
- vector<int>& zj = zji[j];
- const vector<int>& wj = wji[j];
- for (int i = 0; i < num_words; ++i) {
- const int word = wj[i];
- const int cur_topic = zj[i];
- lh *= dr[j].prob(cur_topic, class_p0);
- lh *= wr[cur_topic].prob(word, word_p0);
- if (iter > burnin_size) {
- ++t2w[cur_topic][word];
- }
- }
- }
- if (iter && iter % 40 == 0) { cerr << " [ITER=" << iter << " SEC/SAMPLE=" << (total_time / 40) << " LLH=" << log(lh) << "]\n"; total_time=0; }
- //cerr << "ITERATION " << iter << " LOG LIKELIHOOD: " << log(lh) << endl;
}
for (int j = 0; j < zji.size(); ++j) {
const size_t num_words = wji[j].size();
@@ -113,16 +129,19 @@ int main(int argc, char** argv) {
for (int i = 0; i < num_words; ++i) {
const int word = wj[i];
const int cur_topic = zj[i];
- dr[j].decrement(cur_topic);
- wr[cur_topic].decrement(word);
+ dr[j].decrement(cur_topic, &rng);
+ wr[cur_topic].decrement(word, &rng);
for (int k = 0; k < num_classes; ++k) {
- ss[k]= dr[j].prob(k, class_p0) * wr[k].prob(word, word_p0);
+ ss[k]= dr[j].prob(k, uniform_topic) * wr[k].prob(word, uniform_word);
}
const int new_topic = rng.SelectSample(ss);
- dr[j].increment(new_topic);
- wr[new_topic].increment(word);
+ dr[j].increment(new_topic, uniform_topic, &rng);
+ wr[new_topic].increment(word, uniform_word, &rng);
zj[i] = new_topic;
+ if (iter > burnin_size) {
+ ++t2w[cur_topic][word];
+ }
}
}
}