diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-22 05:12:27 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-22 05:12:27 +0000 |
commit | 0172721855098ca02b207231a654dffa5e4eb1c9 (patch) | |
tree | 8069c3a62e2d72bd64a2cdeee9724b2679c8a56b /training/optimize.h | |
parent | 37728b8be4d0b3df9da81fdda2198ff55b4b2d91 (diff) |
initial checkin
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@2 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training/optimize.h')
-rw-r--r-- | training/optimize.h | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/training/optimize.h b/training/optimize.h new file mode 100644 index 00000000..eddceaad --- /dev/null +++ b/training/optimize.h @@ -0,0 +1,104 @@ +#ifndef _OPTIMIZE_H_ +#define _OPTIMIZE_H_ + +#include <iostream> +#include <vector> +#include <string> +#include <cassert> + +#include "lbfgs.h" + +// abstract base class for first order optimizers +// order of invocation: new, Load(), Optimize(), Save(), delete +class Optimizer { + public: + Optimizer() : eval_(1), has_converged_(false) {} + virtual ~Optimizer(); + virtual std::string Name() const = 0; + int EvaluationCount() const { return eval_; } + bool HasConverged() const { return has_converged_; } + + void Optimize(const double& obj, + const std::vector<double>& g, + std::vector<double>* x) { + assert(g.size() == x->size()); + OptimizeImpl(obj, g, x); + scitbx::lbfgs::traditional_convergence_test<double> converged(g.size()); + has_converged_ = converged(&(*x)[0], &g[0]); + } + + void Save(std::ostream* out) const; + void Load(std::istream* in); + protected: + virtual void SaveImpl(std::ostream* out) const; + virtual void LoadImpl(std::istream* in); + virtual void OptimizeImpl(const double& obj, + const std::vector<double>& g, + std::vector<double>* x) = 0; + + int eval_; + private: + bool has_converged_; +}; + +class RPropOptimizer : public Optimizer { + public: + explicit RPropOptimizer(int num_vars, + double eta_plus = 1.2, + double eta_minus = 0.5, + double delta_0 = 0.1, + double delta_max = 50.0, + double delta_min = 1e-6) : + prev_g_(num_vars, 0.0), + delta_ij_(num_vars, delta_0), + eta_plus_(eta_plus), + eta_minus_(eta_minus), + delta_max_(delta_max), + delta_min_(delta_min) { + assert(eta_plus > 1.0); + assert(eta_minus > 0.0 && eta_minus < 1.0); + assert(delta_max > 0.0); + assert(delta_min > 0.0); + } + std::string Name() const; + void OptimizeImpl(const double& obj, + const std::vector<double>& g, + std::vector<double>* x); + void SaveImpl(std::ostream* out) const; + void LoadImpl(std::istream* in); + private: + std::vector<double> prev_g_; + std::vector<double> delta_ij_; + const double eta_plus_; + const double eta_minus_; + const double delta_max_; + const double delta_min_; +}; + +class SGDOptimizer : public Optimizer { + public: + explicit SGDOptimizer(int num_vars, double eta = 0.1) : eta_(eta) { + (void) num_vars; + } + std::string Name() const; + void OptimizeImpl(const double& obj, + const std::vector<double>& g, + std::vector<double>* x); + private: + const double eta_; +}; + +class LBFGSOptimizer : public Optimizer { + public: + explicit LBFGSOptimizer(int num_vars, int memory_buffers = 10); + std::string Name() const; + void SaveImpl(std::ostream* out) const; + void LoadImpl(std::istream* in); + void OptimizeImpl(const double& obj, + const std::vector<double>& g, + std::vector<double>* x); + private: + scitbx::lbfgs::minimizer<double> opt_; +}; + +#endif |