diff options
Diffstat (limited to 'training/compute_cllh.cc')
-rw-r--r-- | training/compute_cllh.cc | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/training/compute_cllh.cc b/training/compute_cllh.cc index f25e17c3..332f6d0c 100644 --- a/training/compute_cllh.cc +++ b/training/compute_cllh.cc @@ -5,8 +5,10 @@ #include <cassert> #include <cmath> -#include <mpi.h> +#include "config.h" +#ifdef HAVE_MPI #include <boost/mpi.hpp> +#endif #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> @@ -22,7 +24,7 @@ using namespace std; namespace po = boost::program_options; -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("weights,w",po::value<string>(),"Input feature weights file") @@ -45,9 +47,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { cerr << dcmdline_options << endl; - MPI::Finalize(); - exit(1); + return false; } + return true; } void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c, vector<int>* ids) { @@ -125,18 +127,26 @@ struct TrainingObserver : public DecoderObserver { int state; }; +#ifdef HAVE_MPI namespace mpi = boost::mpi; +#endif int main(int argc, char** argv) { +#ifdef HAVE_MPI mpi::environment env(argc, argv); mpi::communicator world; const int size = world.size(); const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif if (size > 1) SetSilent(true); // turn off verbose decoder output register_feature_functions(); po::variables_map conf; - InitCommandLine(argc, argv, &conf); + if (!InitCommandLine(argc, argv, &conf)) + return false; // load initial weights Weights weights; @@ -176,7 +186,11 @@ int main(int argc, char** argv) { decoder.Decode(corpus[i], &observer); } +#ifdef HAVE_MPI reduce(world, observer.acc_obj, objective, std::plus<double>(), 0); +#else + objective = observer.acc_obj; +#endif if (rank == 0) cout << "OBJECTIVE: " << objective << endl; |