diff options
Diffstat (limited to 'gi/pf/nuisance_test.cc')
-rw-r--r-- | gi/pf/nuisance_test.cc | 161 |
1 files changed, 161 insertions, 0 deletions
diff --git a/gi/pf/nuisance_test.cc b/gi/pf/nuisance_test.cc new file mode 100644 index 00000000..0f44fe95 --- /dev/null +++ b/gi/pf/nuisance_test.cc @@ -0,0 +1,161 @@ +#include "ccrp.h" + +#include <vector> +#include <iostream> + +#include "tdict.h" +#include "transliterations.h" + +using namespace std; + +MT19937 rng; + +ostream& operator<<(ostream&os, const vector<int>& v) { + os << '[' << v[0]; + if (v.size() == 2) os << ' ' << v[1]; + return os << ']'; +} + +struct Base { + Base() : llh(), v(2), v1(1), v2(1), crp(0.25, 0.5) {} + inline double p0(const vector<int>& x) const { + double p = 0.75; + if (x.size() == 2) p = 0.25; + p *= 1.0 / 3.0; + if (x.size() == 2) p *= 1.0 / 3.0; + return p; + } + double est_deriv_prob(int a, int b, int seg) const { + assert(a > 0 && a < 4); // a \in {1,2,3} + assert(b > 0 && b < 4); // b \in {1,2,3} + assert(seg == 0 || seg == 1); // seg \in {0,1} + if (seg == 0) { + v[0] = a; + v[1] = b; + return crp.prob(v, p0(v)); + } else { + v1[0] = a; + v2[0] = b; + return crp.prob(v1, p0(v1)) * crp.prob(v2, p0(v2)); + } + } + double est_marginal_prob(int a, int b) const { + return est_deriv_prob(a,b,0) + est_deriv_prob(a,b,1); + } + int increment(int a, int b, double* pw = NULL) { + double p1 = est_deriv_prob(a, b, 0); + double p2 = est_deriv_prob(a, b, 1); + //p1 = 0.5; p2 = 0.5; + int seg = rng.SelectSample(p1,p2); + double tmp = 0; + if (!pw) pw = &tmp; + double& w = *pw; + if (seg == 0) { + v[0] = a; + v[1] = b; + w = crp.prob(v, p0(v)) / p1; + if (crp.increment(v, p0(v), &rng)) { + llh += log(p0(v)); + } + } else { + v1[0] = a; + w = crp.prob(v1, p0(v1)) / p2; + if (crp.increment(v1, p0(v1), &rng)) { + llh += log(p0(v1)); + } + v2[0] = b; + w *= crp.prob(v2, p0(v2)); + if (crp.increment(v2, p0(v2), &rng)) { + llh += log(p0(v2)); + } + } + return seg; + } + void increment(int a, int b, int seg) { + if (seg == 0) { + v[0] = a; + v[1] = b; + if (crp.increment(v, p0(v), &rng)) { + llh += log(p0(v)); + } + } else { + v1[0] = a; + if (crp.increment(v1, p0(v1), &rng)) { + llh += log(p0(v1)); + } + v2[0] = b; + if (crp.increment(v2, p0(v2), &rng)) { + llh += log(p0(v2)); + } + } + } + void decrement(int a, int b, int seg) { + if (seg == 0) { + v[0] = a; + v[1] = b; + if (crp.decrement(v, &rng)) { + llh -= log(p0(v)); + } + } else { + v1[0] = a; + if (crp.decrement(v1, &rng)) { + llh -= log(p0(v1)); + } + v2[0] = b; + if (crp.decrement(v2, &rng)) { + llh -= log(p0(v2)); + } + } + } + double log_likelihood() const { + return llh + crp.log_crp_prob(); + } + double llh; + mutable vector<int> v, v1, v2; + CCRP<vector<int> > crp; +}; + +int main(int argc, char** argv) { + double tl = 0; + const int ITERS = 1000; + const int PARTICLES = 20; + const int DATAPOINTS = 50; + WordID x = TD::Convert("souvenons"); + WordID y = TD::Convert("remember"); + vector<WordID> src; TD::ConvertSentence("s o u v e n o n s", &src); + vector<WordID> trg; TD::ConvertSentence("r e m e m b e r", &trg); + Transliterations xx; + xx.Initialize(x, src, y, trg); + return 1; + + for (int j = 0; j < ITERS; ++j) { + Base b; + vector<int> segs(DATAPOINTS); + SampleSet<double> ss; + vector<int> sss; + for (int i = 0; i < DATAPOINTS; i++) { + ss.clear(); + sss.clear(); + int x = ((i / 10) % 3) + 1; + int y = (i % 3) + 1; + //double ep = b.est_marginal_prob(x,y); + //cerr << "est p(" << x << "," << y << ") = " << ep << endl; + for (int n = 0; n < PARTICLES; ++n) { + double w; + int seg = b.increment(x,y,&w); + //cerr << seg << " w=" << w << endl; + ss.add(w); + sss.push_back(seg); + b.decrement(x,y,seg); + } + int seg = sss[rng.SelectSample(ss)]; + b.increment(x, y, seg); + //cerr << "Selected: " << seg << endl; + //return 1; + segs[i] = seg; + } + tl += b.log_likelihood(); + } + cerr << "LLH=" << tl / ITERS << endl; +} + |