summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--configure.ac21
-rw-r--r--training/em_utils.h24
-rw-r--r--training/model1.cc1
-rw-r--r--training/mr_em_adapted_reduce.cc6
-rw-r--r--training/ttables.h4
-rw-r--r--utils/m.h6
6 files changed, 26 insertions, 36 deletions
diff --git a/configure.ac b/configure.ac
index cd78ee72..aa79027f 100644
--- a/configure.ac
+++ b/configure.ac
@@ -9,7 +9,7 @@ esac
AC_PROG_CC
AC_PROG_CXX
AC_LANG_CPLUSPLUS
-BOOST_REQUIRE
+BOOST_REQUIRE([1.44])
BOOST_PROGRAM_OPTIONS
AC_ARG_ENABLE(mpi,
[ --enable-mpi Build MPI binaries, assumes mpi.h is present ],
@@ -38,7 +38,7 @@ then
CPPFLAGS="$CPPFLAGS -I${with_cmph}/include"
AC_CHECK_HEADER(cmph.h,
- [AC_DEFINE([HAVE_CMPH], [], [flag for cmph perfect hashing library])],
+ [AC_DEFINE([HAVE_CMPH], [1], [flag for cmph perfect hashing library])],
[AC_MSG_ERROR([Cannot find cmph library!])])
LDFLAGS="$LDFLAGS -L${with_cmph}/lib"
@@ -46,6 +46,18 @@ then
AM_CONDITIONAL([HAVE_CMPH], true)
fi
+if test "x$with_eigen" != 'xno'
+then
+ SAVE_CPPFLAGS="$CPPFLAGS"
+ CPPFLAGS="$CPPFLAGS -I${with_eigen}"
+
+ AC_CHECK_HEADER(Eigen,
+ [AC_DEFINE([HAVE_EIGEN], [1], [flag for Eigen linear algebra library])],
+ [AC_MSG_ERROR([Cannot find Eigen!])])
+
+ AM_CONDITIONAL([HAVE_EIGEN], true)
+fi
+
#BOOST_THREADS
CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS"
LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS"
@@ -53,11 +65,8 @@ LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS"
LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS"
# $BOOST_THREAD_LIBS"
-AC_CHECK_HEADER(boost/math/special_functions/digamma.hpp,
- [AC_DEFINE([HAVE_BOOST_DIGAMMA], [], [flag for boost::math::digamma])])
-
AC_CHECK_HEADER(google/dense_hash_map,
- [AC_DEFINE([HAVE_SPARSEHASH], [], [flag for google::dense_hash_map])])
+ [AC_DEFINE([HAVE_SPARSEHASH], [1], [flag for google::dense_hash_map])])
AC_PROG_INSTALL
GTEST_LIB_CHECK(1.0)
diff --git a/training/em_utils.h b/training/em_utils.h
deleted file mode 100644
index 37762978..00000000
--- a/training/em_utils.h
+++ /dev/null
@@ -1,24 +0,0 @@
-#ifndef _EM_UTILS_H_
-#define _EM_UTILS_H_
-
-#include "config.h"
-#ifdef HAVE_BOOST_DIGAMMA
-#include <boost/math/special_functions/digamma.hpp>
-using boost::math::digamma;
-#else
-#warning Using Mark Johnsons digamma()
-#include <cmath>
-inline double digamma(double x) {
- double result = 0, xx, xx2, xx4;
- assert(x > 0);
- for ( ; x < 7; ++x)
- result -= 1/x;
- x -= 1.0/2.0;
- xx = 1.0/x;
- xx2 = xx*xx;
- xx4 = xx2*xx2;
- result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4;
- return result;
-}
-#endif
-#endif
diff --git a/training/model1.cc b/training/model1.cc
index 40249aa3..a87d388f 100644
--- a/training/model1.cc
+++ b/training/model1.cc
@@ -9,7 +9,6 @@
#include "filelib.h"
#include "ttables.h"
#include "tdict.h"
-#include "em_utils.h"
namespace po = boost::program_options;
using namespace std;
diff --git a/training/mr_em_adapted_reduce.cc b/training/mr_em_adapted_reduce.cc
index d4c16a2f..f65b5440 100644
--- a/training/mr_em_adapted_reduce.cc
+++ b/training/mr_em_adapted_reduce.cc
@@ -10,7 +10,7 @@
#include "fdict.h"
#include "weights.h"
#include "sparse_vector.h"
-#include "em_utils.h"
+#include "m.h"
using namespace std;
namespace po = boost::program_options;
@@ -63,11 +63,11 @@ void Maximize(const bool use_vb,
assert(tot > 0.0);
double ltot = log(tot);
if (use_vb)
- ltot = digamma(tot + total_event_types * alpha);
+ ltot = Md::digamma(tot + total_event_types * alpha);
for (SparseVector<double>::const_iterator it = counts.begin();
it != counts.end(); ++it) {
if (use_vb) {
- pc->set_value(it->first, NoZero(digamma(it->second + alpha) - ltot));
+ pc->set_value(it->first, NoZero(Md::digamma(it->second + alpha) - ltot));
} else {
pc->set_value(it->first, NoZero(log(it->second) - ltot));
}
diff --git a/training/ttables.h b/training/ttables.h
index 50d85a68..bf3351d2 100644
--- a/training/ttables.h
+++ b/training/ttables.h
@@ -4,9 +4,9 @@
#include <iostream>
#include <tr1/unordered_map>
+#include "m.h"
#include "wordid.h"
#include "tdict.h"
-#include "em_utils.h"
class TTable {
public:
@@ -39,7 +39,7 @@ class TTable {
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
tot += it->second + alpha;
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
- it->second = exp(digamma(it->second + alpha) - digamma(tot));
+ it->second = exp(Md::digamma(it->second + alpha) - Md::digamma(tot));
}
counts.clear();
}
diff --git a/utils/m.h b/utils/m.h
index b25248c2..5e45efee 100644
--- a/utils/m.h
+++ b/utils/m.h
@@ -3,6 +3,7 @@
#include <cassert>
#include <cmath>
+#include <boost/math/special_functions/digamma.hpp>
template <typename F>
struct M {
@@ -81,6 +82,11 @@ struct M {
}
}
+ // digamma is the first derivative of the log-gamma function
+ static inline F digamma(const F& x) {
+ return boost::math::digamma(x);
+ }
+
};
typedef M<double> Md;