From 9f78539edbbe00feeee618932fc5d51f5c5b9eb4 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 17 Mar 2011 22:29:43 -0400 Subject: enable weights to be frozen during training --- training/mpi_online_optimize.cc | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'training') diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 325ba030..1367581a 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -64,6 +64,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("input_weights,w",po::value(),"Input feature weights file") + ("frozen_features,z",po::value(), "List of features not to optimize") ("training_data,t",po::value(),"Training data corpus") ("training_agenda,a",po::value(), "Text file listing a series of configuration files and the number of iterations to train using each configuration successively") ("minibatch_size_per_proc,s", po::value()->default_value(5), "Number of training instances evaluated per processor in each minibatch") @@ -254,6 +255,20 @@ int main(int argc, char** argv) { if (conf.count("input_weights")) weights.InitFromFile(conf["input_weights"].as()); + vector frozen_fids; + if (conf.count("frozen_features")) { + ReadFile rf(conf["frozen_features"].as()); + istream& in = *rf.stream(); + string line; + while(in) { + getline(in, line); + if (line.empty()) continue; + if (line[0] == ' ' || line[line.size() - 1] == ' ') { line = Trim(line); } + frozen_fids.push_back(FD::Convert(line)); + } + if (rank == 0) cerr << "Freezing " << frozen_fids.size() << " features.\n"; + } + vector corpus; vector ids; ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); @@ -362,6 +377,8 @@ int main(int argc, char** argv) { g.swap(local_grad); #endif local_grad.clear(); + for (int i = 0; i < frozen_fids.size(); ++i) + g.erase(frozen_fids[i]); if (rank == 0) { g /= (size_per_proc * size); o->UpdateWeights(g, FD::NumFeats(), &x); -- cgit v1.2.3