summaryrefslogtreecommitdiff
path: root/training/online_optimizer.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/online_optimizer.h')
-rw-r--r--training/online_optimizer.h3
1 files changed, 3 insertions, 0 deletions
diff --git a/training/online_optimizer.h b/training/online_optimizer.h
index 963c0380..312aabae 100644
--- a/training/online_optimizer.h
+++ b/training/online_optimizer.h
@@ -58,6 +58,7 @@ class OnlineOptimizer {
OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
size_t batch_size)
: N_(batch_size),schedule_(s),k_() {}
+ void ResetEpoch() { k_ = 0; ResetEpochImpl(); }
void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
++k_;
const double eta = schedule_->eta(k_);
@@ -65,6 +66,7 @@ class OnlineOptimizer {
}
protected:
+ virtual void ResetEpochImpl();
virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0;
const size_t N_; // number of training instances per batch
@@ -80,6 +82,7 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
OnlineOptimizer(s, training_instances), C_(C), u_() {}
protected:
+ void ResetEpochImpl() { u_ = 0; }
void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
u_ += eta * C_ / N_;
(*weights) += eta * approx_g;