summaryrefslogtreecommitdiff
path: root/rst_parser/arc_factored_marginals.cc
blob: 16360b0d72db82c27a0f7191a7101b9f9e8c227a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include "arc_factored.h"

#include <iostream>

#include "config.h"

using namespace std;

#if HAVE_EIGEN

#include <Eigen/Dense>
typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> ArcMatrix;
typedef Eigen::Matrix<double, Eigen::Dynamic, 1> RootVector;

void ArcFactoredForest::EdgeMarginals(double *plog_z) {
  ArcMatrix A(num_words_,num_words_);
  RootVector r(num_words_);
  for (int h = 0; h < num_words_; ++h) {
    for (int m = 0; m < num_words_; ++m) {
      if (h != m)
        A(h,m) = edges_(h,m).edge_prob.as_float();
      else
        A(h,m) = 0;
    }
    r(h) = root_edges_[h].edge_prob.as_float();
  }

  ArcMatrix L = -A;
  L.diagonal() = A.colwise().sum();
  L.row(0) = r;
  ArcMatrix Linv = L.inverse();
  if (plog_z) *plog_z = log(Linv.determinant());
  RootVector rootMarginals = r.cwiseProduct(Linv.col(0));
//  ArcMatrix T = Linv;
  for (int h = 0; h < num_words_; ++h) {
    for (int m = 0; m < num_words_; ++m) {
      const double marginal = (m == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,m) -
                              (h == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,h);
      edges_(h,m).edge_prob = prob_t(marginal);
//      T(h,m) = marginal;
    }
    root_edges_[h].edge_prob = prob_t(rootMarginals(h));
  }
//   cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl;
//  cerr << "M:\n" << T << endl;
}

#else

void ArcFactoredForest::EdgeMarginals(double*) {
  cerr << "EdgeMarginals() requires --with-eigen!\n";
  abort();
}

#endif