diff options
Diffstat (limited to 'training/online_optimizer.h')
-rw-r--r-- | training/online_optimizer.h | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/training/online_optimizer.h b/training/online_optimizer.h index 963c0380..312aabae 100644 --- a/training/online_optimizer.h +++ b/training/online_optimizer.h @@ -58,6 +58,7 @@ class OnlineOptimizer { OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s, size_t batch_size) : N_(batch_size),schedule_(s),k_() {} + void ResetEpoch() { k_ = 0; ResetEpochImpl(); } void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) { ++k_; const double eta = schedule_->eta(k_); @@ -65,6 +66,7 @@ class OnlineOptimizer { } protected: + virtual void ResetEpochImpl(); virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0; const size_t N_; // number of training instances per batch @@ -80,6 +82,7 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer { OnlineOptimizer(s, training_instances), C_(C), u_() {} protected: + void ResetEpochImpl() { u_ = 0; } void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) { u_ += eta * C_ / N_; (*weights) += eta * approx_g; |