summaryrefslogtreecommitdiff
path: root/training/mpi_online_optimize.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-03-17 22:29:43 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2011-03-17 22:29:43 -0400
commit9f78539edbbe00feeee618932fc5d51f5c5b9eb4 (patch)
tree69aa5bc909eb41e94756fcaef2e44d002db666ab /training/mpi_online_optimize.cc
parent95e50962fe307b930e835513e4d9998df91426a4 (diff)
enable weights to be frozen during training
Diffstat (limited to 'training/mpi_online_optimize.cc')
-rw-r--r--training/mpi_online_optimize.cc17
1 files changed, 17 insertions, 0 deletions
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc
index 325ba030..1367581a 100644
--- a/training/mpi_online_optimize.cc
+++ b/training/mpi_online_optimize.cc
@@ -64,6 +64,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
("input_weights,w",po::value<string>(),"Input feature weights file")
+ ("frozen_features,z",po::value<string>(), "List of features not to optimize")
("training_data,t",po::value<string>(),"Training data corpus")
("training_agenda,a",po::value<string>(), "Text file listing a series of configuration files and the number of iterations to train using each configuration successively")
("minibatch_size_per_proc,s", po::value<unsigned>()->default_value(5), "Number of training instances evaluated per processor in each minibatch")
@@ -254,6 +255,20 @@ int main(int argc, char** argv) {
if (conf.count("input_weights"))
weights.InitFromFile(conf["input_weights"].as<string>());
+ vector<int> frozen_fids;
+ if (conf.count("frozen_features")) {
+ ReadFile rf(conf["frozen_features"].as<string>());
+ istream& in = *rf.stream();
+ string line;
+ while(in) {
+ getline(in, line);
+ if (line.empty()) continue;
+ if (line[0] == ' ' || line[line.size() - 1] == ' ') { line = Trim(line); }
+ frozen_fids.push_back(FD::Convert(line));
+ }
+ if (rank == 0) cerr << "Freezing " << frozen_fids.size() << " features.\n";
+ }
+
vector<string> corpus;
vector<int> ids;
ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus, &ids);
@@ -362,6 +377,8 @@ int main(int argc, char** argv) {
g.swap(local_grad);
#endif
local_grad.clear();
+ for (int i = 0; i < frozen_fids.size(); ++i)
+ g.erase(frozen_fids[i]);
if (rank == 0) {
g /= (size_per_proc * size);
o->UpdateWeights(g, FD::NumFeats(), &x);