diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-17 22:46:35 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-17 22:46:35 -0400 |
commit | 7079e3685def6f231ecf9f0c3f31b5c03a46d858 (patch) | |
tree | 685d6e7a3d9a487e11a628bf7a7b88fde36a1e5b /training | |
parent | 9f78539edbbe00feeee618932fc5d51f5c5b9eb4 (diff) |
freeze features, including penalty
Diffstat (limited to 'training')
-rw-r--r-- | training/mpi_online_optimize.cc | 4 | ||||
-rw-r--r-- | training/online_optimizer.h | 17 |
2 files changed, 13 insertions, 8 deletions
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 1367581a..32033c19 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -299,7 +299,7 @@ int main(int argc, char** argv) { const string omethod = conf["optimization_method"].as<string>(); if (omethod == "sgd") { const double C = conf["regularization_strength"].as<double>(); - o.reset(new CumulativeL1OnlineOptimizer(lr, total_corpus_size, C)); + o.reset(new CumulativeL1OnlineOptimizer(lr, total_corpus_size, C, frozen_fids)); } else { assert(!"fail"); } @@ -377,8 +377,6 @@ int main(int argc, char** argv) { g.swap(local_grad); #endif local_grad.clear(); - for (int i = 0; i < frozen_fids.size(); ++i) - g.erase(frozen_fids[i]); if (rank == 0) { g /= (size_per_proc * size); o->UpdateWeights(g, FD::NumFeats(), &x); diff --git a/training/online_optimizer.h b/training/online_optimizer.h index 312aabae..61d62a37 100644 --- a/training/online_optimizer.h +++ b/training/online_optimizer.h @@ -2,6 +2,7 @@ #define _ONL_OPTIMIZE_H_ #include <tr1/memory> +#include <set> #include <string> #include <cmath> #include "sparse_vector.h" @@ -56,8 +57,12 @@ class OnlineOptimizer { public: virtual ~OnlineOptimizer(); OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s, - size_t batch_size) - : N_(batch_size),schedule_(s),k_() {} + size_t batch_size, + const std::vector<int>& frozen_feats = std::vector<int>()) + : N_(batch_size),schedule_(s),k_() { + for (int i = 0; i < frozen_feats.size(); ++i) + frozen_.insert(frozen_feats[i]); + } void ResetEpoch() { k_ = 0; ResetEpochImpl(); } void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) { ++k_; @@ -69,6 +74,7 @@ class OnlineOptimizer { virtual void ResetEpochImpl(); virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0; const size_t N_; // number of training instances per batch + std::set<int> frozen_; // frozen (non-optimizing) features private: std::tr1::shared_ptr<LearningRateSchedule> schedule_; @@ -78,8 +84,9 @@ class OnlineOptimizer { class CumulativeL1OnlineOptimizer : public OnlineOptimizer { public: CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s, - size_t training_instances, double C) : - OnlineOptimizer(s, training_instances), C_(C), u_() {} + size_t training_instances, double C, + const std::vector<int>& frozen) : + OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {} protected: void ResetEpochImpl() { u_ = 0; } @@ -87,7 +94,7 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer { u_ += eta * C_ / N_; (*weights) += eta * approx_g; for (int i = 1; i < max_feat; ++i) - ApplyPenalty(i, weights); + if (frozen_.count(i) == 0) ApplyPenalty(i, weights); } private: |