summaryrefslogtreecommitdiff
path: root/training/optimize.h
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-22 05:12:27 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-22 05:12:27 +0000
commit7cc92b65a3185aa242088d830e166e495674efc9 (patch)
tree681fe5237612a4e96ce36fb9fabef00042c8ee61 /training/optimize.h
parent37728b8be4d0b3df9da81fdda2198ff55b4b2d91 (diff)
initial checkin
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@2 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training/optimize.h')
-rw-r--r--training/optimize.h104
1 files changed, 104 insertions, 0 deletions
diff --git a/training/optimize.h b/training/optimize.h
new file mode 100644
index 00000000..eddceaad
--- /dev/null
+++ b/training/optimize.h
@@ -0,0 +1,104 @@
+#ifndef _OPTIMIZE_H_
+#define _OPTIMIZE_H_
+
+#include <iostream>
+#include <vector>
+#include <string>
+#include <cassert>
+
+#include "lbfgs.h"
+
+// abstract base class for first order optimizers
+// order of invocation: new, Load(), Optimize(), Save(), delete
+class Optimizer {
+ public:
+ Optimizer() : eval_(1), has_converged_(false) {}
+ virtual ~Optimizer();
+ virtual std::string Name() const = 0;
+ int EvaluationCount() const { return eval_; }
+ bool HasConverged() const { return has_converged_; }
+
+ void Optimize(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x) {
+ assert(g.size() == x->size());
+ OptimizeImpl(obj, g, x);
+ scitbx::lbfgs::traditional_convergence_test<double> converged(g.size());
+ has_converged_ = converged(&(*x)[0], &g[0]);
+ }
+
+ void Save(std::ostream* out) const;
+ void Load(std::istream* in);
+ protected:
+ virtual void SaveImpl(std::ostream* out) const;
+ virtual void LoadImpl(std::istream* in);
+ virtual void OptimizeImpl(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x) = 0;
+
+ int eval_;
+ private:
+ bool has_converged_;
+};
+
+class RPropOptimizer : public Optimizer {
+ public:
+ explicit RPropOptimizer(int num_vars,
+ double eta_plus = 1.2,
+ double eta_minus = 0.5,
+ double delta_0 = 0.1,
+ double delta_max = 50.0,
+ double delta_min = 1e-6) :
+ prev_g_(num_vars, 0.0),
+ delta_ij_(num_vars, delta_0),
+ eta_plus_(eta_plus),
+ eta_minus_(eta_minus),
+ delta_max_(delta_max),
+ delta_min_(delta_min) {
+ assert(eta_plus > 1.0);
+ assert(eta_minus > 0.0 && eta_minus < 1.0);
+ assert(delta_max > 0.0);
+ assert(delta_min > 0.0);
+ }
+ std::string Name() const;
+ void OptimizeImpl(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x);
+ void SaveImpl(std::ostream* out) const;
+ void LoadImpl(std::istream* in);
+ private:
+ std::vector<double> prev_g_;
+ std::vector<double> delta_ij_;
+ const double eta_plus_;
+ const double eta_minus_;
+ const double delta_max_;
+ const double delta_min_;
+};
+
+class SGDOptimizer : public Optimizer {
+ public:
+ explicit SGDOptimizer(int num_vars, double eta = 0.1) : eta_(eta) {
+ (void) num_vars;
+ }
+ std::string Name() const;
+ void OptimizeImpl(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x);
+ private:
+ const double eta_;
+};
+
+class LBFGSOptimizer : public Optimizer {
+ public:
+ explicit LBFGSOptimizer(int num_vars, int memory_buffers = 10);
+ std::string Name() const;
+ void SaveImpl(std::ostream* out) const;
+ void LoadImpl(std::istream* in);
+ void OptimizeImpl(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x);
+ private:
+ scitbx::lbfgs::minimizer<double> opt_;
+};
+
+#endif