From 74615686493ad495c8e7802c96e5257da7e7f934 Mon Sep 17 00:00:00 2001 From: redpony Date: Wed, 1 Dec 2010 00:03:35 +0000 Subject: optional variational bayes git-svn-id: https://ws10smt.googlecode.com/svn/trunk@734 ec762483-ff6d-05da-a07a-a48fb63a330f --- training/ttables.h | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'training/ttables.h') diff --git a/training/ttables.h b/training/ttables.h index 53f5f2ab..50d85a68 100644 --- a/training/ttables.h +++ b/training/ttables.h @@ -6,6 +6,7 @@ #include "wordid.h" #include "tdict.h" +#include "em_utils.h" class TTable { public: @@ -29,6 +30,19 @@ class TTable { inline void Increment(const int& e, const int& f, double x) { counts[e][f] += x; } + void NormalizeVB(const double alpha) { + ttable.swap(counts); + for (Word2Word2Double::iterator cit = ttable.begin(); + cit != ttable.end(); ++cit) { + double tot = 0; + Word2Double& cpd = cit->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + tot += it->second + alpha; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + it->second = exp(digamma(it->second + alpha) - digamma(tot)); + } + counts.clear(); + } void Normalize() { ttable.swap(counts); for (Word2Word2Double::iterator cit = ttable.begin(); -- cgit v1.2.3