From 594cc714737708bb7c90c24a1ab1537b052f45ee Mon Sep 17 00:00:00 2001 From: redpony Date: Sun, 29 Aug 2010 00:36:09 +0000 Subject: online optimizer git-svn-id: https://ws10smt.googlecode.com/svn/trunk@631 ec762483-ff6d-05da-a07a-a48fb63a330f --- training/Makefile.am | 8 +++- training/mr_optimize_reduce.cc | 6 +-- training/online_optimizer.cc | 14 ++++++ training/online_optimizer.h | 102 +++++++++++++++++++++++++++++++++++++++++ training/online_train.cc | 8 ++++ training/optimize.cc | 22 ++------- training/optimize.h | 23 ++-------- training/optimize_test.cc | 19 ++++++-- 8 files changed, 158 insertions(+), 44 deletions(-) create mode 100644 training/online_optimizer.cc create mode 100644 training/online_optimizer.h create mode 100644 training/online_train.cc (limited to 'training') diff --git a/training/Makefile.am b/training/Makefile.am index 48b19932..a947e4a5 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -7,12 +7,16 @@ bin_PROGRAMS = \ grammar_convert \ atools \ plftools \ - collapse_weights + collapse_weights \ + online_train noinst_PROGRAMS = \ lbfgs_test \ optimize_test +online_train_SOURCES = online_train.cc online_optimizer.cc +online_train_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz + atools_SOURCES = atools.cc atools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz @@ -22,7 +26,7 @@ model1_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -l grammar_convert_SOURCES = grammar_convert.cc grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz -optimize_test_SOURCES = optimize_test.cc optimize.cc +optimize_test_SOURCES = optimize_test.cc optimize.cc online_optimizer.cc optimize_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz collapse_weights_SOURCES = collapse_weights.cc diff --git a/training/mr_optimize_reduce.cc b/training/mr_optimize_reduce.cc index 42727ecb..b931991d 100644 --- a/training/mr_optimize_reduce.cc +++ b/training/mr_optimize_reduce.cc @@ -108,11 +108,9 @@ int main(int argc, char** argv) { } wm.InitVector(&means); } - shared_ptr o; + shared_ptr o; const string omethod = conf["optimization_method"].as(); - if (omethod == "sgd") - o.reset(new SGDOptimizer(conf["eta"].as())); - else if (omethod == "rprop") + if (omethod == "rprop") o.reset(new RPropOptimizer(num_feats)); // TODO add configuration else o.reset(new LBFGSOptimizer(num_feats, conf["correction_buffers"].as())); diff --git a/training/online_optimizer.cc b/training/online_optimizer.cc new file mode 100644 index 00000000..db55c95e --- /dev/null +++ b/training/online_optimizer.cc @@ -0,0 +1,14 @@ +#include "online_optimizer.h" + +LearningRateSchedule::~LearningRateSchedule() {} + +double StandardLearningRate::eta(int k) const { + return eta_0_ / (1.0 + k / N_); +} + +double ExponentialDecayLearningRate::eta(int k) const { + return eta_0_ * pow(alpha_, k / N_); +} + +OnlineOptimizer::~OnlineOptimizer() {} + diff --git a/training/online_optimizer.h b/training/online_optimizer.h new file mode 100644 index 00000000..0cd748c4 --- /dev/null +++ b/training/online_optimizer.h @@ -0,0 +1,102 @@ +#ifndef _ONL_OPTIMIZE_H_ +#define _ONL_OPTIMIZE_H_ + +#include +#include +#include +#include "sparse_vector.h" + +struct LearningRateSchedule { + virtual ~LearningRateSchedule(); + // returns the learning rate for iteration k + virtual double eta(int k) const = 0; +}; + +struct StandardLearningRate : public LearningRateSchedule { + StandardLearningRate( + size_t training_instances, + double eta_0 = 0.2) : + eta_0_(eta_0), + N_(static_cast(training_instances)) {} + + virtual double eta(int k) const; + + private: + const double eta_0_; + const double N_; +}; + +struct ExponentialDecayLearningRate : public LearningRateSchedule { + ExponentialDecayLearningRate( + size_t training_instances, + double eta_0 = 0.2, + double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009) + ) : eta_0_(eta_0), + N_(static_cast(training_instances)), + alpha_(alpha) { + assert(alpha > 0); + assert(alpha < 1.0); + } + + virtual double eta(int k) const; + + private: + const double eta_0_; + const double N_; + const double alpha_; +}; + +class OnlineOptimizer { + public: + virtual ~OnlineOptimizer(); + OnlineOptimizer(const std::tr1::shared_ptr& s, + size_t training_instances) : schedule_(s), k_(), N_(training_instances) {} + void UpdateWeights(const SparseVector& approx_g, SparseVector* weights) { + ++k_; + const double eta = schedule_->eta(k_); + UpdateWeightsImpl(eta, approx_g, weights); + } + + protected: + virtual void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, SparseVector* weights) = 0; + const size_t N_; // number of training instances + + private: + std::tr1::shared_ptr schedule_; + int k_; // iteration count +}; + +class CumulativeL1OnlineOptimizer : public OnlineOptimizer { + public: + CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr& s, + size_t training_instances, double C) : + OnlineOptimizer(s, training_instances), C_(C), u_() {} + + protected: + void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, SparseVector* weights) { + u_ += eta * C_ / N_; + (*weights) += eta * approx_g; + for (SparseVector::const_iterator it = approx_g.begin(); it != approx_g.end(); ++it) + ApplyPenalty(it->first, weights); + } + + private: + void ApplyPenalty(int i, SparseVector* w) { + const double z = w->value(i); + double w_i = z; + double q_i = q_.value(i); + if (w_i > 0) + w_i = std::max(0.0, w_i - (u_ + q_i)); + else + w_i = std::max(0.0, w_i + (u_ - q_i)); + q_i += w_i - z; + q_.set_value(i, q_i); + w->set_value(i, w_i); + } + + const double C_; // reguarlization strength + double u_; + SparseVector q_; +}; + +#endif diff --git a/training/online_train.cc b/training/online_train.cc new file mode 100644 index 00000000..2e906913 --- /dev/null +++ b/training/online_train.cc @@ -0,0 +1,8 @@ +#include + +#include "online_optimizer.h" + +int main(int argc, char** argv) { + return 0; +} + diff --git a/training/optimize.cc b/training/optimize.cc index 5194752e..1377caa6 100644 --- a/training/optimize.cc +++ b/training/optimize.cc @@ -7,9 +7,9 @@ using namespace std; -Optimizer::~Optimizer() {} +BatchOptimizer::~BatchOptimizer() {} -void Optimizer::Save(ostream* out) const { +void BatchOptimizer::Save(ostream* out) const { out->write((const char*)&eval_, sizeof(eval_)); out->write((const char*)&has_converged_, sizeof(has_converged_)); SaveImpl(out); @@ -17,7 +17,7 @@ void Optimizer::Save(ostream* out) const { out->write((const char*)&magic, sizeof(magic)); } -void Optimizer::Load(istream* in) { +void BatchOptimizer::Load(istream* in) { in->read((char*)&eval_, sizeof(eval_)); ++eval_; in->read((char*)&has_converged_, sizeof(has_converged_)); @@ -28,11 +28,11 @@ void Optimizer::Load(istream* in) { cerr << Name() << " EVALUATION #" << eval_ << endl; } -void Optimizer::SaveImpl(ostream* out) const { +void BatchOptimizer::SaveImpl(ostream* out) const { (void)out; } -void Optimizer::LoadImpl(istream* in) { +void BatchOptimizer::LoadImpl(istream* in) { (void)in; } @@ -78,18 +78,6 @@ void RPropOptimizer::LoadImpl(istream* in) { in->read((char*)&delta_ij_[0], sizeof(double) * n); } -string SGDOptimizer::Name() const { - return "SGDOptimizer"; -} - -void SGDOptimizer::OptimizeImpl(const double& obj, - const vector& g, - vector* x) { - (void)obj; - for (int i = 0; i < g.size(); ++i) - (*x)[i] -= g[i] * eta_; -} - string LBFGSOptimizer::Name() const { return "LBFGSOptimizer"; } diff --git a/training/optimize.h b/training/optimize.h index eddceaad..e2620f93 100644 --- a/training/optimize.h +++ b/training/optimize.h @@ -10,10 +10,10 @@ // abstract base class for first order optimizers // order of invocation: new, Load(), Optimize(), Save(), delete -class Optimizer { +class BatchOptimizer { public: - Optimizer() : eval_(1), has_converged_(false) {} - virtual ~Optimizer(); + BatchOptimizer() : eval_(1), has_converged_(false) {} + virtual ~BatchOptimizer(); virtual std::string Name() const = 0; int EvaluationCount() const { return eval_; } bool HasConverged() const { return has_converged_; } @@ -41,7 +41,7 @@ class Optimizer { bool has_converged_; }; -class RPropOptimizer : public Optimizer { +class RPropOptimizer : public BatchOptimizer { public: explicit RPropOptimizer(int num_vars, double eta_plus = 1.2, @@ -75,20 +75,7 @@ class RPropOptimizer : public Optimizer { const double delta_min_; }; -class SGDOptimizer : public Optimizer { - public: - explicit SGDOptimizer(int num_vars, double eta = 0.1) : eta_(eta) { - (void) num_vars; - } - std::string Name() const; - void OptimizeImpl(const double& obj, - const std::vector& g, - std::vector* x); - private: - const double eta_; -}; - -class LBFGSOptimizer : public Optimizer { +class LBFGSOptimizer : public BatchOptimizer { public: explicit LBFGSOptimizer(int num_vars, int memory_buffers = 10); std::string Name() const; diff --git a/training/optimize_test.cc b/training/optimize_test.cc index 0ada7cbb..6fa5efd4 100644 --- a/training/optimize_test.cc +++ b/training/optimize_test.cc @@ -3,12 +3,13 @@ #include #include #include "optimize.h" +#include "online_optimizer.h" #include "sparse_vector.h" #include "fdict.h" using namespace std; -double TestOptimizer(Optimizer* opt) { +double TestOptimizer(BatchOptimizer* opt) { cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 @@ -34,7 +35,7 @@ double TestOptimizer(Optimizer* opt) { return obj; } -double TestPersistentOptimizer(Optimizer* opt) { +double TestPersistentOptimizer(BatchOptimizer* opt) { cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 // df/dx1 = 8*x1 + x2 @@ -95,11 +96,23 @@ void TestOptimizerVariants(int num_vars) { cerr << oa.Name() << " SUCCESS\n"; } +using namespace std::tr1; + +void TestOnline() { + size_t N = 20; + double C = 1.0; + double eta0 = 0.2; + shared_ptr r(new ExponentialDecayLearningRate(N, eta0, 0.85)); + //shared_ptr r(new StandardLearningRate(N, eta0)); + CumulativeL1OnlineOptimizer opt(r, N, C); + assert(r->eta(10) < r->eta(1)); +} + int main() { int n = 3; - TestOptimizerVariants(n); TestOptimizerVariants(n); TestOptimizerVariants(n); + TestOnline(); return 0; } -- cgit v1.2.3