summaryrefslogtreecommitdiff
path: root/gi/pf/quasi_model2.h
blob: 588c8f84235aeb86b6be61bd3e19287d00012452 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#ifndef _QUASI_MODEL2_H_
#define _QUASI_MODEL2_H_

#include <vector>
#include <cmath>
#include <tr1/unordered_map>
#include "boost/functional.hpp"
#include "prob.h"
#include "array2d.h"
#include "slice_sampler.h"
#include "m.h"

struct AlignmentObservation {
  AlignmentObservation() : src_len(), trg_len(), j(), a_j() {}
  AlignmentObservation(unsigned sl, unsigned tl, unsigned tw, unsigned sw) :
      src_len(sl), trg_len(tl), j(tw), a_j(sw) {}
  unsigned short src_len;
  unsigned short trg_len;
  unsigned short j;
  unsigned short a_j;
};

inline size_t hash_value(const AlignmentObservation& o) {
  return reinterpret_cast<const size_t&>(o);
}

inline bool operator==(const AlignmentObservation& a, const AlignmentObservation& b) {
  return hash_value(a) == hash_value(b);
}

struct QuasiModel2 {
  explicit QuasiModel2(double alpha, double pnull = 0.1) :
      alpha_(alpha),
      pnull_(pnull),
      pnotnull_(1 - pnull) {}

  // a_j = 0 => NULL; src_len does *not* include null
  prob_t Prob(unsigned a_j, unsigned j, unsigned src_len, unsigned trg_len) const {
    if (!a_j) return pnull_;
    return pnotnull_ *
       prob_t(UnnormalizedProb(a_j, j, src_len, trg_len, alpha_) / GetOrComputeZ(j, src_len, trg_len));
  }

  void Increment(unsigned a_j, unsigned j, unsigned src_len, unsigned trg_len) {
    assert(a_j <= src_len);
    assert(j < trg_len);
    ++obs_[AlignmentObservation(src_len, trg_len, j, a_j)];
  }

  void Decrement(unsigned a_j, unsigned j, unsigned src_len, unsigned trg_len) {
    const AlignmentObservation ao(src_len, trg_len, j, a_j);
    int &cc = obs_[ao];
    assert(cc > 0);
    --cc;
    if (!cc) obs_.erase(ao);
  }

  struct PNullResampler {
    PNullResampler(const QuasiModel2& m) : m_(m) {}
    const QuasiModel2& m_;
    double operator()(const double& proposed_pnull) const {
      return log(m_.Likelihood(m_.alpha_, proposed_pnull));
    }
  };

  struct AlphaResampler {
    AlphaResampler(const QuasiModel2& m) : m_(m) {}
    const QuasiModel2& m_;
    double operator()(const double& proposed_alpha) const {
      return log(m_.Likelihood(proposed_alpha, m_.pnull_.as_float()));
    }
  };

  void ResampleHyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) {
    const PNullResampler dr(*this);
    const AlphaResampler ar(*this);
    for (unsigned i = 0; i < nloop; ++i) {
      double pnull = slice_sampler1d(dr, pnull_.as_float(), *rng, 0.00000001,
                            1.0, 0.0, niterations, 100*niterations);
      pnull_ = prob_t(pnull);
      alpha_ = slice_sampler1d(ar, alpha_, *rng, 0.00000001,
                              std::numeric_limits<double>::infinity(), 0.0, niterations, 100*niterations);
    }
    std::cerr << "QuasiModel2(alpha=" << alpha_ << ",p_null="
              << pnull_.as_float() << ") = " << Likelihood() << std::endl;
    zcache_.clear();
  }

  prob_t Likelihood() const {
    return Likelihood(alpha_, pnull_.as_float());
  }

  prob_t Likelihood(double alpha, double ppnull) const {
    const prob_t pnull(ppnull);
    const prob_t pnotnull(1 - ppnull);

    prob_t p;
    p.logeq(Md::log_gamma_density(alpha, 0.1, 25));  // TODO configure
    assert(!p.is_0());
    prob_t prob_of_ppnull; prob_of_ppnull.logeq(Md::log_beta_density(ppnull, 2, 10));
    assert(!prob_of_ppnull.is_0());
    p *= prob_of_ppnull;
    for (ObsCount::const_iterator it = obs_.begin(); it != obs_.end(); ++it) {
      const AlignmentObservation& ao = it->first;
      if (ao.a_j) {
        prob_t u = XUnnormalizedProb(ao.a_j, ao.j, ao.src_len, ao.trg_len, alpha);
        prob_t z = XComputeZ(ao.j, ao.src_len, ao.trg_len, alpha);
        prob_t pa(u / z);
        pa *= pnotnull;
        pa.poweq(it->second);
        p *= pa;
      } else {
        p *= pnull.pow(it->second);
      }
    }
    return p;
  }

 private:
  static prob_t XUnnormalizedProb(unsigned a_j, unsigned j, unsigned src_len, unsigned trg_len, double alpha) {
    prob_t p;
    p.logeq(-fabs(double(a_j - 1) / src_len - double(j) / trg_len) * alpha);
    return p;
  }

  static prob_t XComputeZ(unsigned j, unsigned src_len, unsigned trg_len, double alpha) {
    prob_t z = prob_t::Zero();
    for (int a_j = 1; a_j <= src_len; ++a_j)
      z += XUnnormalizedProb(a_j, j, src_len, trg_len, alpha);
    return z;
  }

  static double UnnormalizedProb(unsigned a_j, unsigned j, unsigned src_len, unsigned trg_len, double alpha) {
    return exp(-fabs(double(a_j - 1) / src_len - double(j) / trg_len) * alpha);
  }

  static double ComputeZ(unsigned j, unsigned src_len, unsigned trg_len, double alpha) {
    double z = 0;
    for (int a_j = 1; a_j <= src_len; ++a_j)
      z += UnnormalizedProb(a_j, j, src_len, trg_len, alpha);
    return z;
  }

  const double& GetOrComputeZ(unsigned j, unsigned src_len, unsigned trg_len) const {
    if (src_len >= zcache_.size())
      zcache_.resize(src_len + 1);
    if (trg_len >= zcache_[src_len].size())
      zcache_[src_len].resize(trg_len + 1);
    std::vector<double>& zv = zcache_[src_len][trg_len];
    if (zv.size() == 0)
      zv.resize(trg_len);
    double& z = zv[j];
    if (!z)
      z = ComputeZ(j, src_len, trg_len, alpha_);
    return z;
  }

  double alpha_;
  prob_t pnull_;
  prob_t pnotnull_;
  mutable std::vector<std::vector<std::vector<double> > > zcache_;
  typedef std::tr1::unordered_map<AlignmentObservation, int, boost::hash<AlignmentObservation> > ObsCount;
  ObsCount obs_;
};

#endif