diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-23 14:49:22 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-23 14:49:22 +0000 |
commit | 70ad159e22fc6ea12a5e7b468ab38a93c3ed111f (patch) | |
tree | 75a9064e0eae0de81055632320461ef3e3efdb64 | |
parent | e011e33a5ba5f1c7ef213009d2eec8a30cacd679 (diff) |
fix bugs in CRP
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@10 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r-- | gi/clda/src/clda.cc | 9 | ||||
-rw-r--r-- | gi/clda/src/crp.h | 3 |
2 files changed, 6 insertions, 6 deletions
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc index 976b020f..482a1c4c 100644 --- a/gi/clda/src/clda.cc +++ b/gi/clda/src/clda.cc @@ -79,6 +79,9 @@ int main(int argc, char** argv) { vector<map<WordID, int> > t2w(num_classes); Timer timer; SampleSet ss; + const int num_types = TD::dict_.max(); + const prob_t class_p0(1.0 / num_classes); + const prob_t word_p0(1.0 / num_types); ss.resize(num_classes); double total_time = 0; for (int iter = 0; iter < num_iterations; ++iter) { @@ -94,8 +97,8 @@ int main(int argc, char** argv) { for (int i = 0; i < num_words; ++i) { const int word = wj[i]; const int cur_topic = zj[i]; - lh *= dr[j].prob(cur_topic); - lh *= wr[cur_topic].prob(word); + lh *= dr[j].prob(cur_topic, class_p0); + lh *= wr[cur_topic].prob(word, word_p0); if (iter > burnin_size) { ++t2w[cur_topic][word]; } @@ -115,7 +118,7 @@ int main(int argc, char** argv) { wr[cur_topic].decrement(word); for (int k = 0; k < num_classes; ++k) { - ss[k]= dr[j].prob(k) * wr[k].prob(word); + ss[k]= dr[j].prob(k, class_p0) * wr[k].prob(word, word_p0); } const int new_topic = rng.SelectSample(ss); dr[j].increment(new_topic); diff --git a/gi/clda/src/crp.h b/gi/clda/src/crp.h index 865c41ac..b01a7f47 100644 --- a/gi/clda/src/crp.h +++ b/gi/clda/src/crp.h @@ -22,9 +22,6 @@ class CRP { const typename MapType::const_iterator i = counts_.find(dish); if (i == counts_.end()) return 0; else return i->second; } - inline prob_t prob(const DishType& dish) const { - return (prob_t(count(dish) + alpha_)) / prob_t(total_customers_ + alpha_); - } inline prob_t prob(const DishType& dish, const prob_t& p0) const { return (prob_t(count(dish)) + palpha_ * p0) / prob_t(total_customers_ + alpha_); } |