diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/model1.cc | 15 | ||||
| -rw-r--r-- | training/ttables.h | 14 | 
2 files changed, 28 insertions, 1 deletions
diff --git a/training/model1.cc b/training/model1.cc index eacf4b32..4023735c 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -20,6 +20,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {          ("iterations,i",po::value<unsigned>()->default_value(5),"Number of iterations of EM training")          ("beam_threshold,t",po::value<double>()->default_value(-4),"log_10 of beam threshold (-10000 to include everything, 0 max)")          ("no_null_word,N","Do not generate from the null token") +        ("variational_bayes,v","Add a symmetric Dirichlet prior and infer VB estimate of weights") +        ("alpha,a", po::value<double>()->default_value(0.01), "Hyperparameter for optional Dirichlet prior")          ("no_add_viterbi,V","Do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)");    po::options_description clo("Command line options");    clo.add_options() @@ -53,6 +55,12 @@ int main(int argc, char** argv) {    const bool use_null = (conf.count("no_null_word") == 0);    const WordID kNULL = TD::Convert("<eps>");    const bool add_viterbi = (conf.count("no_add_viterbi") == 0); +  const bool variational_bayes = (conf.count("variational_bayes") > 0); +  const double alpha = conf["alpha"].as<double>(); +  if (variational_bayes && alpha <= 0.0) { +    cerr << "--alpha must be > 0\n"; +    return 1; +  }    TTable tt;    TTable::Word2Word2Double was_viterbi; @@ -125,7 +133,12 @@ int main(int argc, char** argv) {      cerr << "  log likelihood: " << likelihood << endl;      cerr << "   cross entropy: " << (-likelihood / denom) << endl;      cerr << "      perplexity: " << pow(2.0, -likelihood / denom) << endl; -    if (!final_iteration) tt.Normalize(); +    if (!final_iteration) { +      if (variational_bayes) +        tt.NormalizeVB(alpha); +      else +        tt.Normalize(); +    }    }    for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) {      const TTable::Word2Double& cpd = ei->second; diff --git a/training/ttables.h b/training/ttables.h index 53f5f2ab..50d85a68 100644 --- a/training/ttables.h +++ b/training/ttables.h @@ -6,6 +6,7 @@  #include "wordid.h"  #include "tdict.h" +#include "em_utils.h"  class TTable {   public: @@ -29,6 +30,19 @@ class TTable {    inline void Increment(const int& e, const int& f, double x) {      counts[e][f] += x;    } +  void NormalizeVB(const double alpha) { +    ttable.swap(counts); +    for (Word2Word2Double::iterator cit = ttable.begin(); +         cit != ttable.end(); ++cit) { +      double tot = 0; +      Word2Double& cpd = cit->second; +      for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) +        tot += it->second + alpha; +      for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) +        it->second = exp(digamma(it->second + alpha) - digamma(tot)); +    } +    counts.clear(); +  }    void Normalize() {      ttable.swap(counts);      for (Word2Word2Double::iterator cit = ttable.begin();  | 
