summaryrefslogtreecommitdiff
path: root/gi/clda/src/clda.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/clda/src/clda.cc')
-rw-r--r--gi/clda/src/clda.cc9
1 files changed, 6 insertions, 3 deletions
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc
index 976b020f..482a1c4c 100644
--- a/gi/clda/src/clda.cc
+++ b/gi/clda/src/clda.cc
@@ -79,6 +79,9 @@ int main(int argc, char** argv) {
vector<map<WordID, int> > t2w(num_classes);
Timer timer;
SampleSet ss;
+ const int num_types = TD::dict_.max();
+ const prob_t class_p0(1.0 / num_classes);
+ const prob_t word_p0(1.0 / num_types);
ss.resize(num_classes);
double total_time = 0;
for (int iter = 0; iter < num_iterations; ++iter) {
@@ -94,8 +97,8 @@ int main(int argc, char** argv) {
for (int i = 0; i < num_words; ++i) {
const int word = wj[i];
const int cur_topic = zj[i];
- lh *= dr[j].prob(cur_topic);
- lh *= wr[cur_topic].prob(word);
+ lh *= dr[j].prob(cur_topic, class_p0);
+ lh *= wr[cur_topic].prob(word, word_p0);
if (iter > burnin_size) {
++t2w[cur_topic][word];
}
@@ -115,7 +118,7 @@ int main(int argc, char** argv) {
wr[cur_topic].decrement(word);
for (int k = 0; k < num_classes; ++k) {
- ss[k]= dr[j].prob(k) * wr[k].prob(word);
+ ss[k]= dr[j].prob(k, class_p0) * wr[k].prob(word, word_p0);
}
const int new_topic = rng.SelectSample(ss);
dr[j].increment(new_topic);