summaryrefslogtreecommitdiff
path: root/rst_parser/arc_factored_marginals.cc
diff options
context:
space:
mode:
Diffstat (limited to 'rst_parser/arc_factored_marginals.cc')
-rw-r--r--rst_parser/arc_factored_marginals.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/rst_parser/arc_factored_marginals.cc b/rst_parser/arc_factored_marginals.cc
new file mode 100644
index 00000000..9851b59a
--- /dev/null
+++ b/rst_parser/arc_factored_marginals.cc
@@ -0,0 +1,52 @@
+#include "arc_factored.h"
+
+#include <iostream>
+
+#include "config.h"
+
+using namespace std;
+
+#if HAVE_EIGEN
+
+#include <Eigen/Dense>
+typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> ArcMatrix;
+typedef Eigen::Matrix<double, Eigen::Dynamic, 1> RootVector;
+
+void ArcFactoredForest::EdgeMarginals(double *plog_z) {
+ ArcMatrix A(num_words_,num_words_);
+ RootVector r(num_words_);
+ for (int h = 0; h < num_words_; ++h) {
+ for (int m = 0; m < num_words_; ++m) {
+ if (h != m)
+ A(h,m) = edges_(h,m).edge_prob.as_float();
+ else
+ A(h,m) = 0;
+ }
+ r(h) = root_edges_[h].edge_prob.as_float();
+ }
+
+ ArcMatrix L = -A;
+ L.diagonal() = A.colwise().sum();
+ L.row(0) = r;
+ ArcMatrix Linv = L.inverse();
+ if (plog_z) *plog_z = log(Linv.determinant());
+ RootVector rootMarginals = r.cwiseProduct(Linv.col(0));
+ for (int h = 0; h < num_words_; ++h) {
+ for (int m = 0; m < num_words_; ++m) {
+ edges_(h,m).edge_prob = prob_t((m == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,m) -
+ (h == 0 ? 0.0 : 1.0) * A(h,m) * Linv(m,h));
+ }
+ root_edges_[h].edge_prob = prob_t(rootMarginals(h));
+ }
+ // cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl;
+}
+
+#else
+
+void ArcFactoredForest::EdgeMarginals(double*) {
+ cerr << "EdgeMarginals() requires --with-eigen!\n";
+ abort();
+}
+
+#endif
+