summaryrefslogtreecommitdiff
path: root/training/mpi_batch_optimize.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-09-17 01:08:45 +0100
committerChris Dyer <cdyer@cs.cmu.edu>2011-09-17 01:08:45 +0100
commita28c48d07df4e426a875f5381c80ebf4fbbd1de2 (patch)
tree96368878b6ea59dd235517fd712e4eff6fd6214b /training/mpi_batch_optimize.cc
parentb70a0be1c34bd177e8ac7c53cb466f226008cc52 (diff)
enable ramdisk scratch for per-sentence-grammars
Diffstat (limited to 'training/mpi_batch_optimize.cc')
-rw-r--r--training/mpi_batch_optimize.cc35
1 files changed, 35 insertions, 0 deletions
diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc
index cc5953f6..0ba8c530 100644
--- a/training/mpi_batch_optimize.cc
+++ b/training/mpi_batch_optimize.cc
@@ -22,6 +22,7 @@ namespace mpi = boost::mpi;
#include "ff_register.h"
#include "decoder.h"
#include "filelib.h"
+#include "stringlib.h"
#include "optimize.h"
#include "fdict.h"
#include "weights.h"
@@ -42,6 +43,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("correction_buffers,M", po::value<int>()->default_value(10), "Number of gradients for LBFGS to maintain in memory")
("gaussian_prior,p","Use a Gaussian prior on the weights")
("means,u", po::value<string>(), "File containing the means for Gaussian prior")
+ ("per_sentence_grammar_scratch,P", po::value<string>(), "(Optional) location of scratch space to copy per-sentence grammars for fast access, useful if a RAM disk is available")
("sigma_squared", po::value<double>()->default_value(1.0), "Sigma squared term for spherical Gaussian prior");
po::options_description clo("Command line options");
clo.add_options()
@@ -186,6 +188,36 @@ struct VectorPlus : public binary_function<vector<T>, vector<T>, vector<T> > {
}
};
+void MovePerSentenceGrammars(const string& root, int size, int rank, vector<string>* c) {
+ if (!DirectoryExists(root)) {
+ cerr << "Can't find scratch space at " << root << endl;
+ abort();
+ }
+ ostringstream os;
+ os << root << "/psg." << size << "_of_" << rank;
+ const string path = os.str();
+ MkDirP(path);
+ string sent;
+ map<string, string> attr;
+ for (unsigned i = 0; i < c->size(); ++i) {
+ sent = (*c)[i];
+ attr.clear();
+ ProcessAndStripSGML(&sent, &attr);
+ map<string, string>::iterator it = attr.find("grammar");
+ if (it != attr.end()) {
+ string src_file = it->second;
+ bool is_gzipped = (src_file.size() > 3) && (src_file.rfind(".gz") == (src_file.size() - 3));
+ string new_name = path + "/" + md5(sent);
+ if (is_gzipped) new_name += ".gz";
+ CopyFile(src_file, new_name);
+ it->second = new_name;
+ }
+ ostringstream ns;
+ ns << SGMLOpenSegTag(attr) << ' ' << sent << " </seg>";
+ (*c)[i] = ns.str();
+ }
+}
+
int main(int argc, char** argv) {
#ifdef HAVE_MPI
mpi::environment env(argc, argv);
@@ -257,6 +289,9 @@ int main(int argc, char** argv) {
ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus);
assert(corpus.size() > 0);
+ if (conf.count("per_sentence_grammar_scratch"))
+ MovePerSentenceGrammars(conf["per_sentence_grammar_scratch"].as<string>(), rank, size, &corpus);
+
TrainingObserver observer;
while (!converged) {
observer.Reset();