From 9f78539edbbe00feeee618932fc5d51f5c5b9eb4 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 17 Mar 2011 22:29:43 -0400 Subject: enable weights to be frozen during training --- training/mpi_online_optimize.cc | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'training') diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 325ba030..1367581a 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -64,6 +64,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("input_weights,w",po::value(),"Input feature weights file") + ("frozen_features,z",po::value(), "List of features not to optimize") ("training_data,t",po::value(),"Training data corpus") ("training_agenda,a",po::value(), "Text file listing a series of configuration files and the number of iterations to train using each configuration successively") ("minibatch_size_per_proc,s", po::value()->default_value(5), "Number of training instances evaluated per processor in each minibatch") @@ -254,6 +255,20 @@ int main(int argc, char** argv) { if (conf.count("input_weights")) weights.InitFromFile(conf["input_weights"].as()); + vector frozen_fids; + if (conf.count("frozen_features")) { + ReadFile rf(conf["frozen_features"].as()); + istream& in = *rf.stream(); + string line; + while(in) { + getline(in, line); + if (line.empty()) continue; + if (line[0] == ' ' || line[line.size() - 1] == ' ') { line = Trim(line); } + frozen_fids.push_back(FD::Convert(line)); + } + if (rank == 0) cerr << "Freezing " << frozen_fids.size() << " features.\n"; + } + vector corpus; vector ids; ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); @@ -362,6 +377,8 @@ int main(int argc, char** argv) { g.swap(local_grad); #endif local_grad.clear(); + for (int i = 0; i < frozen_fids.size(); ++i) + g.erase(frozen_fids[i]); if (rank == 0) { g /= (size_per_proc * size); o->UpdateWeights(g, FD::NumFeats(), &x); -- cgit v1.2.3 From 7079e3685def6f231ecf9f0c3f31b5c03a46d858 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 17 Mar 2011 22:46:35 -0400 Subject: freeze features, including penalty --- training/mpi_online_optimize.cc | 4 +--- training/online_optimizer.h | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 8 deletions(-) (limited to 'training') diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 1367581a..32033c19 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -299,7 +299,7 @@ int main(int argc, char** argv) { const string omethod = conf["optimization_method"].as(); if (omethod == "sgd") { const double C = conf["regularization_strength"].as(); - o.reset(new CumulativeL1OnlineOptimizer(lr, total_corpus_size, C)); + o.reset(new CumulativeL1OnlineOptimizer(lr, total_corpus_size, C, frozen_fids)); } else { assert(!"fail"); } @@ -377,8 +377,6 @@ int main(int argc, char** argv) { g.swap(local_grad); #endif local_grad.clear(); - for (int i = 0; i < frozen_fids.size(); ++i) - g.erase(frozen_fids[i]); if (rank == 0) { g /= (size_per_proc * size); o->UpdateWeights(g, FD::NumFeats(), &x); diff --git a/training/online_optimizer.h b/training/online_optimizer.h index 312aabae..61d62a37 100644 --- a/training/online_optimizer.h +++ b/training/online_optimizer.h @@ -2,6 +2,7 @@ #define _ONL_OPTIMIZE_H_ #include +#include #include #include #include "sparse_vector.h" @@ -56,8 +57,12 @@ class OnlineOptimizer { public: virtual ~OnlineOptimizer(); OnlineOptimizer(const std::tr1::shared_ptr& s, - size_t batch_size) - : N_(batch_size),schedule_(s),k_() {} + size_t batch_size, + const std::vector& frozen_feats = std::vector()) + : N_(batch_size),schedule_(s),k_() { + for (int i = 0; i < frozen_feats.size(); ++i) + frozen_.insert(frozen_feats[i]); + } void ResetEpoch() { k_ = 0; ResetEpochImpl(); } void UpdateWeights(const SparseVector& approx_g, int max_feat, SparseVector* weights) { ++k_; @@ -69,6 +74,7 @@ class OnlineOptimizer { virtual void ResetEpochImpl(); virtual void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) = 0; const size_t N_; // number of training instances per batch + std::set frozen_; // frozen (non-optimizing) features private: std::tr1::shared_ptr schedule_; @@ -78,8 +84,9 @@ class OnlineOptimizer { class CumulativeL1OnlineOptimizer : public OnlineOptimizer { public: CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr& s, - size_t training_instances, double C) : - OnlineOptimizer(s, training_instances), C_(C), u_() {} + size_t training_instances, double C, + const std::vector& frozen) : + OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {} protected: void ResetEpochImpl() { u_ = 0; } @@ -87,7 +94,7 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer { u_ += eta * C_ / N_; (*weights) += eta * approx_g; for (int i = 1; i < max_feat; ++i) - ApplyPenalty(i, weights); + if (frozen_.count(i) == 0) ApplyPenalty(i, weights); } private: -- cgit v1.2.3 From 4482fe7a82e3f9a197bf65d60635885c4bfab195 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 17 Mar 2011 22:53:19 -0400 Subject: try 2 --- training/online_optimizer.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'training') diff --git a/training/online_optimizer.h b/training/online_optimizer.h index 61d62a37..28d89344 100644 --- a/training/online_optimizer.h +++ b/training/online_optimizer.h @@ -92,7 +92,11 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer { void ResetEpochImpl() { u_ = 0; } void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) { u_ += eta * C_ / N_; - (*weights) += eta * approx_g; + for (SparseVector::const_iterator it = approx_g.begin(); + it != approx_g.end(); ++it) { + if (frozen_.count(it->first) == 0) + weights->add_value(it->first, eta * it->second); + } for (int i = 1; i < max_feat; ++i) if (frozen_.count(i) == 0) ApplyPenalty(i, weights); } -- cgit v1.2.3 From ed47102885e52c52146fc8631ff624779bd7eb0a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 18 Mar 2011 10:36:26 -0400 Subject: compile fix --- Makefile.am | 4 +++- training/optimize_test.cc | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'training') diff --git a/Makefile.am b/Makefile.am index a808c211..bd46bd91 100644 --- a/Makefile.am +++ b/Makefile.am @@ -1,7 +1,9 @@ # warning - the subdirectories in the following list should # be kept in topologically sorted order. Also, DO NOT introduce # cyclic dependencies between these directories! -SUBDIRS = utils mteval klm/util klm/lm decoder phrasinator training vest extools gi/pyp-topics/src gi/clda/src gi/posterior-regularisation/prjava +SUBDIRS = utils mteval klm/util klm/lm decoder phrasinator training vest extools + +#gi/pyp-topics/src gi/clda/src gi/posterior-regularisation/prjava AUTOMAKE_OPTIONS = foreign ACLOCAL_AMFLAGS = -I m4 diff --git a/training/optimize_test.cc b/training/optimize_test.cc index 6fa5efd4..fe7ca70f 100644 --- a/training/optimize_test.cc +++ b/training/optimize_test.cc @@ -104,7 +104,7 @@ void TestOnline() { double eta0 = 0.2; shared_ptr r(new ExponentialDecayLearningRate(N, eta0, 0.85)); //shared_ptr r(new StandardLearningRate(N, eta0)); - CumulativeL1OnlineOptimizer opt(r, N, C); + CumulativeL1OnlineOptimizer opt(r, N, C, std::vector()); assert(r->eta(10) < r->eta(1)); } -- cgit v1.2.3