summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/em_utils.h24
-rw-r--r--training/model1.cc1
-rw-r--r--training/mr_em_adapted_reduce.cc6
-rw-r--r--training/ttables.h4
4 files changed, 5 insertions, 30 deletions
diff --git a/training/em_utils.h b/training/em_utils.h
deleted file mode 100644
index 37762978..00000000
--- a/training/em_utils.h
+++ /dev/null
@@ -1,24 +0,0 @@
-#ifndef _EM_UTILS_H_
-#define _EM_UTILS_H_
-
-#include "config.h"
-#ifdef HAVE_BOOST_DIGAMMA
-#include <boost/math/special_functions/digamma.hpp>
-using boost::math::digamma;
-#else
-#warning Using Mark Johnsons digamma()
-#include <cmath>
-inline double digamma(double x) {
- double result = 0, xx, xx2, xx4;
- assert(x > 0);
- for ( ; x < 7; ++x)
- result -= 1/x;
- x -= 1.0/2.0;
- xx = 1.0/x;
- xx2 = xx*xx;
- xx4 = xx2*xx2;
- result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4;
- return result;
-}
-#endif
-#endif
diff --git a/training/model1.cc b/training/model1.cc
index 40249aa3..a87d388f 100644
--- a/training/model1.cc
+++ b/training/model1.cc
@@ -9,7 +9,6 @@
#include "filelib.h"
#include "ttables.h"
#include "tdict.h"
-#include "em_utils.h"
namespace po = boost::program_options;
using namespace std;
diff --git a/training/mr_em_adapted_reduce.cc b/training/mr_em_adapted_reduce.cc
index d4c16a2f..f65b5440 100644
--- a/training/mr_em_adapted_reduce.cc
+++ b/training/mr_em_adapted_reduce.cc
@@ -10,7 +10,7 @@
#include "fdict.h"
#include "weights.h"
#include "sparse_vector.h"
-#include "em_utils.h"
+#include "m.h"
using namespace std;
namespace po = boost::program_options;
@@ -63,11 +63,11 @@ void Maximize(const bool use_vb,
assert(tot > 0.0);
double ltot = log(tot);
if (use_vb)
- ltot = digamma(tot + total_event_types * alpha);
+ ltot = Md::digamma(tot + total_event_types * alpha);
for (SparseVector<double>::const_iterator it = counts.begin();
it != counts.end(); ++it) {
if (use_vb) {
- pc->set_value(it->first, NoZero(digamma(it->second + alpha) - ltot));
+ pc->set_value(it->first, NoZero(Md::digamma(it->second + alpha) - ltot));
} else {
pc->set_value(it->first, NoZero(log(it->second) - ltot));
}
diff --git a/training/ttables.h b/training/ttables.h
index 50d85a68..bf3351d2 100644
--- a/training/ttables.h
+++ b/training/ttables.h
@@ -4,9 +4,9 @@
#include <iostream>
#include <tr1/unordered_map>
+#include "m.h"
#include "wordid.h"
#include "tdict.h"
-#include "em_utils.h"
class TTable {
public:
@@ -39,7 +39,7 @@ class TTable {
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
tot += it->second + alpha;
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
- it->second = exp(digamma(it->second + alpha) - digamma(tot));
+ it->second = exp(Md::digamma(it->second + alpha) - Md::digamma(tot));
}
counts.clear();
}