summaryrefslogtreecommitdiff
path: root/gi/clda/src/clda.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/clda/src/clda.cc')
-rw-r--r--gi/clda/src/clda.cc87
1 files changed, 53 insertions, 34 deletions
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc
index 757a4691..ea224a27 100644
--- a/gi/clda/src/clda.cc
+++ b/gi/clda/src/clda.cc
@@ -1,9 +1,11 @@
#include <iostream>
#include <vector>
#include <map>
+#include <string>
#include "timer.h"
#include "crp.h"
+#include "ccrp.h"
#include "sampler.h"
#include "tdict.h"
const size_t MAX_DOC_LEN_CHARS = 10000000;
@@ -18,12 +20,44 @@ void ShowTopWordsForTopic(const map<WordID, int>& counts) {
for (multimap<int, WordID>::reverse_iterator it = ms.rbegin(); it != ms.rend(); ++it) {
cerr << it->first << ':' << TD::Convert(it->second) << " ";
++cc;
- if (cc==12) break;
+ if (cc==20) break;
}
cerr << endl;
}
+void tc() {
+ MT19937 rng;
+ CCRP<string> crp(0.1, 5);
+ double un = 0.25;
+ int tt = 0;
+ tt += crp.increment("hi", un, &rng);
+ tt += crp.increment("foo", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ cout << "tt=" << tt << endl;
+ cout << crp << endl;
+ cout << " P(bar)=" << crp.prob("bar", un) << endl;
+ cout << " P(hi)=" << crp.prob("hi", un) << endl;
+ cout << " P(baz)=" << crp.prob("baz", un) << endl;
+ cout << " P(foo)=" << crp.prob("foo", un) << endl;
+ double x = crp.prob("bar", un) + crp.prob("hi", un) + crp.prob("baz", un) + crp.prob("foo", un);
+ cout << " tot=" << x << endl;
+ tt += crp.decrement("hi", &rng);
+ tt += crp.decrement("bar", &rng);
+ cout << crp << endl;
+ tt += crp.decrement("bar", &rng);
+ cout << crp << endl;
+ cout << "tt=" << tt << endl;
+ cout << crp.llh() << endl;
+}
+
int main(int argc, char** argv) {
+ tc();
if (argc != 3) {
cerr << "Usage: " << argv[0] << " num-classes num-samples\n";
return 1;
@@ -54,10 +88,13 @@ int main(int argc, char** argv) {
MT19937 rng;
cerr << "INITIALIZING RANDOM TOPIC ASSIGNMENTS\n";
zji.resize(wji.size());
- double beta = 0.1;
- double alpha = 50.0 / num_classes;
- vector<CRP<int> > dr(zji.size(), CRP<int>(beta)); // dr[i] describes the probability of using a topic in document i
- vector<CRP<int> > wr(num_classes, CRP<int>(alpha)); // wr[k] describes the probability of generating a word in topic k
+ double disc = 0.05;
+ double beta = 10.0;
+ double alpha = 50.0;
+ double uniform_topic = 1.0 / num_classes;
+ double uniform_word = 1.0 / TD::NumWords();
+ vector<CCRP<int> > dr(zji.size(), CCRP<int>(disc, beta)); // dr[i] describes the probability of using a topic in document i
+ vector<CCRP<int> > wr(num_classes, CCRP<int>(disc, alpha)); // wr[k] describes the probability of generating a word in topic k
for (int j = 0; j < zji.size(); ++j) {
const size_t num_words = wji[j].size();
vector<int>& zj = zji[j];
@@ -68,19 +105,15 @@ int main(int argc, char** argv) {
if (random_topic == num_classes) { --random_topic; }
zj[i] = random_topic;
const int word = wj[i];
- dr[j].increment(random_topic);
- wr[random_topic].increment(word);
+ dr[j].increment(random_topic, uniform_topic, &rng);
+ wr[random_topic].increment(word, uniform_word, &rng);
}
}
cerr << "SAMPLING\n";
vector<map<WordID, int> > t2w(num_classes);
Timer timer;
- SampleSet ss;
+ SampleSet<double> ss;
const int num_types = TD::NumWords();
- const prob_t class_p0(1.0 / num_classes);
- const prob_t word_p0(1.0 / num_types);
- cerr << "CLASS PRIOR PROB: " << class_p0 << endl;
- cerr << " WORD PRIOR LOGPROB: " << log(word_p0) << endl;
ss.resize(num_classes);
double total_time = 0;
for (int iter = 0; iter < num_iterations; ++iter) {
@@ -88,23 +121,6 @@ int main(int argc, char** argv) {
if (iter && iter % 10 == 0) {
total_time += timer.Elapsed();
timer.Reset();
- prob_t lh = prob_t::One();
- for (int j = 0; j < zji.size(); ++j) {
- const size_t num_words = wji[j].size();
- vector<int>& zj = zji[j];
- const vector<int>& wj = wji[j];
- for (int i = 0; i < num_words; ++i) {
- const int word = wj[i];
- const int cur_topic = zj[i];
- lh *= dr[j].prob(cur_topic, class_p0);
- lh *= wr[cur_topic].prob(word, word_p0);
- if (iter > burnin_size) {
- ++t2w[cur_topic][word];
- }
- }
- }
- if (iter && iter % 40 == 0) { cerr << " [ITER=" << iter << " SEC/SAMPLE=" << (total_time / 40) << " LLH=" << log(lh) << "]\n"; total_time=0; }
- //cerr << "ITERATION " << iter << " LOG LIKELIHOOD: " << log(lh) << endl;
}
for (int j = 0; j < zji.size(); ++j) {
const size_t num_words = wji[j].size();
@@ -113,16 +129,19 @@ int main(int argc, char** argv) {
for (int i = 0; i < num_words; ++i) {
const int word = wj[i];
const int cur_topic = zj[i];
- dr[j].decrement(cur_topic);
- wr[cur_topic].decrement(word);
+ dr[j].decrement(cur_topic, &rng);
+ wr[cur_topic].decrement(word, &rng);
for (int k = 0; k < num_classes; ++k) {
- ss[k]= dr[j].prob(k, class_p0) * wr[k].prob(word, word_p0);
+ ss[k]= dr[j].prob(k, uniform_topic) * wr[k].prob(word, uniform_word);
}
const int new_topic = rng.SelectSample(ss);
- dr[j].increment(new_topic);
- wr[new_topic].increment(word);
+ dr[j].increment(new_topic, uniform_topic, &rng);
+ wr[new_topic].increment(word, uniform_word, &rng);
zj[i] = new_topic;
+ if (iter > burnin_size) {
+ ++t2w[cur_topic][word];
+ }
}
}
}