From 8aa29810bb77611cc20b7a384897ff6703783ea1 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 18 Nov 2012 13:35:42 -0500 Subject: major restructure of the training code --- training/utils/online_optimizer.h | 129 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 training/utils/online_optimizer.h (limited to 'training/utils/online_optimizer.h') diff --git a/training/utils/online_optimizer.h b/training/utils/online_optimizer.h new file mode 100644 index 00000000..28d89344 --- /dev/null +++ b/training/utils/online_optimizer.h @@ -0,0 +1,129 @@ +#ifndef _ONL_OPTIMIZE_H_ +#define _ONL_OPTIMIZE_H_ + +#include +#include +#include +#include +#include "sparse_vector.h" + +struct LearningRateSchedule { + virtual ~LearningRateSchedule(); + // returns the learning rate for the kth iteration + virtual double eta(int k) const = 0; +}; + +// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N +// to mean the batch size in most places, but it doesn't completely +// make sense to me in the learning rate schedules-- this needs +// to be worked out to make sure they didn't mean corpus size +// in some places and batch size in others (since in the paper they +// only ever work with batch sizes of 1) +struct StandardLearningRate : public LearningRateSchedule { + StandardLearningRate( + size_t batch_size, // batch size, not corpus size! + double eta_0 = 0.2) : + eta_0_(eta_0), + N_(static_cast(batch_size)) {} + + virtual double eta(int k) const; + + private: + const double eta_0_; + const double N_; +}; + +struct ExponentialDecayLearningRate : public LearningRateSchedule { + ExponentialDecayLearningRate( + size_t batch_size, // batch size, not corpus size! + double eta_0 = 0.2, + double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009) + ) : eta_0_(eta_0), + N_(static_cast(batch_size)), + alpha_(alpha) { + assert(alpha > 0); + assert(alpha < 1.0); + } + + virtual double eta(int k) const; + + private: + const double eta_0_; + const double N_; + const double alpha_; +}; + +class OnlineOptimizer { + public: + virtual ~OnlineOptimizer(); + OnlineOptimizer(const std::tr1::shared_ptr& s, + size_t batch_size, + const std::vector& frozen_feats = std::vector()) + : N_(batch_size),schedule_(s),k_() { + for (int i = 0; i < frozen_feats.size(); ++i) + frozen_.insert(frozen_feats[i]); + } + void ResetEpoch() { k_ = 0; ResetEpochImpl(); } + void UpdateWeights(const SparseVector& approx_g, int max_feat, SparseVector* weights) { + ++k_; + const double eta = schedule_->eta(k_); + UpdateWeightsImpl(eta, approx_g, max_feat, weights); + } + + protected: + virtual void ResetEpochImpl(); + virtual void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) = 0; + const size_t N_; // number of training instances per batch + std::set frozen_; // frozen (non-optimizing) features + + private: + std::tr1::shared_ptr schedule_; + int k_; // iteration count +}; + +class CumulativeL1OnlineOptimizer : public OnlineOptimizer { + public: + CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr& s, + size_t training_instances, double C, + const std::vector& frozen) : + OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {} + + protected: + void ResetEpochImpl() { u_ = 0; } + void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) { + u_ += eta * C_ / N_; + for (SparseVector::const_iterator it = approx_g.begin(); + it != approx_g.end(); ++it) { + if (frozen_.count(it->first) == 0) + weights->add_value(it->first, eta * it->second); + } + for (int i = 1; i < max_feat; ++i) + if (frozen_.count(i) == 0) ApplyPenalty(i, weights); + } + + private: + void ApplyPenalty(int i, SparseVector* w) { + const double z = w->value(i); + double w_i = z; + double q_i = q_.value(i); + if (w_i > 0.0) + w_i = std::max(0.0, w_i - (u_ + q_i)); + else if (w_i < 0.0) + w_i = std::min(0.0, w_i + (u_ - q_i)); + q_i += w_i - z; + if (q_i == 0.0) + q_.erase(i); + else + q_.set_value(i, q_i); + if (w_i == 0.0) + w->erase(i); + else + w->set_value(i, w_i); + } + + const double C_; // reguarlization strength + double u_; + SparseVector q_; +}; + +#endif -- cgit v1.2.3