diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/em_utils.h | 24 | ||||
-rw-r--r-- | training/model1.cc | 1 | ||||
-rw-r--r-- | training/mr_em_adapted_reduce.cc | 6 | ||||
-rw-r--r-- | training/ttables.h | 4 |
4 files changed, 5 insertions, 30 deletions
diff --git a/training/em_utils.h b/training/em_utils.h deleted file mode 100644 index 37762978..00000000 --- a/training/em_utils.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef _EM_UTILS_H_ -#define _EM_UTILS_H_ - -#include "config.h" -#ifdef HAVE_BOOST_DIGAMMA -#include <boost/math/special_functions/digamma.hpp> -using boost::math::digamma; -#else -#warning Using Mark Johnsons digamma() -#include <cmath> -inline double digamma(double x) { - double result = 0, xx, xx2, xx4; - assert(x > 0); - for ( ; x < 7; ++x) - result -= 1/x; - x -= 1.0/2.0; - xx = 1.0/x; - xx2 = xx*xx; - xx4 = xx2*xx2; - result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4; - return result; -} -#endif -#endif diff --git a/training/model1.cc b/training/model1.cc index 40249aa3..a87d388f 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -9,7 +9,6 @@ #include "filelib.h" #include "ttables.h" #include "tdict.h" -#include "em_utils.h" namespace po = boost::program_options; using namespace std; diff --git a/training/mr_em_adapted_reduce.cc b/training/mr_em_adapted_reduce.cc index d4c16a2f..f65b5440 100644 --- a/training/mr_em_adapted_reduce.cc +++ b/training/mr_em_adapted_reduce.cc @@ -10,7 +10,7 @@ #include "fdict.h" #include "weights.h" #include "sparse_vector.h" -#include "em_utils.h" +#include "m.h" using namespace std; namespace po = boost::program_options; @@ -63,11 +63,11 @@ void Maximize(const bool use_vb, assert(tot > 0.0); double ltot = log(tot); if (use_vb) - ltot = digamma(tot + total_event_types * alpha); + ltot = Md::digamma(tot + total_event_types * alpha); for (SparseVector<double>::const_iterator it = counts.begin(); it != counts.end(); ++it) { if (use_vb) { - pc->set_value(it->first, NoZero(digamma(it->second + alpha) - ltot)); + pc->set_value(it->first, NoZero(Md::digamma(it->second + alpha) - ltot)); } else { pc->set_value(it->first, NoZero(log(it->second) - ltot)); } diff --git a/training/ttables.h b/training/ttables.h index 50d85a68..bf3351d2 100644 --- a/training/ttables.h +++ b/training/ttables.h @@ -4,9 +4,9 @@ #include <iostream> #include <tr1/unordered_map> +#include "m.h" #include "wordid.h" #include "tdict.h" -#include "em_utils.h" class TTable { public: @@ -39,7 +39,7 @@ class TTable { for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) tot += it->second + alpha; for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) - it->second = exp(digamma(it->second + alpha) - digamma(tot)); + it->second = exp(Md::digamma(it->second + alpha) - Md::digamma(tot)); } counts.clear(); } |