summaryrefslogtreecommitdiff
path: root/training/online_optimizer.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/online_optimizer.h')
-rw-r--r--training/online_optimizer.h17
1 files changed, 12 insertions, 5 deletions
diff --git a/training/online_optimizer.h b/training/online_optimizer.h
index 312aabae..61d62a37 100644
--- a/training/online_optimizer.h
+++ b/training/online_optimizer.h
@@ -2,6 +2,7 @@
#define _ONL_OPTIMIZE_H_
#include <tr1/memory>
+#include <set>
#include <string>
#include <cmath>
#include "sparse_vector.h"
@@ -56,8 +57,12 @@ class OnlineOptimizer {
public:
virtual ~OnlineOptimizer();
OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
- size_t batch_size)
- : N_(batch_size),schedule_(s),k_() {}
+ size_t batch_size,
+ const std::vector<int>& frozen_feats = std::vector<int>())
+ : N_(batch_size),schedule_(s),k_() {
+ for (int i = 0; i < frozen_feats.size(); ++i)
+ frozen_.insert(frozen_feats[i]);
+ }
void ResetEpoch() { k_ = 0; ResetEpochImpl(); }
void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
++k_;
@@ -69,6 +74,7 @@ class OnlineOptimizer {
virtual void ResetEpochImpl();
virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0;
const size_t N_; // number of training instances per batch
+ std::set<int> frozen_; // frozen (non-optimizing) features
private:
std::tr1::shared_ptr<LearningRateSchedule> schedule_;
@@ -78,8 +84,9 @@ class OnlineOptimizer {
class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
public:
CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
- size_t training_instances, double C) :
- OnlineOptimizer(s, training_instances), C_(C), u_() {}
+ size_t training_instances, double C,
+ const std::vector<int>& frozen) :
+ OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {}
protected:
void ResetEpochImpl() { u_ = 0; }
@@ -87,7 +94,7 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
u_ += eta * C_ / N_;
(*weights) += eta * approx_g;
for (int i = 1; i < max_feat; ++i)
- ApplyPenalty(i, weights);
+ if (frozen_.count(i) == 0) ApplyPenalty(i, weights);
}
private: