#include "ccrp.h" #include #include #include "tdict.h" #include "transliterations.h" using namespace std; MT19937 rng; ostream& operator<<(ostream&os, const vector& v) { os << '[' << v[0]; if (v.size() == 2) os << ' ' << v[1]; return os << ']'; } struct Base { Base() : llh(), v(2), v1(1), v2(1), crp(0.25, 0.5) {} inline double p0(const vector& x) const { double p = 0.75; if (x.size() == 2) p = 0.25; p *= 1.0 / 3.0; if (x.size() == 2) p *= 1.0 / 3.0; return p; } double est_deriv_prob(int a, int b, int seg) const { assert(a > 0 && a < 4); // a \in {1,2,3} assert(b > 0 && b < 4); // b \in {1,2,3} assert(seg == 0 || seg == 1); // seg \in {0,1} if (seg == 0) { v[0] = a; v[1] = b; return crp.prob(v, p0(v)); } else { v1[0] = a; v2[0] = b; return crp.prob(v1, p0(v1)) * crp.prob(v2, p0(v2)); } } double est_marginal_prob(int a, int b) const { return est_deriv_prob(a,b,0) + est_deriv_prob(a,b,1); } int increment(int a, int b, double* pw = NULL) { double p1 = est_deriv_prob(a, b, 0); double p2 = est_deriv_prob(a, b, 1); //p1 = 0.5; p2 = 0.5; int seg = rng.SelectSample(p1,p2); double tmp = 0; if (!pw) pw = &tmp; double& w = *pw; if (seg == 0) { v[0] = a; v[1] = b; w = crp.prob(v, p0(v)) / p1; if (crp.increment(v, p0(v), &rng)) { llh += log(p0(v)); } } else { v1[0] = a; w = crp.prob(v1, p0(v1)) / p2; if (crp.increment(v1, p0(v1), &rng)) { llh += log(p0(v1)); } v2[0] = b; w *= crp.prob(v2, p0(v2)); if (crp.increment(v2, p0(v2), &rng)) { llh += log(p0(v2)); } } return seg; } void increment(int a, int b, int seg) { if (seg == 0) { v[0] = a; v[1] = b; if (crp.increment(v, p0(v), &rng)) { llh += log(p0(v)); } } else { v1[0] = a; if (crp.increment(v1, p0(v1), &rng)) { llh += log(p0(v1)); } v2[0] = b; if (crp.increment(v2, p0(v2), &rng)) { llh += log(p0(v2)); } } } void decrement(int a, int b, int seg) { if (seg == 0) { v[0] = a; v[1] = b; if (crp.decrement(v, &rng)) { llh -= log(p0(v)); } } else { v1[0] = a; if (crp.decrement(v1, &rng)) { llh -= log(p0(v1)); } v2[0] = b; if (crp.decrement(v2, &rng)) { llh -= log(p0(v2)); } } } double log_likelihood() const { return llh + crp.log_crp_prob(); } double llh; mutable vector v, v1, v2; CCRP > crp; }; int main(int argc, char** argv) { double tl = 0; const int ITERS = 1000; const int PARTICLES = 20; const int DATAPOINTS = 50; WordID x = TD::Convert("souvenons"); WordID y = TD::Convert("remember"); vector src; TD::ConvertSentence("s o u v e n o n s", &src); vector trg; TD::ConvertSentence("r e m e m b e r", &trg); // Transliterations xx; // xx.Initialize(x, src, y, trg); // return 1; for (int j = 0; j < ITERS; ++j) { Base b; vector segs(DATAPOINTS); SampleSet ss; vector sss; for (int i = 0; i < DATAPOINTS; i++) { ss.clear(); sss.clear(); int x = ((i / 10) % 3) + 1; int y = (i % 3) + 1; //double ep = b.est_marginal_prob(x,y); //cerr << "est p(" << x << "," << y << ") = " << ep << endl; for (int n = 0; n < PARTICLES; ++n) { double w; int seg = b.increment(x,y,&w); //cerr << seg << " w=" << w << endl; ss.add(w); sss.push_back(seg); b.decrement(x,y,seg); } int seg = sss[rng.SelectSample(ss)]; b.increment(x, y, seg); //cerr << "Selected: " << seg << endl; //return 1; segs[i] = seg; } tl += b.log_likelihood(); } cerr << "LLH=" << tl / ITERS << endl; }