summaryrefslogtreecommitdiff
path: root/training/utils/online_optimizer.h
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2012-11-18 13:35:42 -0500
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2012-11-18 13:35:42 -0500
commit1b8181bf0d6e9137e6b9ccdbe414aec37377a1a9 (patch)
tree33e5f3aa5abff1f41314cf8f6afbd2c2c40e4bfd /training/utils/online_optimizer.h
parent7c4665949fb93fb3de402e4ce1d19bef67850d05 (diff)
major restructure of the training code
Diffstat (limited to 'training/utils/online_optimizer.h')
-rw-r--r--training/utils/online_optimizer.h129
1 files changed, 129 insertions, 0 deletions
diff --git a/training/utils/online_optimizer.h b/training/utils/online_optimizer.h
new file mode 100644
index 00000000..28d89344
--- /dev/null
+++ b/training/utils/online_optimizer.h
@@ -0,0 +1,129 @@
+#ifndef _ONL_OPTIMIZE_H_
+#define _ONL_OPTIMIZE_H_
+
+#include <tr1/memory>
+#include <set>
+#include <string>
+#include <cmath>
+#include "sparse_vector.h"
+
+struct LearningRateSchedule {
+ virtual ~LearningRateSchedule();
+ // returns the learning rate for the kth iteration
+ virtual double eta(int k) const = 0;
+};
+
+// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N
+// to mean the batch size in most places, but it doesn't completely
+// make sense to me in the learning rate schedules-- this needs
+// to be worked out to make sure they didn't mean corpus size
+// in some places and batch size in others (since in the paper they
+// only ever work with batch sizes of 1)
+struct StandardLearningRate : public LearningRateSchedule {
+ StandardLearningRate(
+ size_t batch_size, // batch size, not corpus size!
+ double eta_0 = 0.2) :
+ eta_0_(eta_0),
+ N_(static_cast<double>(batch_size)) {}
+
+ virtual double eta(int k) const;
+
+ private:
+ const double eta_0_;
+ const double N_;
+};
+
+struct ExponentialDecayLearningRate : public LearningRateSchedule {
+ ExponentialDecayLearningRate(
+ size_t batch_size, // batch size, not corpus size!
+ double eta_0 = 0.2,
+ double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009)
+ ) : eta_0_(eta_0),
+ N_(static_cast<double>(batch_size)),
+ alpha_(alpha) {
+ assert(alpha > 0);
+ assert(alpha < 1.0);
+ }
+
+ virtual double eta(int k) const;
+
+ private:
+ const double eta_0_;
+ const double N_;
+ const double alpha_;
+};
+
+class OnlineOptimizer {
+ public:
+ virtual ~OnlineOptimizer();
+ OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
+ size_t batch_size,
+ const std::vector<int>& frozen_feats = std::vector<int>())
+ : N_(batch_size),schedule_(s),k_() {
+ for (int i = 0; i < frozen_feats.size(); ++i)
+ frozen_.insert(frozen_feats[i]);
+ }
+ void ResetEpoch() { k_ = 0; ResetEpochImpl(); }
+ void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
+ ++k_;
+ const double eta = schedule_->eta(k_);
+ UpdateWeightsImpl(eta, approx_g, max_feat, weights);
+ }
+
+ protected:
+ virtual void ResetEpochImpl();
+ virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0;
+ const size_t N_; // number of training instances per batch
+ std::set<int> frozen_; // frozen (non-optimizing) features
+
+ private:
+ std::tr1::shared_ptr<LearningRateSchedule> schedule_;
+ int k_; // iteration count
+};
+
+class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
+ public:
+ CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
+ size_t training_instances, double C,
+ const std::vector<int>& frozen) :
+ OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {}
+
+ protected:
+ void ResetEpochImpl() { u_ = 0; }
+ void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
+ u_ += eta * C_ / N_;
+ for (SparseVector<double>::const_iterator it = approx_g.begin();
+ it != approx_g.end(); ++it) {
+ if (frozen_.count(it->first) == 0)
+ weights->add_value(it->first, eta * it->second);
+ }
+ for (int i = 1; i < max_feat; ++i)
+ if (frozen_.count(i) == 0) ApplyPenalty(i, weights);
+ }
+
+ private:
+ void ApplyPenalty(int i, SparseVector<double>* w) {
+ const double z = w->value(i);
+ double w_i = z;
+ double q_i = q_.value(i);
+ if (w_i > 0.0)
+ w_i = std::max(0.0, w_i - (u_ + q_i));
+ else if (w_i < 0.0)
+ w_i = std::min(0.0, w_i + (u_ - q_i));
+ q_i += w_i - z;
+ if (q_i == 0.0)
+ q_.erase(i);
+ else
+ q_.set_value(i, q_i);
+ if (w_i == 0.0)
+ w->erase(i);
+ else
+ w->set_value(i, w_i);
+ }
+
+ const double C_; // reguarlization strength
+ double u_;
+ SparseVector<double> q_;
+};
+
+#endif