From 68a3c1423c4c602a27b0211cf6b0c217135548d3 Mon Sep 17 00:00:00 2001 From: redpony Date: Fri, 15 Oct 2010 20:13:01 +0000 Subject: new multi-epoch online optimizer git-svn-id: https://ws10smt.googlecode.com/svn/trunk@675 ec762483-ff6d-05da-a07a-a48fb63a330f --- training/online_optimizer.h | 3 +++ 1 file changed, 3 insertions(+) (limited to 'training/online_optimizer.h') diff --git a/training/online_optimizer.h b/training/online_optimizer.h index 963c0380..312aabae 100644 --- a/training/online_optimizer.h +++ b/training/online_optimizer.h @@ -58,6 +58,7 @@ class OnlineOptimizer { OnlineOptimizer(const std::tr1::shared_ptr& s, size_t batch_size) : N_(batch_size),schedule_(s),k_() {} + void ResetEpoch() { k_ = 0; ResetEpochImpl(); } void UpdateWeights(const SparseVector& approx_g, int max_feat, SparseVector* weights) { ++k_; const double eta = schedule_->eta(k_); @@ -65,6 +66,7 @@ class OnlineOptimizer { } protected: + virtual void ResetEpochImpl(); virtual void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) = 0; const size_t N_; // number of training instances per batch @@ -80,6 +82,7 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer { OnlineOptimizer(s, training_instances), C_(C), u_() {} protected: + void ResetEpochImpl() { u_ = 0; } void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) { u_ += eta * C_ / N_; (*weights) += eta * approx_g; -- cgit v1.2.3