summaryrefslogtreecommitdiff
path: root/gi/pf/mh_test.cc
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-10-22 12:07:20 +0100
committerKenneth Heafield <github@kheafield.com>2012-10-22 12:07:20 +0100
commit5f98fe5c4f2a2090eeb9d30c030305a70a8347d1 (patch)
tree9b6002f850e6dea1e3400c6b19bb31a9cdf3067f /gi/pf/mh_test.cc
parentcf9994131993b40be62e90e213b1e11e6b550143 (diff)
parent21825a09d97c2e0afd20512f306fb25fed55e529 (diff)
Merge remote branch 'upstream/master'
Conflicts: Jamroot bjam decoder/Jamfile decoder/cdec.cc dpmert/Jamfile jam-files/sanity.jam klm/lm/Jamfile klm/util/Jamfile mira/Jamfile
Diffstat (limited to 'gi/pf/mh_test.cc')
-rw-r--r--gi/pf/mh_test.cc148
1 files changed, 0 insertions, 148 deletions
diff --git a/gi/pf/mh_test.cc b/gi/pf/mh_test.cc
deleted file mode 100644
index 296e7285..00000000
--- a/gi/pf/mh_test.cc
+++ /dev/null
@@ -1,148 +0,0 @@
-#include "ccrp.h"
-
-#include <vector>
-#include <iostream>
-
-#include "tdict.h"
-#include "transliterations.h"
-
-using namespace std;
-
-MT19937 rng;
-
-static bool verbose = false;
-
-struct Model {
-
- Model() : bp(), base(0.2, 0.6) , ccrps(5, CCRP<int>(0.8, 0.5)) {}
-
- double p0(int x) const {
- assert(x > 0);
- assert(x < 5);
- return 1.0/4.0;
- }
-
- double llh() const {
- double lh = bp + base.log_crp_prob();
- for (int ctx = 1; ctx < 5; ++ctx)
- lh += ccrps[ctx].log_crp_prob();
- return lh;
- }
-
- double prob(int ctx, int x) const {
- assert(ctx > 0 && ctx < 5);
- return ccrps[ctx].prob(x, base.prob(x, p0(x)));
- }
-
- void increment(int ctx, int x) {
- assert(ctx > 0 && ctx < 5);
- if (ccrps[ctx].increment(x, base.prob(x, p0(x)), &rng)) {
- if (base.increment(x, p0(x), &rng)) {
- bp += log(1.0 / 4.0);
- }
- }
- }
-
- // this is just a biased estimate
- double est_base_prob(int x) {
- return (x + 1) * x / 40.0;
- }
-
- void increment_is(int ctx, int x) {
- assert(ctx > 0 && ctx < 5);
- SampleSet<double> ss;
- const int PARTICLES = 25;
- vector<CCRP<int> > s1s(PARTICLES, CCRP<int>(0.5,0.5));
- vector<CCRP<int> > sbs(PARTICLES, CCRP<int>(0.5,0.5));
- vector<double> sp0s(PARTICLES);
-
- CCRP<int> s1 = ccrps[ctx];
- CCRP<int> sb = base;
- double sp0 = bp;
- for (int pp = 0; pp < PARTICLES; ++pp) {
- if (pp > 0) {
- ccrps[ctx] = s1;
- base = sb;
- bp = sp0;
- }
-
- double q = 1;
- double gamma = 1;
- double est_p = est_base_prob(x);
- //base.prob(x, p0(x)) + rng.next() * 0.1;
- if (ccrps[ctx].increment(x, est_p, &rng, &q)) {
- gamma = q * base.prob(x, p0(x));
- q *= est_p;
- if (verbose) cerr << "(DP-base draw) ";
- double qq = -1;
- if (base.increment(x, p0(x), &rng, &qq)) {
- if (verbose) cerr << "(G0 draw) ";
- bp += log(p0(x));
- qq *= p0(x);
- }
- } else { gamma = q; }
- double w = gamma / q;
- if (verbose)
- cerr << "gamma=" << gamma << " q=" << q << "\tw=" << w << endl;
- ss.add(w);
- s1s[pp] = ccrps[ctx];
- sbs[pp] = base;
- sp0s[pp] = bp;
- }
- int ps = rng.SelectSample(ss);
- ccrps[ctx] = s1s[ps];
- base = sbs[ps];
- bp = sp0s[ps];
- if (verbose) {
- cerr << "SELECTED: " << ps << endl;
- static int cc = 0; cc++; if (cc ==10) exit(1);
- }
- }
-
- void decrement(int ctx, int x) {
- assert(ctx > 0 && ctx < 5);
- if (ccrps[ctx].decrement(x, &rng)) {
- if (base.decrement(x, &rng)) {
- bp -= log(p0(x));
- }
- }
- }
-
- double bp;
- CCRP<int> base;
- vector<CCRP<int> > ccrps;
-
-};
-
-int main(int argc, char** argv) {
- if (argc > 1) { verbose = true; }
- vector<int> counts(15, 0);
- vector<int> tcounts(15, 0);
- int points[] = {1,2, 2,2, 3,2, 4,1, 3, 4, 3, 3, 2, 3, 4, 1, 4, 1, 3, 2, 1, 3, 1, 4, 0, 0};
- double tlh = 0;
- double tt = 0;
- for (int n = 0; n < 1000; ++n) {
- if (n % 10 == 0) cerr << '.';
- if ((n+1) % 400 == 0) cerr << " [" << (n+1) << "]\n";
- Model m;
- for (int *x = points; *x; x += 2)
- m.increment(x[0], x[1]);
-
- for (int j = 0; j < 24; ++j) {
- for (int *x = points; *x; x += 2) {
- if (rng.next() < 0.8) {
- m.decrement(x[0], x[1]);
- m.increment_is(x[0], x[1]);
- }
- }
- }
- counts[m.base.num_customers()]++;
- tcounts[m.base.num_tables()]++;
- tlh += m.llh();
- tt += 1.0;
- }
- cerr << "mean LLH = " << (tlh / tt) << endl;
- for (int i = 0; i < 15; ++i)
- cerr << i << ": " << (counts[i] / tt) << "\t" << (tcounts[i] / tt) << endl;
-}
-