summaryrefslogtreecommitdiff
path: root/rst_parser/arc_factored_marginals.cc
diff options
context:
space:
mode:
Diffstat (limited to 'rst_parser/arc_factored_marginals.cc')
-rw-r--r--rst_parser/arc_factored_marginals.cc24
1 files changed, 13 insertions, 11 deletions
diff --git a/rst_parser/arc_factored_marginals.cc b/rst_parser/arc_factored_marginals.cc
index 16360b0d..acb8102a 100644
--- a/rst_parser/arc_factored_marginals.cc
+++ b/rst_parser/arc_factored_marginals.cc
@@ -9,37 +9,39 @@ using namespace std;
#if HAVE_EIGEN
#include <Eigen/Dense>
-typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> ArcMatrix;
-typedef Eigen::Matrix<double, Eigen::Dynamic, 1> RootVector;
+typedef Eigen::Matrix<prob_t, Eigen::Dynamic, Eigen::Dynamic> ArcMatrix;
+typedef Eigen::Matrix<prob_t, Eigen::Dynamic, 1> RootVector;
-void ArcFactoredForest::EdgeMarginals(double *plog_z) {
+void ArcFactoredForest::EdgeMarginals(prob_t *plog_z) {
ArcMatrix A(num_words_,num_words_);
RootVector r(num_words_);
for (int h = 0; h < num_words_; ++h) {
for (int m = 0; m < num_words_; ++m) {
if (h != m)
- A(h,m) = edges_(h,m).edge_prob.as_float();
+ A(h,m) = edges_(h,m).edge_prob;
else
- A(h,m) = 0;
+ A(h,m) = prob_t::Zero();
}
- r(h) = root_edges_[h].edge_prob.as_float();
+ r(h) = root_edges_[h].edge_prob;
}
ArcMatrix L = -A;
L.diagonal() = A.colwise().sum();
L.row(0) = r;
ArcMatrix Linv = L.inverse();
- if (plog_z) *plog_z = log(Linv.determinant());
+ if (plog_z) *plog_z = Linv.determinant();
RootVector rootMarginals = r.cwiseProduct(Linv.col(0));
+ static const prob_t ZERO(0);
+ static const prob_t ONE(1);
// ArcMatrix T = Linv;
for (int h = 0; h < num_words_; ++h) {
for (int m = 0; m < num_words_; ++m) {
- const double marginal = (m == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,m) -
- (h == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,h);
- edges_(h,m).edge_prob = prob_t(marginal);
+ const prob_t marginal = (m == 0 ? ZERO : ONE) * A(h,m) * Linv(m,m) -
+ (h == 0 ? ZERO : ONE) * A(h,m) * Linv(m,h);
+ edges_(h,m).edge_prob = marginal;
// T(h,m) = marginal;
}
- root_edges_[h].edge_prob = prob_t(rootMarginals(h));
+ root_edges_[h].edge_prob = rootMarginals(h);
}
// cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl;
// cerr << "M:\n" << T << endl;