diff options
Diffstat (limited to 'training/model1.cc')
-rw-r--r-- | training/model1.cc | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/training/model1.cc b/training/model1.cc index eacf4b32..4023735c 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -20,6 +20,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("iterations,i",po::value<unsigned>()->default_value(5),"Number of iterations of EM training") ("beam_threshold,t",po::value<double>()->default_value(-4),"log_10 of beam threshold (-10000 to include everything, 0 max)") ("no_null_word,N","Do not generate from the null token") + ("variational_bayes,v","Add a symmetric Dirichlet prior and infer VB estimate of weights") + ("alpha,a", po::value<double>()->default_value(0.01), "Hyperparameter for optional Dirichlet prior") ("no_add_viterbi,V","Do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)"); po::options_description clo("Command line options"); clo.add_options() @@ -53,6 +55,12 @@ int main(int argc, char** argv) { const bool use_null = (conf.count("no_null_word") == 0); const WordID kNULL = TD::Convert("<eps>"); const bool add_viterbi = (conf.count("no_add_viterbi") == 0); + const bool variational_bayes = (conf.count("variational_bayes") > 0); + const double alpha = conf["alpha"].as<double>(); + if (variational_bayes && alpha <= 0.0) { + cerr << "--alpha must be > 0\n"; + return 1; + } TTable tt; TTable::Word2Word2Double was_viterbi; @@ -125,7 +133,12 @@ int main(int argc, char** argv) { cerr << " log likelihood: " << likelihood << endl; cerr << " cross entropy: " << (-likelihood / denom) << endl; cerr << " perplexity: " << pow(2.0, -likelihood / denom) << endl; - if (!final_iteration) tt.Normalize(); + if (!final_iteration) { + if (variational_bayes) + tt.NormalizeVB(alpha); + else + tt.Normalize(); + } } for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) { const TTable::Word2Double& cpd = ei->second; |