summaryrefslogtreecommitdiff
path: root/training/online_optimizer.h
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-10-15 20:13:01 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-10-15 20:13:01 +0000
commit68a3c1423c4c602a27b0211cf6b0c217135548d3 (patch)
treef9d9514fa0da81e27ce342130b4547cb3b2bd740 /training/online_optimizer.h
parentbb70b3e2fc8c0eca56bbed8132a69cf50ad819bc (diff)
new multi-epoch online optimizer
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@675 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training/online_optimizer.h')
-rw-r--r--training/online_optimizer.h3
1 files changed, 3 insertions, 0 deletions
diff --git a/training/online_optimizer.h b/training/online_optimizer.h
index 963c0380..312aabae 100644
--- a/training/online_optimizer.h
+++ b/training/online_optimizer.h
@@ -58,6 +58,7 @@ class OnlineOptimizer {
OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
size_t batch_size)
: N_(batch_size),schedule_(s),k_() {}
+ void ResetEpoch() { k_ = 0; ResetEpochImpl(); }
void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
++k_;
const double eta = schedule_->eta(k_);
@@ -65,6 +66,7 @@ class OnlineOptimizer {
}
protected:
+ virtual void ResetEpochImpl();
virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0;
const size_t N_; // number of training instances per batch
@@ -80,6 +82,7 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
OnlineOptimizer(s, training_instances), C_(C), u_() {}
protected:
+ void ResetEpochImpl() { u_ = 0; }
void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
u_ += eta * C_ / N_;
(*weights) += eta * approx_g;