summaryrefslogtreecommitdiff
path: root/training/online_optimizer.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/online_optimizer.h')
-rw-r--r--training/online_optimizer.h50
1 files changed, 31 insertions, 19 deletions
diff --git a/training/online_optimizer.h b/training/online_optimizer.h
index d2718f93..963c0380 100644
--- a/training/online_optimizer.h
+++ b/training/online_optimizer.h
@@ -8,16 +8,22 @@
struct LearningRateSchedule {
virtual ~LearningRateSchedule();
- // returns the learning rate for iteration k
+ // returns the learning rate for the kth iteration
virtual double eta(int k) const = 0;
};
+// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N
+// to mean the batch size in most places, but it doesn't completely
+// make sense to me in the learning rate schedules-- this needs
+// to be worked out to make sure they didn't mean corpus size
+// in some places and batch size in others (since in the paper they
+// only ever work with batch sizes of 1)
struct StandardLearningRate : public LearningRateSchedule {
StandardLearningRate(
- size_t training_instances,
+ size_t batch_size, // batch size, not corpus size!
double eta_0 = 0.2) :
eta_0_(eta_0),
- N_(static_cast<double>(training_instances)) {}
+ N_(static_cast<double>(batch_size)) {}
virtual double eta(int k) const;
@@ -28,11 +34,11 @@ struct StandardLearningRate : public LearningRateSchedule {
struct ExponentialDecayLearningRate : public LearningRateSchedule {
ExponentialDecayLearningRate(
- size_t training_instances,
+ size_t batch_size, // batch size, not corpus size!
double eta_0 = 0.2,
double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009)
) : eta_0_(eta_0),
- N_(static_cast<double>(training_instances)),
+ N_(static_cast<double>(batch_size)),
alpha_(alpha) {
assert(alpha > 0);
assert(alpha < 1.0);
@@ -50,17 +56,17 @@ class OnlineOptimizer {
public:
virtual ~OnlineOptimizer();
OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
- size_t training_instances)
- : N_(training_instances),schedule_(s),k_() {}
- void UpdateWeights(const SparseVector<double>& approx_g, SparseVector<double>* weights) {
+ size_t batch_size)
+ : N_(batch_size),schedule_(s),k_() {}
+ void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
++k_;
const double eta = schedule_->eta(k_);
- UpdateWeightsImpl(eta, approx_g, weights);
+ UpdateWeightsImpl(eta, approx_g, max_feat, weights);
}
protected:
- virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, SparseVector<double>* weights) = 0;
- const size_t N_; // number of training instances
+ virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0;
+ const size_t N_; // number of training instances per batch
private:
std::tr1::shared_ptr<LearningRateSchedule> schedule_;
@@ -74,11 +80,11 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
OnlineOptimizer(s, training_instances), C_(C), u_() {}
protected:
- void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, SparseVector<double>* weights) {
+ void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
u_ += eta * C_ / N_;
(*weights) += eta * approx_g;
- for (SparseVector<double>::const_iterator it = approx_g.begin(); it != approx_g.end(); ++it)
- ApplyPenalty(it->first, weights);
+ for (int i = 1; i < max_feat; ++i)
+ ApplyPenalty(i, weights);
}
private:
@@ -86,13 +92,19 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
const double z = w->value(i);
double w_i = z;
double q_i = q_.value(i);
- if (w_i > 0)
+ if (w_i > 0.0)
w_i = std::max(0.0, w_i - (u_ + q_i));
- else
- w_i = std::max(0.0, w_i + (u_ - q_i));
+ else if (w_i < 0.0)
+ w_i = std::min(0.0, w_i + (u_ - q_i));
q_i += w_i - z;
- q_.set_value(i, q_i);
- w->set_value(i, w_i);
+ if (q_i == 0.0)
+ q_.erase(i);
+ else
+ q_.set_value(i, q_i);
+ if (w_i == 0.0)
+ w->erase(i);
+ else
+ w->set_value(i, w_i);
}
const double C_; // reguarlization strength