summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/crf/mpi_adagrad_optimize.cc43
1 files changed, 41 insertions, 2 deletions
diff --git a/training/crf/mpi_adagrad_optimize.cc b/training/crf/mpi_adagrad_optimize.cc
index 05e7c35f..af963e3a 100644
--- a/training/crf/mpi_adagrad_optimize.cc
+++ b/training/crf/mpi_adagrad_optimize.cc
@@ -194,7 +194,7 @@ class AdaGradOptimizer {
for (auto& gi : g) {
#else
for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) {
- const pair<unsigne,double>& gi = *it;
+ const pair<unsigned,double>& gi = *it;
#endif
if (gi.second) {
G[gi.first] += gi.second * gi.second;
@@ -206,6 +206,44 @@ class AdaGradOptimizer {
vector<double> G;
};
+class AdaGradL1Optimizer {
+ public:
+ explicit AdaGradL1Optimizer(double e, double l) :
+ t(),
+ eta(e),
+ lambda(l),
+ G() {}
+ void update(const SparseVector<double>& g, vector<double>* x) {
+ t += 1.0;
+ if (x->size() > G.size()) {
+ G.resize(x->size(), 0.0);
+ u.resize(x->size(), 0.0);
+ }
+#if HAVE_CXX11
+ for (auto& gi : g) {
+#else
+ for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) {
+ const pair<unsigned,double>& gi = *it;
+#endif
+ if (gi.second) {
+ u[gi.first] += gi.second;
+ G[gi.first] += gi.second * gi.second;
+ double z = fabs(u[gi.first] / t) - lambda;
+ double s = 1;
+ if (u[gi.first] > 0) s = -1;
+ if (z > 0 && G[gi.first])
+ (*x)[gi.first] = eta * s * z * t / sqrt(G[gi.first]);
+ else
+ (*x)[gi.first] = 0.0;
+ }
+ }
+ }
+ double t;
+ const double eta;
+ const double lambda;
+ vector<double> G, u;
+};
+
unsigned non_zeros(const vector<double>& x) {
unsigned nz = 0;
for (unsigned i = 0; i < x.size(); ++i)
@@ -286,7 +324,8 @@ int main(int argc, char** argv) {
init_weights.clear();
}
- AdaGradOptimizer adagrad(conf["eta"].as<double>());
+ //AdaGradOptimizer adagrad(conf["eta"].as<double>());
+ AdaGradL1Optimizer adagrad(conf["eta"].as<double>(), conf["regularization_strength"].as<double>());
int iter = -1;
bool converged = false;