From 69b0bf8d618338c82fda17878defff77fb35a69f Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 13 May 2012 16:18:43 -0700 Subject: fast creg training code for univariate linear and logistic regression --- training/liblbfgs/lbfgs++.h | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) (limited to 'training/liblbfgs/lbfgs++.h') diff --git a/training/liblbfgs/lbfgs++.h b/training/liblbfgs/lbfgs++.h index 342f9b0e..92ead955 100644 --- a/training/liblbfgs/lbfgs++.h +++ b/training/liblbfgs/lbfgs++.h @@ -16,28 +16,33 @@ template class LBFGS { public: - LBFGS(size_t n, // number of variables - const Function& f, // function to optimize - double l1_c = 0.0, // l1 penalty strength - size_t m = 10 // number of memory buffers - // TODO should use custom allocator here: + LBFGS(size_t n, // number of variables + const Function& f, // function to optimize + size_t m = 10, // number of memory buffers + double l1_c = 0.0, // l1 penalty strength + unsigned l1_start = 0, // l1 penalty starting index + double eps = 1e-5 // convergence epsilon + // TODO should use custom allocator here: ) : p_x(new std::vector(n, 0.0)), owned(true), m_x(*p_x), func(f) { - Init(m, l1_c); + Init(m, l1_c, l1_start, eps); } // constructor where external vector storage for variables is used LBFGS(std::vector* px, const Function& f, - double l1_c = 0.0, // l1 penalty strength - size_t m = 10 + size_t m = 10, // number of memory buffers + double l1_c = 0.0, // l1 penalty strength + unsigned l1_start = 0, // l1 penalty starting index + double eps = 1e-5 // convergence epsilon + // TODO should use custom allocator here: ) : p_x(px), owned(false), m_x(*p_x), func(f) { - Init(m, l1_c); + Init(m, l1_c, l1_start, eps); } ~LBFGS() { @@ -60,12 +65,14 @@ class LBFGS { } private: - void Init(size_t m, double l1_c) { + void Init(size_t m, double l1_c, unsigned l1_start, double eps) { lbfgs_parameter_init(¶m); param.m = m; + param.epsilon = eps; if (l1_c > 0.0) { param.linesearch = LBFGS_LINESEARCH_BACKTRACKING; - param.orthantwise_c = 1.0; + param.orthantwise_c = l1_c; + param.orthantwise_start = l1_start; } silence = false; } -- cgit v1.2.3