summaryrefslogtreecommitdiff
path: root/gi/clda/src/clda.cc
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-25 16:16:27 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-25 16:16:27 +0000
commitaa2a66378777d3825cc31640dff2e206491ccee3 (patch)
treedeed44f7528e70ed7710627e31f58ce50adb978a /gi/clda/src/clda.cc
parenteb7dd39a10474a58fda0ab237e92b30f9b352d55 (diff)
unit tests for PYP sampler
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@620 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/clda/src/clda.cc')
-rw-r--r--gi/clda/src/clda.cc38
1 files changed, 3 insertions, 35 deletions
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc
index 0232331b..10056bc9 100644
--- a/gi/clda/src/clda.cc
+++ b/gi/clda/src/clda.cc
@@ -25,39 +25,7 @@ void ShowTopWordsForTopic(const map<WordID, int>& counts) {
cerr << endl;
}
-void tc() {
- MT19937 rng;
- CCRP<string> crp(0.1, 5);
- double un = 0.25;
- int tt = 0;
- tt += crp.increment("hi", un, &rng);
- tt += crp.increment("foo", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- tt += crp.increment("bar", un, &rng);
- cout << "tt=" << tt << endl;
- cout << crp << endl;
- cout << " P(bar)=" << crp.prob("bar", un) << endl;
- cout << " P(hi)=" << crp.prob("hi", un) << endl;
- cout << " P(baz)=" << crp.prob("baz", un) << endl;
- cout << " P(foo)=" << crp.prob("foo", un) << endl;
- double x = crp.prob("bar", un) + crp.prob("hi", un) + crp.prob("baz", un) + crp.prob("foo", un);
- cout << " tot=" << x << endl;
- tt += crp.decrement("hi", &rng);
- tt += crp.decrement("bar", &rng);
- cout << crp << endl;
- tt += crp.decrement("bar", &rng);
- cout << crp << endl;
- cout << "tt=" << tt << endl;
- cout << crp.log_crp_prob() << endl;
-}
-
int main(int argc, char** argv) {
- tc();
if (argc != 3) {
cerr << "Usage: " << argv[0] << " num-classes num-samples\n";
return 1;
@@ -88,11 +56,11 @@ int main(int argc, char** argv) {
MT19937 rng;
cerr << "INITIALIZING RANDOM TOPIC ASSIGNMENTS\n";
zji.resize(wji.size());
- double disc = 0.05;
+ double disc = 0.1;
double beta = 10.0;
double alpha = 50.0;
- double uniform_topic = 1.0 / num_classes;
- double uniform_word = 1.0 / TD::NumWords();
+ const double uniform_topic = 1.0 / num_classes;
+ const double uniform_word = 1.0 / TD::NumWords();
vector<CCRP<int> > dr(zji.size(), CCRP<int>(disc, beta)); // dr[i] describes the probability of using a topic in document i
vector<CCRP<int> > wr(num_classes, CCRP<int>(disc, alpha)); // wr[k] describes the probability of generating a word in topic k
for (int j = 0; j < zji.size(); ++j) {