summaryrefslogtreecommitdiff
path: root/gi/pf/nuisance_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pf/nuisance_test.cc')
-rw-r--r--gi/pf/nuisance_test.cc161
1 files changed, 161 insertions, 0 deletions
diff --git a/gi/pf/nuisance_test.cc b/gi/pf/nuisance_test.cc
new file mode 100644
index 00000000..fc0af9cb
--- /dev/null
+++ b/gi/pf/nuisance_test.cc
@@ -0,0 +1,161 @@
+#include "ccrp.h"
+
+#include <vector>
+#include <iostream>
+
+#include "tdict.h"
+#include "transliterations.h"
+
+using namespace std;
+
+MT19937 rng;
+
+ostream& operator<<(ostream&os, const vector<int>& v) {
+ os << '[' << v[0];
+ if (v.size() == 2) os << ' ' << v[1];
+ return os << ']';
+}
+
+struct Base {
+ Base() : llh(), v(2), v1(1), v2(1), crp(0.25, 0.5) {}
+ inline double p0(const vector<int>& x) const {
+ double p = 0.75;
+ if (x.size() == 2) p = 0.25;
+ p *= 1.0 / 3.0;
+ if (x.size() == 2) p *= 1.0 / 3.0;
+ return p;
+ }
+ double est_deriv_prob(int a, int b, int seg) const {
+ assert(a > 0 && a < 4); // a \in {1,2,3}
+ assert(b > 0 && b < 4); // b \in {1,2,3}
+ assert(seg == 0 || seg == 1); // seg \in {0,1}
+ if (seg == 0) {
+ v[0] = a;
+ v[1] = b;
+ return crp.prob(v, p0(v));
+ } else {
+ v1[0] = a;
+ v2[0] = b;
+ return crp.prob(v1, p0(v1)) * crp.prob(v2, p0(v2));
+ }
+ }
+ double est_marginal_prob(int a, int b) const {
+ return est_deriv_prob(a,b,0) + est_deriv_prob(a,b,1);
+ }
+ int increment(int a, int b, double* pw = NULL) {
+ double p1 = est_deriv_prob(a, b, 0);
+ double p2 = est_deriv_prob(a, b, 1);
+ //p1 = 0.5; p2 = 0.5;
+ int seg = rng.SelectSample(p1,p2);
+ double tmp = 0;
+ if (!pw) pw = &tmp;
+ double& w = *pw;
+ if (seg == 0) {
+ v[0] = a;
+ v[1] = b;
+ w = crp.prob(v, p0(v)) / p1;
+ if (crp.increment(v, p0(v), &rng)) {
+ llh += log(p0(v));
+ }
+ } else {
+ v1[0] = a;
+ w = crp.prob(v1, p0(v1)) / p2;
+ if (crp.increment(v1, p0(v1), &rng)) {
+ llh += log(p0(v1));
+ }
+ v2[0] = b;
+ w *= crp.prob(v2, p0(v2));
+ if (crp.increment(v2, p0(v2), &rng)) {
+ llh += log(p0(v2));
+ }
+ }
+ return seg;
+ }
+ void increment(int a, int b, int seg) {
+ if (seg == 0) {
+ v[0] = a;
+ v[1] = b;
+ if (crp.increment(v, p0(v), &rng)) {
+ llh += log(p0(v));
+ }
+ } else {
+ v1[0] = a;
+ if (crp.increment(v1, p0(v1), &rng)) {
+ llh += log(p0(v1));
+ }
+ v2[0] = b;
+ if (crp.increment(v2, p0(v2), &rng)) {
+ llh += log(p0(v2));
+ }
+ }
+ }
+ void decrement(int a, int b, int seg) {
+ if (seg == 0) {
+ v[0] = a;
+ v[1] = b;
+ if (crp.decrement(v, &rng)) {
+ llh -= log(p0(v));
+ }
+ } else {
+ v1[0] = a;
+ if (crp.decrement(v1, &rng)) {
+ llh -= log(p0(v1));
+ }
+ v2[0] = b;
+ if (crp.decrement(v2, &rng)) {
+ llh -= log(p0(v2));
+ }
+ }
+ }
+ double log_likelihood() const {
+ return llh + crp.log_crp_prob();
+ }
+ double llh;
+ mutable vector<int> v, v1, v2;
+ CCRP<vector<int> > crp;
+};
+
+int main(int argc, char** argv) {
+ double tl = 0;
+ const int ITERS = 1000;
+ const int PARTICLES = 20;
+ const int DATAPOINTS = 50;
+ WordID x = TD::Convert("souvenons");
+ WordID y = TD::Convert("remember");
+ vector<WordID> src; TD::ConvertSentence("s o u v e n o n s", &src);
+ vector<WordID> trg; TD::ConvertSentence("r e m e m b e r", &trg);
+// Transliterations xx;
+// xx.Initialize(x, src, y, trg);
+// return 1;
+
+ for (int j = 0; j < ITERS; ++j) {
+ Base b;
+ vector<int> segs(DATAPOINTS);
+ SampleSet<double> ss;
+ vector<int> sss;
+ for (int i = 0; i < DATAPOINTS; i++) {
+ ss.clear();
+ sss.clear();
+ int x = ((i / 10) % 3) + 1;
+ int y = (i % 3) + 1;
+ //double ep = b.est_marginal_prob(x,y);
+ //cerr << "est p(" << x << "," << y << ") = " << ep << endl;
+ for (int n = 0; n < PARTICLES; ++n) {
+ double w;
+ int seg = b.increment(x,y,&w);
+ //cerr << seg << " w=" << w << endl;
+ ss.add(w);
+ sss.push_back(seg);
+ b.decrement(x,y,seg);
+ }
+ int seg = sss[rng.SelectSample(ss)];
+ b.increment(x, y, seg);
+ //cerr << "Selected: " << seg << endl;
+ //return 1;
+ segs[i] = seg;
+ }
+ tl += b.log_likelihood();
+ }
+ cerr << "LLH=" << tl / ITERS << endl;
+}
+