summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/model1.cc15
-rw-r--r--training/ttables.h14
2 files changed, 28 insertions, 1 deletions
diff --git a/training/model1.cc b/training/model1.cc
index eacf4b32..4023735c 100644
--- a/training/model1.cc
+++ b/training/model1.cc
@@ -20,6 +20,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("iterations,i",po::value<unsigned>()->default_value(5),"Number of iterations of EM training")
("beam_threshold,t",po::value<double>()->default_value(-4),"log_10 of beam threshold (-10000 to include everything, 0 max)")
("no_null_word,N","Do not generate from the null token")
+ ("variational_bayes,v","Add a symmetric Dirichlet prior and infer VB estimate of weights")
+ ("alpha,a", po::value<double>()->default_value(0.01), "Hyperparameter for optional Dirichlet prior")
("no_add_viterbi,V","Do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)");
po::options_description clo("Command line options");
clo.add_options()
@@ -53,6 +55,12 @@ int main(int argc, char** argv) {
const bool use_null = (conf.count("no_null_word") == 0);
const WordID kNULL = TD::Convert("<eps>");
const bool add_viterbi = (conf.count("no_add_viterbi") == 0);
+ const bool variational_bayes = (conf.count("variational_bayes") > 0);
+ const double alpha = conf["alpha"].as<double>();
+ if (variational_bayes && alpha <= 0.0) {
+ cerr << "--alpha must be > 0\n";
+ return 1;
+ }
TTable tt;
TTable::Word2Word2Double was_viterbi;
@@ -125,7 +133,12 @@ int main(int argc, char** argv) {
cerr << " log likelihood: " << likelihood << endl;
cerr << " cross entropy: " << (-likelihood / denom) << endl;
cerr << " perplexity: " << pow(2.0, -likelihood / denom) << endl;
- if (!final_iteration) tt.Normalize();
+ if (!final_iteration) {
+ if (variational_bayes)
+ tt.NormalizeVB(alpha);
+ else
+ tt.Normalize();
+ }
}
for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) {
const TTable::Word2Double& cpd = ei->second;
diff --git a/training/ttables.h b/training/ttables.h
index 53f5f2ab..50d85a68 100644
--- a/training/ttables.h
+++ b/training/ttables.h
@@ -6,6 +6,7 @@
#include "wordid.h"
#include "tdict.h"
+#include "em_utils.h"
class TTable {
public:
@@ -29,6 +30,19 @@ class TTable {
inline void Increment(const int& e, const int& f, double x) {
counts[e][f] += x;
}
+ void NormalizeVB(const double alpha) {
+ ttable.swap(counts);
+ for (Word2Word2Double::iterator cit = ttable.begin();
+ cit != ttable.end(); ++cit) {
+ double tot = 0;
+ Word2Double& cpd = cit->second;
+ for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
+ tot += it->second + alpha;
+ for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
+ it->second = exp(digamma(it->second + alpha) - digamma(tot));
+ }
+ counts.clear();
+ }
void Normalize() {
ttable.swap(counts);
for (Word2Word2Double::iterator cit = ttable.begin();