summaryrefslogtreecommitdiff
path: root/gi/clda/src/clda.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/clda/src/clda.cc')
-rw-r--r--gi/clda/src/clda.cc11
1 files changed, 8 insertions, 3 deletions
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc
index ea224a27..0232331b 100644
--- a/gi/clda/src/clda.cc
+++ b/gi/clda/src/clda.cc
@@ -53,7 +53,7 @@ void tc() {
tt += crp.decrement("bar", &rng);
cout << crp << endl;
cout << "tt=" << tt << endl;
- cout << crp.llh() << endl;
+ cout << crp.log_crp_prob() << endl;
}
int main(int argc, char** argv) {
@@ -64,7 +64,7 @@ int main(int argc, char** argv) {
}
const int num_classes = atoi(argv[1]);
const int num_iterations = atoi(argv[2]);
- const int burnin_size = num_iterations * 0.666;
+ const int burnin_size = num_iterations * 0.9;
if (num_classes < 2) {
cerr << "Must request more than 1 class\n";
return 1;
@@ -113,7 +113,6 @@ int main(int argc, char** argv) {
vector<map<WordID, int> > t2w(num_classes);
Timer timer;
SampleSet<double> ss;
- const int num_types = TD::NumWords();
ss.resize(num_classes);
double total_time = 0;
for (int iter = 0; iter < num_iterations; ++iter) {
@@ -121,6 +120,12 @@ int main(int argc, char** argv) {
if (iter && iter % 10 == 0) {
total_time += timer.Elapsed();
timer.Reset();
+ double llh = 0;
+ for (int j = 0; j < dr.size(); ++j)
+ llh += dr[j].log_crp_prob();
+ for (int j = 0; j < wr.size(); ++j)
+ llh += wr[j].log_crp_prob();
+ cerr << " [LLH=" << llh << " I=" << iter << "]\n";
}
for (int j = 0; j < zji.size(); ++j) {
const size_t num_words = wji[j].size();