From eb7dd39a10474a58fda0ab237e92b30f9b352d55 Mon Sep 17 00:00:00 2001 From: redpony Date: Wed, 25 Aug 2010 04:03:35 +0000 Subject: compute prob of crp git-svn-id: https://ws10smt.googlecode.com/svn/trunk@619 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/clda/src/clda.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'gi/clda/src/clda.cc') diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc index ea224a27..0232331b 100644 --- a/gi/clda/src/clda.cc +++ b/gi/clda/src/clda.cc @@ -53,7 +53,7 @@ void tc() { tt += crp.decrement("bar", &rng); cout << crp << endl; cout << "tt=" << tt << endl; - cout << crp.llh() << endl; + cout << crp.log_crp_prob() << endl; } int main(int argc, char** argv) { @@ -64,7 +64,7 @@ int main(int argc, char** argv) { } const int num_classes = atoi(argv[1]); const int num_iterations = atoi(argv[2]); - const int burnin_size = num_iterations * 0.666; + const int burnin_size = num_iterations * 0.9; if (num_classes < 2) { cerr << "Must request more than 1 class\n"; return 1; @@ -113,7 +113,6 @@ int main(int argc, char** argv) { vector > t2w(num_classes); Timer timer; SampleSet ss; - const int num_types = TD::NumWords(); ss.resize(num_classes); double total_time = 0; for (int iter = 0; iter < num_iterations; ++iter) { @@ -121,6 +120,12 @@ int main(int argc, char** argv) { if (iter && iter % 10 == 0) { total_time += timer.Elapsed(); timer.Reset(); + double llh = 0; + for (int j = 0; j < dr.size(); ++j) + llh += dr[j].log_crp_prob(); + for (int j = 0; j < wr.size(); ++j) + llh += wr[j].log_crp_prob(); + cerr << " [LLH=" << llh << " I=" << iter << "]\n"; } for (int j = 0; j < zji.size(); ++j) { const size_t num_words = wji[j].size(); -- cgit v1.2.3