From d8f1acd9857a600d4970a1af910d66d66f44875b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 22 Feb 2012 16:10:56 +0000 Subject: add regularization --- training/lbl_model.cc | 50 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 9 deletions(-) (limited to 'training') diff --git a/training/lbl_model.cc b/training/lbl_model.cc index eb3e194d..a114bba7 100644 --- a/training/lbl_model.cc +++ b/training/lbl_model.cc @@ -12,6 +12,7 @@ #include // memset #include +#include #include #include #include @@ -27,7 +28,7 @@ namespace po = boost::program_options; using namespace std; -#define kDIMENSIONS 8 +#define kDIMENSIONS 110 typedef Eigen::Matrix RVector; typedef Eigen::Matrix RTVector; typedef Eigen::Matrix TMatrix; @@ -38,8 +39,9 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { opts.add_options() ("input,i",po::value(),"Input file") ("iterations,I",po::value()->default_value(1000),"Number of iterations of training") + ("regularization_strength,C",po::value()->default_value(0.1),"L2 regularization strength (0 for no regularization)") ("eta,e", po::value()->default_value(0.1f), "Eta for SGD") - ("random_seed", po::value(), "Random seed") + ("random_seed,s", po::value(), "Random seed") ("diagonal_tension,T", po::value()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (0 = uniform, >0 sharpens)") ("testset,x", po::value(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model"); po::options_description clo("Command line options"); @@ -67,6 +69,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { void Normalize(RVector* v) { float norm = v->norm(); + assert(norm > 0.0f); *v /= norm; } @@ -74,21 +77,42 @@ void Flatten(const TMatrix& m, vector* v) { unsigned c = 0; v->resize(kDIMENSIONS * kDIMENSIONS); for (unsigned i = 0; i < kDIMENSIONS; ++i) - for (unsigned j = 0; j < kDIMENSIONS; ++j) + for (unsigned j = 0; j < kDIMENSIONS; ++j) { + assert(boost::math::isnormal(m(i, j))); (*v)[c++] = m(i,j); + } } void Unflatten(const vector& v, TMatrix* m) { unsigned c = 0; for (unsigned i = 0; i < kDIMENSIONS; ++i) - for (unsigned j = 0; j < kDIMENSIONS; ++j) + for (unsigned j = 0; j < kDIMENSIONS; ++j) { + assert(boost::math::isnormal(v[c])); (*m)(i, j) = v[c++]; + } +} + +double ApplyRegularization(const double C, + const vector& weights, + vector* g) { + assert(weights.size() == g->size()); + double reg = 0; + for (size_t i = 0; i < weights.size(); ++i) { + const double& w_i = weights[i]; + double& g_i = (*g)[i]; + reg += C * w_i * w_i; + g_i += 2 * C * w_i; + } + return reg; } int main(int argc, char** argv) { po::variables_map conf; if (!InitCommandLine(argc, argv, &conf)) return 1; const string fname = conf["input"].as(); + const float reg_strength = conf["regularization_strength"].as(); + const bool has_l2 = reg_strength; + assert(reg_strength >= 0.0f); const int ITERATIONS = conf["iterations"].as(); const float eta = conf["eta"].as(); const double diagonal_tension = conf["diagonal_tension"].as(); @@ -147,7 +171,7 @@ int main(int argc, char** argv) { cerr << "Random seed: " << seed << endl; srand(seed); } - TMatrix t = TMatrix::Random() / 1024.0; + TMatrix t = TMatrix::Random() / 50.0; for (unsigned i = 1; i < r_trg.size(); ++i) { r_trg[i] = RVector::Random(); r_src[i] = RVector::Random(); @@ -159,7 +183,7 @@ int main(int argc, char** argv) { vector > trg_pos(TD::NumWords() + 1); // do optimization - TMatrix g; + TMatrix g = TMatrix::Zero(); vector exp_src; vector z_src; vector flat_g, flat_t; @@ -265,11 +289,19 @@ int main(int argc, char** argv) { const double base2_likelihood = likelihood / log(2); cerr << " log_e likelihood: " << likelihood << endl; cerr << " log_2 likelihood: " << base2_likelihood << endl; - cerr << " cross entropy: " << (-base2_likelihood / denom) << endl; - cerr << " perplexity: " << pow(2.0, -base2_likelihood / denom) << endl; + cerr << " cross entropy: " << (-base2_likelihood / denom) << endl; + cerr << " perplexity: " << pow(2.0, -base2_likelihood / denom) << endl; if (!SGD) { Flatten(g, &flat_g); - lbfgs.Optimize(-likelihood, flat_g, &flat_t); + double obj = -likelihood; + if (has_l2) { + const double r = ApplyRegularization(reg_strength, + flat_t, + &flat_g); + obj += r; + cerr << " regularization: " << r << endl; + } + lbfgs.Optimize(obj, flat_g, &flat_t); Unflatten(flat_t, &t); if (lbfgs.HasConverged()) break; } -- cgit v1.2.3