diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-10-15 20:13:01 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-10-15 20:13:01 +0000 |
commit | 68a3c1423c4c602a27b0211cf6b0c217135548d3 (patch) | |
tree | f9d9514fa0da81e27ce342130b4547cb3b2bd740 /training/online_optimizer.h | |
parent | bb70b3e2fc8c0eca56bbed8132a69cf50ad819bc (diff) |
new multi-epoch online optimizer
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@675 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training/online_optimizer.h')
-rw-r--r-- | training/online_optimizer.h | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/training/online_optimizer.h b/training/online_optimizer.h index 963c0380..312aabae 100644 --- a/training/online_optimizer.h +++ b/training/online_optimizer.h @@ -58,6 +58,7 @@ class OnlineOptimizer { OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s, size_t batch_size) : N_(batch_size),schedule_(s),k_() {} + void ResetEpoch() { k_ = 0; ResetEpochImpl(); } void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) { ++k_; const double eta = schedule_->eta(k_); @@ -65,6 +66,7 @@ class OnlineOptimizer { } protected: + virtual void ResetEpochImpl(); virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0; const size_t N_; // number of training instances per batch @@ -80,6 +82,7 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer { OnlineOptimizer(s, training_instances), C_(C), u_() {} protected: + void ResetEpochImpl() { u_ = 0; } void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) { u_ += eta * C_ / N_; (*weights) += eta * approx_g; |