summaryrefslogtreecommitdiff
path: root/gi/clda
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-23 14:49:22 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-23 14:49:22 +0000
commit70ad159e22fc6ea12a5e7b468ab38a93c3ed111f (patch)
tree75a9064e0eae0de81055632320461ef3e3efdb64 /gi/clda
parente011e33a5ba5f1c7ef213009d2eec8a30cacd679 (diff)
fix bugs in CRP
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@10 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/clda')
-rw-r--r--gi/clda/src/clda.cc9
-rw-r--r--gi/clda/src/crp.h3
2 files changed, 6 insertions, 6 deletions
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc
index 976b020f..482a1c4c 100644
--- a/gi/clda/src/clda.cc
+++ b/gi/clda/src/clda.cc
@@ -79,6 +79,9 @@ int main(int argc, char** argv) {
vector<map<WordID, int> > t2w(num_classes);
Timer timer;
SampleSet ss;
+ const int num_types = TD::dict_.max();
+ const prob_t class_p0(1.0 / num_classes);
+ const prob_t word_p0(1.0 / num_types);
ss.resize(num_classes);
double total_time = 0;
for (int iter = 0; iter < num_iterations; ++iter) {
@@ -94,8 +97,8 @@ int main(int argc, char** argv) {
for (int i = 0; i < num_words; ++i) {
const int word = wj[i];
const int cur_topic = zj[i];
- lh *= dr[j].prob(cur_topic);
- lh *= wr[cur_topic].prob(word);
+ lh *= dr[j].prob(cur_topic, class_p0);
+ lh *= wr[cur_topic].prob(word, word_p0);
if (iter > burnin_size) {
++t2w[cur_topic][word];
}
@@ -115,7 +118,7 @@ int main(int argc, char** argv) {
wr[cur_topic].decrement(word);
for (int k = 0; k < num_classes; ++k) {
- ss[k]= dr[j].prob(k) * wr[k].prob(word);
+ ss[k]= dr[j].prob(k, class_p0) * wr[k].prob(word, word_p0);
}
const int new_topic = rng.SelectSample(ss);
dr[j].increment(new_topic);
diff --git a/gi/clda/src/crp.h b/gi/clda/src/crp.h
index 865c41ac..b01a7f47 100644
--- a/gi/clda/src/crp.h
+++ b/gi/clda/src/crp.h
@@ -22,9 +22,6 @@ class CRP {
const typename MapType::const_iterator i = counts_.find(dish);
if (i == counts_.end()) return 0; else return i->second;
}
- inline prob_t prob(const DishType& dish) const {
- return (prob_t(count(dish) + alpha_)) / prob_t(total_customers_ + alpha_);
- }
inline prob_t prob(const DishType& dish, const prob_t& p0) const {
return (prob_t(count(dish)) + palpha_ * p0) / prob_t(total_customers_ + alpha_);
}