From 70ad159e22fc6ea12a5e7b468ab38a93c3ed111f Mon Sep 17 00:00:00 2001 From: redpony Date: Wed, 23 Jun 2010 14:49:22 +0000 Subject: fix bugs in CRP git-svn-id: https://ws10smt.googlecode.com/svn/trunk@10 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/clda/src/clda.cc | 9 ++++++--- gi/clda/src/crp.h | 3 --- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'gi/clda/src') diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc index 976b020f..482a1c4c 100644 --- a/gi/clda/src/clda.cc +++ b/gi/clda/src/clda.cc @@ -79,6 +79,9 @@ int main(int argc, char** argv) { vector > t2w(num_classes); Timer timer; SampleSet ss; + const int num_types = TD::dict_.max(); + const prob_t class_p0(1.0 / num_classes); + const prob_t word_p0(1.0 / num_types); ss.resize(num_classes); double total_time = 0; for (int iter = 0; iter < num_iterations; ++iter) { @@ -94,8 +97,8 @@ int main(int argc, char** argv) { for (int i = 0; i < num_words; ++i) { const int word = wj[i]; const int cur_topic = zj[i]; - lh *= dr[j].prob(cur_topic); - lh *= wr[cur_topic].prob(word); + lh *= dr[j].prob(cur_topic, class_p0); + lh *= wr[cur_topic].prob(word, word_p0); if (iter > burnin_size) { ++t2w[cur_topic][word]; } @@ -115,7 +118,7 @@ int main(int argc, char** argv) { wr[cur_topic].decrement(word); for (int k = 0; k < num_classes; ++k) { - ss[k]= dr[j].prob(k) * wr[k].prob(word); + ss[k]= dr[j].prob(k, class_p0) * wr[k].prob(word, word_p0); } const int new_topic = rng.SelectSample(ss); dr[j].increment(new_topic); diff --git a/gi/clda/src/crp.h b/gi/clda/src/crp.h index 865c41ac..b01a7f47 100644 --- a/gi/clda/src/crp.h +++ b/gi/clda/src/crp.h @@ -22,9 +22,6 @@ class CRP { const typename MapType::const_iterator i = counts_.find(dish); if (i == counts_.end()) return 0; else return i->second; } - inline prob_t prob(const DishType& dish) const { - return (prob_t(count(dish) + alpha_)) / prob_t(total_customers_ + alpha_); - } inline prob_t prob(const DishType& dish, const prob_t& p0) const { return (prob_t(count(dish)) + palpha_ * p0) / prob_t(total_customers_ + alpha_); } -- cgit v1.2.3