summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-09-29 20:45:48 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-09-29 20:45:48 +0000
commitf412aaab3d10fb82b20a2190f2cb1424959c599a (patch)
tree1942e2a05777694cc81724f3206c8972813b4224 /training
parent7f56dd65ee706683444b012d0afcfff3e376bfff (diff)
another feature, another POS
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@664 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training')
-rw-r--r--training/mpi_online_optimize.cc18
1 files changed, 11 insertions, 7 deletions
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc
index 0c032c01..d662e8bd 100644
--- a/training/mpi_online_optimize.cc
+++ b/training/mpi_online_optimize.cc
@@ -215,10 +215,9 @@ int main(int argc, char** argv) {
mpi::communicator world;
const int size = world.size();
const int rank = world.rank();
- SetSilent(true); // turn off verbose decoder output
- cerr << "MPI: I am " << rank << '/' << size << endl;
+ if (size > 1) SetSilent(true); // turn off verbose decoder output
register_feature_functions();
- MT19937* rng = NULL;
+ std::tr1::shared_ptr<MT19937> rng;
po::variables_map conf;
InitCommandLine(argc, argv, &conf);
@@ -272,9 +271,9 @@ int main(int argc, char** argv) {
for (unsigned i = 0; i < order.size(); ++i) order[i]=i;
// randomize corpus
if (conf.count("random_seed"))
- rng = new MT19937(conf["random_seed"].as<uint32_t>());
+ rng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
else
- rng = new MT19937;
+ rng.reset(new MT19937);
}
SparseVector<double> x;
weights.InitSparseVector(&x);
@@ -283,6 +282,7 @@ int main(int argc, char** argv) {
double objective = 0;
bool converged = false;
+ int write_weights_every_ith = 100; // TODO configure
int iter = -1;
vector<double> lambdas;
while (!converged) {
@@ -296,6 +296,10 @@ int main(int argc, char** argv) {
ShowLargestFeatures(lambdas);
string fname = "weights.cur.gz";
if (converged) { fname = "weights.final.gz"; }
+ if (iter % write_weights_every_ith == 0) {
+ ostringstream o; o << "weights." << iter << ".gz";
+ fname = o.str();
+ }
ostringstream vv;
vv << "Objective = " << objective; // << " (eval count=" << o->EvaluationCount() << ")";
const string svv = vv.str();
@@ -304,12 +308,12 @@ int main(int argc, char** argv) {
if (fully_random || size * size_per_proc * miter > corpus.size()) {
if (rank == 0)
- Shuffle(&order, rng);
+ Shuffle(&order, rng.get());
miter = 0;
broadcast(world, order, 0);
}
if (rank == 0)
- cerr << "iter=" << iter << " minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << " passes_thru_data=" << (iter * batch_size / static_cast<double>(corpus.size())) << " eta=" << lr->eta(iter) << endl;
+ cerr << "iter=" << iter << " minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (iter * batch_size / static_cast<double>(corpus.size())) << " eta=" << lr->eta(iter) << endl;
const int beg = size * miter * size_per_proc + rank * size_per_proc;
const int end = beg + size_per_proc;