From b6c44f82ffca74cc38cf6039ac9ab3c2c66fd5d6 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 25 Nov 2013 01:03:58 -0500 Subject: l1 version of adagrad optimizer --- training/crf/mpi_adagrad_optimize.cc | 43 ++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) (limited to 'training/crf') diff --git a/training/crf/mpi_adagrad_optimize.cc b/training/crf/mpi_adagrad_optimize.cc index 05e7c35f..af963e3a 100644 --- a/training/crf/mpi_adagrad_optimize.cc +++ b/training/crf/mpi_adagrad_optimize.cc @@ -194,7 +194,7 @@ class AdaGradOptimizer { for (auto& gi : g) { #else for (SparseVector::const_iterator it = g.begin(); it != g.end(); ++it) { - const pair& gi = *it; + const pair& gi = *it; #endif if (gi.second) { G[gi.first] += gi.second * gi.second; @@ -206,6 +206,44 @@ class AdaGradOptimizer { vector G; }; +class AdaGradL1Optimizer { + public: + explicit AdaGradL1Optimizer(double e, double l) : + t(), + eta(e), + lambda(l), + G() {} + void update(const SparseVector& g, vector* x) { + t += 1.0; + if (x->size() > G.size()) { + G.resize(x->size(), 0.0); + u.resize(x->size(), 0.0); + } +#if HAVE_CXX11 + for (auto& gi : g) { +#else + for (SparseVector::const_iterator it = g.begin(); it != g.end(); ++it) { + const pair& gi = *it; +#endif + if (gi.second) { + u[gi.first] += gi.second; + G[gi.first] += gi.second * gi.second; + double z = fabs(u[gi.first] / t) - lambda; + double s = 1; + if (u[gi.first] > 0) s = -1; + if (z > 0 && G[gi.first]) + (*x)[gi.first] = eta * s * z * t / sqrt(G[gi.first]); + else + (*x)[gi.first] = 0.0; + } + } + } + double t; + const double eta; + const double lambda; + vector G, u; +}; + unsigned non_zeros(const vector& x) { unsigned nz = 0; for (unsigned i = 0; i < x.size(); ++i) @@ -286,7 +324,8 @@ int main(int argc, char** argv) { init_weights.clear(); } - AdaGradOptimizer adagrad(conf["eta"].as()); + //AdaGradOptimizer adagrad(conf["eta"].as()); + AdaGradL1Optimizer adagrad(conf["eta"].as(), conf["regularization_strength"].as()); int iter = -1; bool converged = false; -- cgit v1.2.3