diff options
Diffstat (limited to 'rst_parser/arc_factored_marginals.cc')
| -rw-r--r-- | rst_parser/arc_factored_marginals.cc | 52 | 
1 files changed, 52 insertions, 0 deletions
| diff --git a/rst_parser/arc_factored_marginals.cc b/rst_parser/arc_factored_marginals.cc new file mode 100644 index 00000000..9851b59a --- /dev/null +++ b/rst_parser/arc_factored_marginals.cc @@ -0,0 +1,52 @@ +#include "arc_factored.h" + +#include <iostream> + +#include "config.h" + +using namespace std; + +#if HAVE_EIGEN + +#include <Eigen/Dense> +typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> ArcMatrix; +typedef Eigen::Matrix<double, Eigen::Dynamic, 1> RootVector; + +void ArcFactoredForest::EdgeMarginals(double *plog_z) { +  ArcMatrix A(num_words_,num_words_); +  RootVector r(num_words_); +  for (int h = 0; h < num_words_; ++h) { +    for (int m = 0; m < num_words_; ++m) { +      if (h != m) +        A(h,m) = edges_(h,m).edge_prob.as_float(); +      else +        A(h,m) = 0; +    } +    r(h) = root_edges_[h].edge_prob.as_float(); +  } + +  ArcMatrix L = -A; +  L.diagonal() = A.colwise().sum(); +  L.row(0) = r; +  ArcMatrix Linv = L.inverse(); +  if (plog_z) *plog_z = log(Linv.determinant()); +  RootVector rootMarginals = r.cwiseProduct(Linv.col(0)); +  for (int h = 0; h < num_words_; ++h) { +    for (int m = 0; m < num_words_; ++m) { +      edges_(h,m).edge_prob = prob_t((m == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,m) - +                                     (h == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,h)); +    } +    root_edges_[h].edge_prob = prob_t(rootMarginals(h)); +  } +  // cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl; +} + +#else + +void ArcFactoredForest::EdgeMarginals(double*) { +  cerr << "EdgeMarginals() requires --with-eigen!\n"; +  abort(); +} + +#endif + | 
