From 3db6b004ae1e2319f52862d428c20be5a1538993 Mon Sep 17 00:00:00 2001 From: redpony Date: Tue, 28 Sep 2010 17:06:08 +0000 Subject: use boost mpi, fix L1 stochastic optimizer git-svn-id: https://ws10smt.googlecode.com/svn/trunk@659 ec762483-ff6d-05da-a07a-a48fb63a330f --- training/online_optimizer.h | 50 ++++++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 19 deletions(-) (limited to 'training/online_optimizer.h') diff --git a/training/online_optimizer.h b/training/online_optimizer.h index d2718f93..963c0380 100644 --- a/training/online_optimizer.h +++ b/training/online_optimizer.h @@ -8,16 +8,22 @@ struct LearningRateSchedule { virtual ~LearningRateSchedule(); - // returns the learning rate for iteration k + // returns the learning rate for the kth iteration virtual double eta(int k) const = 0; }; +// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N +// to mean the batch size in most places, but it doesn't completely +// make sense to me in the learning rate schedules-- this needs +// to be worked out to make sure they didn't mean corpus size +// in some places and batch size in others (since in the paper they +// only ever work with batch sizes of 1) struct StandardLearningRate : public LearningRateSchedule { StandardLearningRate( - size_t training_instances, + size_t batch_size, // batch size, not corpus size! double eta_0 = 0.2) : eta_0_(eta_0), - N_(static_cast(training_instances)) {} + N_(static_cast(batch_size)) {} virtual double eta(int k) const; @@ -28,11 +34,11 @@ struct StandardLearningRate : public LearningRateSchedule { struct ExponentialDecayLearningRate : public LearningRateSchedule { ExponentialDecayLearningRate( - size_t training_instances, + size_t batch_size, // batch size, not corpus size! double eta_0 = 0.2, double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009) ) : eta_0_(eta_0), - N_(static_cast(training_instances)), + N_(static_cast(batch_size)), alpha_(alpha) { assert(alpha > 0); assert(alpha < 1.0); @@ -50,17 +56,17 @@ class OnlineOptimizer { public: virtual ~OnlineOptimizer(); OnlineOptimizer(const std::tr1::shared_ptr& s, - size_t training_instances) - : N_(training_instances),schedule_(s),k_() {} - void UpdateWeights(const SparseVector& approx_g, SparseVector* weights) { + size_t batch_size) + : N_(batch_size),schedule_(s),k_() {} + void UpdateWeights(const SparseVector& approx_g, int max_feat, SparseVector* weights) { ++k_; const double eta = schedule_->eta(k_); - UpdateWeightsImpl(eta, approx_g, weights); + UpdateWeightsImpl(eta, approx_g, max_feat, weights); } protected: - virtual void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, SparseVector* weights) = 0; - const size_t N_; // number of training instances + virtual void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) = 0; + const size_t N_; // number of training instances per batch private: std::tr1::shared_ptr schedule_; @@ -74,11 +80,11 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer { OnlineOptimizer(s, training_instances), C_(C), u_() {} protected: - void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, SparseVector* weights) { + void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) { u_ += eta * C_ / N_; (*weights) += eta * approx_g; - for (SparseVector::const_iterator it = approx_g.begin(); it != approx_g.end(); ++it) - ApplyPenalty(it->first, weights); + for (int i = 1; i < max_feat; ++i) + ApplyPenalty(i, weights); } private: @@ -86,13 +92,19 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer { const double z = w->value(i); double w_i = z; double q_i = q_.value(i); - if (w_i > 0) + if (w_i > 0.0) w_i = std::max(0.0, w_i - (u_ + q_i)); - else - w_i = std::max(0.0, w_i + (u_ - q_i)); + else if (w_i < 0.0) + w_i = std::min(0.0, w_i + (u_ - q_i)); q_i += w_i - z; - q_.set_value(i, q_i); - w->set_value(i, w_i); + if (q_i == 0.0) + q_.erase(i); + else + q_.set_value(i, q_i); + if (w_i == 0.0) + w->erase(i); + else + w->set_value(i, w_i); } const double C_; // reguarlization strength -- cgit v1.2.3