summaryrefslogtreecommitdiff
path: root/training/ttables.h
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-12-01 00:03:35 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-12-01 00:03:35 +0000
commitfd4259da347b371a1a399b9130f62938e3db462b (patch)
tree93bef21d1f53e778f54f7f8bc867bef97672d6c1 /training/ttables.h
parent988461b40d14162dd93490c3e34de1f6c04cb540 (diff)
optional variational bayes
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@734 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training/ttables.h')
-rw-r--r--training/ttables.h14
1 files changed, 14 insertions, 0 deletions
diff --git a/training/ttables.h b/training/ttables.h
index 53f5f2ab..50d85a68 100644
--- a/training/ttables.h
+++ b/training/ttables.h
@@ -6,6 +6,7 @@
#include "wordid.h"
#include "tdict.h"
+#include "em_utils.h"
class TTable {
public:
@@ -29,6 +30,19 @@ class TTable {
inline void Increment(const int& e, const int& f, double x) {
counts[e][f] += x;
}
+ void NormalizeVB(const double alpha) {
+ ttable.swap(counts);
+ for (Word2Word2Double::iterator cit = ttable.begin();
+ cit != ttable.end(); ++cit) {
+ double tot = 0;
+ Word2Double& cpd = cit->second;
+ for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
+ tot += it->second + alpha;
+ for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
+ it->second = exp(digamma(it->second + alpha) - digamma(tot));
+ }
+ counts.clear();
+ }
void Normalize() {
ttable.swap(counts);
for (Word2Word2Double::iterator cit = ttable.begin();