summaryrefslogtreecommitdiff
path: root/training/online_optimizer.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/online_optimizer.h')
-rw-r--r--training/online_optimizer.h129
1 files changed, 0 insertions, 129 deletions
diff --git a/training/online_optimizer.h b/training/online_optimizer.h
deleted file mode 100644
index 28d89344..00000000
--- a/training/online_optimizer.h
+++ /dev/null
@@ -1,129 +0,0 @@
-#ifndef _ONL_OPTIMIZE_H_
-#define _ONL_OPTIMIZE_H_
-
-#include <tr1/memory>
-#include <set>
-#include <string>
-#include <cmath>
-#include "sparse_vector.h"
-
-struct LearningRateSchedule {
- virtual ~LearningRateSchedule();
- // returns the learning rate for the kth iteration
- virtual double eta(int k) const = 0;
-};
-
-// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N
-// to mean the batch size in most places, but it doesn't completely
-// make sense to me in the learning rate schedules-- this needs
-// to be worked out to make sure they didn't mean corpus size
-// in some places and batch size in others (since in the paper they
-// only ever work with batch sizes of 1)
-struct StandardLearningRate : public LearningRateSchedule {
- StandardLearningRate(
- size_t batch_size, // batch size, not corpus size!
- double eta_0 = 0.2) :
- eta_0_(eta_0),
- N_(static_cast<double>(batch_size)) {}
-
- virtual double eta(int k) const;
-
- private:
- const double eta_0_;
- const double N_;
-};
-
-struct ExponentialDecayLearningRate : public LearningRateSchedule {
- ExponentialDecayLearningRate(
- size_t batch_size, // batch size, not corpus size!
- double eta_0 = 0.2,
- double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009)
- ) : eta_0_(eta_0),
- N_(static_cast<double>(batch_size)),
- alpha_(alpha) {
- assert(alpha > 0);
- assert(alpha < 1.0);
- }
-
- virtual double eta(int k) const;
-
- private:
- const double eta_0_;
- const double N_;
- const double alpha_;
-};
-
-class OnlineOptimizer {
- public:
- virtual ~OnlineOptimizer();
- OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
- size_t batch_size,
- const std::vector<int>& frozen_feats = std::vector<int>())
- : N_(batch_size),schedule_(s),k_() {
- for (int i = 0; i < frozen_feats.size(); ++i)
- frozen_.insert(frozen_feats[i]);
- }
- void ResetEpoch() { k_ = 0; ResetEpochImpl(); }
- void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
- ++k_;
- const double eta = schedule_->eta(k_);
- UpdateWeightsImpl(eta, approx_g, max_feat, weights);
- }
-
- protected:
- virtual void ResetEpochImpl();
- virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0;
- const size_t N_; // number of training instances per batch
- std::set<int> frozen_; // frozen (non-optimizing) features
-
- private:
- std::tr1::shared_ptr<LearningRateSchedule> schedule_;
- int k_; // iteration count
-};
-
-class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
- public:
- CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
- size_t training_instances, double C,
- const std::vector<int>& frozen) :
- OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {}
-
- protected:
- void ResetEpochImpl() { u_ = 0; }
- void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
- u_ += eta * C_ / N_;
- for (SparseVector<double>::const_iterator it = approx_g.begin();
- it != approx_g.end(); ++it) {
- if (frozen_.count(it->first) == 0)
- weights->add_value(it->first, eta * it->second);
- }
- for (int i = 1; i < max_feat; ++i)
- if (frozen_.count(i) == 0) ApplyPenalty(i, weights);
- }
-
- private:
- void ApplyPenalty(int i, SparseVector<double>* w) {
- const double z = w->value(i);
- double w_i = z;
- double q_i = q_.value(i);
- if (w_i > 0.0)
- w_i = std::max(0.0, w_i - (u_ + q_i));
- else if (w_i < 0.0)
- w_i = std::min(0.0, w_i + (u_ - q_i));
- q_i += w_i - z;
- if (q_i == 0.0)
- q_.erase(i);
- else
- q_.set_value(i, q_i);
- if (w_i == 0.0)
- w->erase(i);
- else
- w->set_value(i, w_i);
- }
-
- const double C_; // reguarlization strength
- double u_;
- SparseVector<double> q_;
-};
-
-#endif