diff options
Diffstat (limited to 'gi/clda/src/clda.cc')
-rw-r--r-- | gi/clda/src/clda.cc | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc index 976b020f..482a1c4c 100644 --- a/gi/clda/src/clda.cc +++ b/gi/clda/src/clda.cc @@ -79,6 +79,9 @@ int main(int argc, char** argv) { vector<map<WordID, int> > t2w(num_classes); Timer timer; SampleSet ss; + const int num_types = TD::dict_.max(); + const prob_t class_p0(1.0 / num_classes); + const prob_t word_p0(1.0 / num_types); ss.resize(num_classes); double total_time = 0; for (int iter = 0; iter < num_iterations; ++iter) { @@ -94,8 +97,8 @@ int main(int argc, char** argv) { for (int i = 0; i < num_words; ++i) { const int word = wj[i]; const int cur_topic = zj[i]; - lh *= dr[j].prob(cur_topic); - lh *= wr[cur_topic].prob(word); + lh *= dr[j].prob(cur_topic, class_p0); + lh *= wr[cur_topic].prob(word, word_p0); if (iter > burnin_size) { ++t2w[cur_topic][word]; } @@ -115,7 +118,7 @@ int main(int argc, char** argv) { wr[cur_topic].decrement(word); for (int k = 0; k < num_classes; ++k) { - ss[k]= dr[j].prob(k) * wr[k].prob(word); + ss[k]= dr[j].prob(k, class_p0) * wr[k].prob(word, word_p0); } const int new_topic = rng.SelectSample(ss); dr[j].increment(new_topic); |