summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-10-15 20:30:49 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-10-15 20:30:49 +0000
commitb4e3c270d454115df8fd36a321c513a5605646f4 (patch)
treed3e02b4fcb2364e95a8b230ceed95a81351d9140 /training
parent99fc029872c94d59851237e2f64e66951ceec1d2 (diff)
few fixes
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@676 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training')
-rw-r--r--training/mpi_online_optimize.cc13
1 files changed, 8 insertions, 5 deletions
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc
index 4c08b181..96f3bcfb 100644
--- a/training/mpi_online_optimize.cc
+++ b/training/mpi_online_optimize.cc
@@ -220,7 +220,6 @@ bool LoadAgenda(const string& file, vector<pair<string, int> >* a) {
x.first = line.substr(0,d);
x.second = atoi(line.substr(d+1).c_str());
a->push_back(x);
- cerr << "X: " << x.second << " - " << x.first << "'\n";
if (!FileExists(x.first)) {
cerr << "Can't find file " << x.first << endl;
return false;
@@ -261,7 +260,11 @@ int main(int argc, char** argv) {
return 1;
}
+ size_t total_corpus_size = 0;
+ reduce(world, corpus.size(), total_corpus_size, std::plus<size_t>(), 0);
+
if (rank == 0) {
+ cerr << "Total corpus size: " << total_corpus_size << endl;
const unsigned batch_size = size_per_proc * size;
// TODO config
lr.reset(new ExponentialDecayLearningRate(batch_size, conf["eta_0"].as<double>()));
@@ -269,7 +272,7 @@ int main(int argc, char** argv) {
const string omethod = conf["optimization_method"].as<string>();
if (omethod == "sgd") {
const double C = conf["regularization_strength"].as<double>();
- o.reset(new CumulativeL1OnlineOptimizer(lr, corpus.size(), C));
+ o.reset(new CumulativeL1OnlineOptimizer(lr, total_corpus_size, C));
} else {
assert(!"fail");
}
@@ -302,7 +305,8 @@ int main(int argc, char** argv) {
ReadFile ini_rf(cur_config);
Decoder decoder(ini_rf.stream());
- o->ResetEpoch(); // resets the learning rate-- TODO is this good?
+ if (rank == 0)
+ o->ResetEpoch(); // resets the learning rate-- TODO is this good?
int iter = -1;
bool converged = false;
@@ -324,7 +328,7 @@ int main(int argc, char** argv) {
}
if (converged && ((ai+1)==agenda.size())) { fname = "weights.final.gz"; }
ostringstream vv;
- vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size * size_per_proc / static_cast<double>(corpus.size())) << " eta=" << lr->eta(titer);
+ vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast<double>(corpus.size())) << " eta=" << lr->eta(titer);
const string svv = vv.str();
cerr << svv << endl;
weights.WriteToFile(fname, true, &svv);
@@ -343,7 +347,6 @@ int main(int argc, char** argv) {
if (rank == 0) {
g /= (size_per_proc * size);
o->UpdateWeights(g, FD::NumFeats(), &x);
- cerr << "XX: " << x << endl;
}
broadcast(world, x, 0);
broadcast(world, converged, 0);