From ed2b1ce7181fb21d2363b0d5ce04d7fa52aef0fa Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 01:08:45 +0100 Subject: enable ramdisk scratch for per-sentence-grammars --- training/mpi_batch_optimize.cc | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) (limited to 'training/mpi_batch_optimize.cc') diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index cc5953f6..0ba8c530 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -22,6 +22,7 @@ namespace mpi = boost::mpi; #include "ff_register.h" #include "decoder.h" #include "filelib.h" +#include "stringlib.h" #include "optimize.h" #include "fdict.h" #include "weights.h" @@ -42,6 +43,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("correction_buffers,M", po::value()->default_value(10), "Number of gradients for LBFGS to maintain in memory") ("gaussian_prior,p","Use a Gaussian prior on the weights") ("means,u", po::value(), "File containing the means for Gaussian prior") + ("per_sentence_grammar_scratch,P", po::value(), "(Optional) location of scratch space to copy per-sentence grammars for fast access, useful if a RAM disk is available") ("sigma_squared", po::value()->default_value(1.0), "Sigma squared term for spherical Gaussian prior"); po::options_description clo("Command line options"); clo.add_options() @@ -186,6 +188,36 @@ struct VectorPlus : public binary_function, vector, vector > { } }; +void MovePerSentenceGrammars(const string& root, int size, int rank, vector* c) { + if (!DirectoryExists(root)) { + cerr << "Can't find scratch space at " << root << endl; + abort(); + } + ostringstream os; + os << root << "/psg." << size << "_of_" << rank; + const string path = os.str(); + MkDirP(path); + string sent; + map attr; + for (unsigned i = 0; i < c->size(); ++i) { + sent = (*c)[i]; + attr.clear(); + ProcessAndStripSGML(&sent, &attr); + map::iterator it = attr.find("grammar"); + if (it != attr.end()) { + string src_file = it->second; + bool is_gzipped = (src_file.size() > 3) && (src_file.rfind(".gz") == (src_file.size() - 3)); + string new_name = path + "/" + md5(sent); + if (is_gzipped) new_name += ".gz"; + CopyFile(src_file, new_name); + it->second = new_name; + } + ostringstream ns; + ns << SGMLOpenSegTag(attr) << ' ' << sent << " "; + (*c)[i] = ns.str(); + } +} + int main(int argc, char** argv) { #ifdef HAVE_MPI mpi::environment env(argc, argv); @@ -257,6 +289,9 @@ int main(int argc, char** argv) { ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); assert(corpus.size() > 0); + if (conf.count("per_sentence_grammar_scratch")) + MovePerSentenceGrammars(conf["per_sentence_grammar_scratch"].as(), rank, size, &corpus); + TrainingObserver observer; while (!converged) { observer.Reset(); -- cgit v1.2.3