summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-09-22 18:28:00 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-09-22 18:28:00 +0000
commit46c1d1c56535cdaf71e920b12f1cbdcbd1bd9d4f (patch)
tree7ba40a52456b967200a0f09e9aed939e79af81e4 /training
parent716356a5de1eb7d066d511af51b76e1294609d87 (diff)
few fixes
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@655 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training')
-rw-r--r--training/mpi_online_optimize.cc18
1 files changed, 14 insertions, 4 deletions
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc
index 95b462bb..62821aa3 100644
--- a/training/mpi_online_optimize.cc
+++ b/training/mpi_online_optimize.cc
@@ -85,7 +85,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
}
po::notify(*conf);
- if (conf->count("help") || !conf->count("input_weights") || !conf->count("training_data") || !conf->count("decoder_config")) {
+ if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) {
cerr << dcmdline_options << endl;
MPI::Finalize();
exit(1);
@@ -201,14 +201,14 @@ int main(int argc, char** argv) {
cerr << "MPI: I am " << rank << '/' << size << endl;
register_feature_functions();
MT19937* rng = NULL;
- if (rank == 0) rng = new MT19937;
po::variables_map conf;
InitCommandLine(argc, argv, &conf);
// load initial weights
Weights weights;
- weights.InitFromFile(conf["input_weights"].as<string>());
+ if (conf.count("input_weights"))
+ weights.InitFromFile(conf["input_weights"].as<string>());
// freeze feature set
const bool freeze_feature_set = conf.count("freeze_feature_set");
@@ -228,6 +228,7 @@ int main(int argc, char** argv) {
std::tr1::shared_ptr<OnlineOptimizer> o;
std::tr1::shared_ptr<LearningRateSchedule> lr;
+ vector<int> order;
if (rank == 0) {
// TODO config
lr.reset(new ExponentialDecayLearningRate(corpus.size(), conf["eta_0"].as<double>()));
@@ -239,12 +240,21 @@ int main(int argc, char** argv) {
} else {
assert(!"fail");
}
+
+ // randomize corpus
+ rng = new MT19937;
+ order.resize(corpus.size());
+ for (unsigned i = 0; i < order.size(); ++i) order[i]=i;
+ Shuffle(&order, rng);
}
double objective = 0;
vector<double> lambdas;
weights.InitVector(&lambdas);
bool converged = false;
-
+ const unsigned size_per_proc = conf["minibatch_size_per_proc"].as<unsigned>();
+ for (int i = 0; i < size_per_proc; ++i)
+ cerr << "i=" << i << ": " << order[i] << endl;
+ abort();
TrainingObserver observer;
while (!converged) {
observer.Reset();