summaryrefslogtreecommitdiff
path: root/training/mpi_online_optimize.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-09-13 17:36:23 +0100
committerChris Dyer <cdyer@cs.cmu.edu>2011-09-13 17:36:23 +0100
commitbb86637332d49f71c485df34576e464eaf053656 (patch)
treeefaa1cb07db897f3443c9dc69712999a530921f3 /training/mpi_online_optimize.cc
parent7fadd06330c015d7ebc51ebd50e30332d187acbb (diff)
get rid of bad Weights class so it no longer keeps a copy of a vector inside it
Diffstat (limited to 'training/mpi_online_optimize.cc')
-rw-r--r--training/mpi_online_optimize.cc69
1 files changed, 24 insertions, 45 deletions
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc
index 32033c19..2ef4a2e7 100644
--- a/training/mpi_online_optimize.cc
+++ b/training/mpi_online_optimize.cc
@@ -31,35 +31,6 @@ namespace mpi = boost::mpi;
using namespace std;
namespace po = boost::program_options;
-void SanityCheck(const vector<double>& w) {
- for (int i = 0; i < w.size(); ++i) {
- assert(!isnan(w[i]));
- assert(!isinf(w[i]));
- }
-}
-
-struct FComp {
- const vector<double>& w_;
- FComp(const vector<double>& w) : w_(w) {}
- bool operator()(int a, int b) const {
- return fabs(w_[a]) > fabs(w_[b]);
- }
-};
-
-void ShowLargestFeatures(const vector<double>& w) {
- vector<int> fnums(w.size());
- for (int i = 0; i < w.size(); ++i)
- fnums[i] = i;
- vector<int>::iterator mid = fnums.begin();
- mid += (w.size() > 10 ? 10 : w.size());
- partial_sort(fnums.begin(), mid, fnums.end(), FComp(w));
- cerr << "TOP FEATURES:";
- for (vector<int>::iterator i = fnums.begin(); i != mid; ++i) {
- cerr << ' ' << FD::Convert(*i) << '=' << w[*i];
- }
- cerr << endl;
-}
-
bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
@@ -250,10 +221,25 @@ int main(int argc, char** argv) {
if (!InitCommandLine(argc, argv, &conf))
return 1;
+ vector<pair<string, int> > agenda;
+ if (!LoadAgenda(conf["training_agenda"].as<string>(), &agenda))
+ return 1;
+ if (rank == 0)
+ cerr << "Loaded agenda defining " << agenda.size() << " training epochs\n";
+
+ assert(agenda.size() > 0);
+
+ if (1) { // hack to load the feature hash functions -- TODO this should not be in cdec.ini
+ const string& cur_config = agenda[0].first;
+ const unsigned max_iteration = agenda[0].second;
+ ReadFile ini_rf(cur_config);
+ Decoder decoder(ini_rf.stream());
+ }
+
// load initial weights
- Weights weights;
+ vector<weight_t> init_weights;
if (conf.count("input_weights"))
- weights.InitFromFile(conf["input_weights"].as<string>());
+ Weights::InitFromFile(conf["input_weights"].as<string>(), &init_weights);
vector<int> frozen_fids;
if (conf.count("frozen_features")) {
@@ -310,19 +296,12 @@ int main(int argc, char** argv) {
rng.reset(new MT19937);
SparseVector<double> x;
- weights.InitSparseVector(&x);
+ Weights::InitSparseVector(init_weights, &x);
TrainingObserver observer;
int write_weights_every_ith = 100; // TODO configure
int titer = -1;
- vector<pair<string, int> > agenda;
- if (!LoadAgenda(conf["training_agenda"].as<string>(), &agenda))
- return 1;
- if (rank == 0)
- cerr << "Loaded agenda defining " << agenda.size() << " training epochs\n";
-
- vector<double> lambdas;
for (int ai = 0; ai < agenda.size(); ++ai) {
const string& cur_config = agenda[ai].first;
const unsigned max_iteration = agenda[ai].second;
@@ -331,6 +310,8 @@ int main(int argc, char** argv) {
// load cdec.ini and set up decoder
ReadFile ini_rf(cur_config);
Decoder decoder(ini_rf.stream());
+ vector<weight_t>& lambdas = decoder.CurrentWeightVector();
+ if (ai == 0) { lambdas.swap(init_weights); init_weights.clear(); }
if (rank == 0)
o->ResetEpoch(); // resets the learning rate-- TODO is this good?
@@ -341,15 +322,13 @@ int main(int argc, char** argv) {
#ifdef HAVE_MPI
mpi::timer timer;
#endif
- weights.InitFromVector(x);
- weights.InitVector(&lambdas);
+ x.init_vector(&lambdas);
++iter; ++titer;
observer.Reset();
- decoder.SetWeights(lambdas);
if (rank == 0) {
converged = (iter == max_iteration);
- SanityCheck(lambdas);
- ShowLargestFeatures(lambdas);
+ Weights::SanityCheck(lambdas);
+ Weights::ShowLargestFeatures(lambdas);
string fname = "weights.cur.gz";
if (iter % write_weights_every_ith == 0) {
ostringstream o; o << "weights.epoch_" << (ai+1) << '.' << iter << ".gz";
@@ -360,7 +339,7 @@ int main(int argc, char** argv) {
vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast<double>(corpus.size())) << " eta=" << lr->eta(titer);
const string svv = vv.str();
cerr << svv << endl;
- weights.WriteToFile(fname, true, &svv);
+ Weights::WriteToFile(fname, lambdas, true, &svv);
}
for (int i = 0; i < size_per_proc; ++i) {