diff options
author | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-11-25 01:03:58 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-11-25 01:03:58 -0500 |
commit | b6c44f82ffca74cc38cf6039ac9ab3c2c66fd5d6 (patch) | |
tree | 70c5d30405ffc7b4606dfd5a5546160287bb01fa | |
parent | 73aea215412a8a2837292d9eac1617df7c45ef01 (diff) |
l1 version of adagrad optimizer
-rw-r--r-- | training/crf/mpi_adagrad_optimize.cc | 43 |
1 files changed, 41 insertions, 2 deletions
diff --git a/training/crf/mpi_adagrad_optimize.cc b/training/crf/mpi_adagrad_optimize.cc index 05e7c35f..af963e3a 100644 --- a/training/crf/mpi_adagrad_optimize.cc +++ b/training/crf/mpi_adagrad_optimize.cc @@ -194,7 +194,7 @@ class AdaGradOptimizer { for (auto& gi : g) { #else for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) { - const pair<unsigne,double>& gi = *it; + const pair<unsigned,double>& gi = *it; #endif if (gi.second) { G[gi.first] += gi.second * gi.second; @@ -206,6 +206,44 @@ class AdaGradOptimizer { vector<double> G; }; +class AdaGradL1Optimizer { + public: + explicit AdaGradL1Optimizer(double e, double l) : + t(), + eta(e), + lambda(l), + G() {} + void update(const SparseVector<double>& g, vector<double>* x) { + t += 1.0; + if (x->size() > G.size()) { + G.resize(x->size(), 0.0); + u.resize(x->size(), 0.0); + } +#if HAVE_CXX11 + for (auto& gi : g) { +#else + for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) { + const pair<unsigned,double>& gi = *it; +#endif + if (gi.second) { + u[gi.first] += gi.second; + G[gi.first] += gi.second * gi.second; + double z = fabs(u[gi.first] / t) - lambda; + double s = 1; + if (u[gi.first] > 0) s = -1; + if (z > 0 && G[gi.first]) + (*x)[gi.first] = eta * s * z * t / sqrt(G[gi.first]); + else + (*x)[gi.first] = 0.0; + } + } + } + double t; + const double eta; + const double lambda; + vector<double> G, u; +}; + unsigned non_zeros(const vector<double>& x) { unsigned nz = 0; for (unsigned i = 0; i < x.size(); ++i) @@ -286,7 +324,8 @@ int main(int argc, char** argv) { init_weights.clear(); } - AdaGradOptimizer adagrad(conf["eta"].as<double>()); + //AdaGradOptimizer adagrad(conf["eta"].as<double>()); + AdaGradL1Optimizer adagrad(conf["eta"].as<double>(), conf["regularization_strength"].as<double>()); int iter = -1; bool converged = false; |