diff options
| author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-09-17 01:08:45 +0100 | 
|---|---|---|
| committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-09-17 01:08:45 +0100 | 
| commit | a28c48d07df4e426a875f5381c80ebf4fbbd1de2 (patch) | |
| tree | 96368878b6ea59dd235517fd712e4eff6fd6214b | |
| parent | b70a0be1c34bd177e8ac7c53cb466f226008cc52 (diff) | |
enable ramdisk scratch for per-sentence-grammars
| -rw-r--r-- | training/mpi_batch_optimize.cc | 35 | ||||
| -rw-r--r-- | utils/filelib.cc | 19 | ||||
| -rw-r--r-- | utils/filelib.h | 5 | 
3 files changed, 55 insertions, 4 deletions
| diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index cc5953f6..0ba8c530 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -22,6 +22,7 @@ namespace mpi = boost::mpi;  #include "ff_register.h"  #include "decoder.h"  #include "filelib.h" +#include "stringlib.h"  #include "optimize.h"  #include "fdict.h"  #include "weights.h" @@ -42,6 +43,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {  	("correction_buffers,M", po::value<int>()->default_value(10), "Number of gradients for LBFGS to maintain in memory")          ("gaussian_prior,p","Use a Gaussian prior on the weights")          ("means,u", po::value<string>(), "File containing the means for Gaussian prior") +        ("per_sentence_grammar_scratch,P", po::value<string>(), "(Optional) location of scratch space to copy per-sentence grammars for fast access, useful if a RAM disk is available")          ("sigma_squared", po::value<double>()->default_value(1.0), "Sigma squared term for spherical Gaussian prior");    po::options_description clo("Command line options");    clo.add_options() @@ -186,6 +188,36 @@ struct VectorPlus : public binary_function<vector<T>, vector<T>, vector<T> >  {    }   };  +void MovePerSentenceGrammars(const string& root, int size, int rank, vector<string>* c) { +  if (!DirectoryExists(root)) { +    cerr << "Can't find scratch space at " << root << endl; +    abort(); +  } +  ostringstream os; +  os << root << "/psg." << size << "_of_" << rank; +  const string path = os.str(); +  MkDirP(path); +  string sent; +  map<string, string> attr; +  for (unsigned i = 0; i < c->size(); ++i) { +    sent = (*c)[i]; +    attr.clear(); +    ProcessAndStripSGML(&sent, &attr); +    map<string, string>::iterator it = attr.find("grammar"); +    if (it != attr.end()) { +      string src_file = it->second; +      bool is_gzipped = (src_file.size() > 3) && (src_file.rfind(".gz") == (src_file.size() - 3)); +      string new_name = path + "/" + md5(sent); +      if (is_gzipped) new_name += ".gz"; +      CopyFile(src_file, new_name); +      it->second = new_name; +    } +    ostringstream ns; +    ns << SGMLOpenSegTag(attr) << ' ' << sent << " </seg>"; +    (*c)[i] = ns.str(); +  } +} +  int main(int argc, char** argv) {  #ifdef HAVE_MPI    mpi::environment env(argc, argv); @@ -257,6 +289,9 @@ int main(int argc, char** argv) {    ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus);    assert(corpus.size() > 0); +  if (conf.count("per_sentence_grammar_scratch")) +    MovePerSentenceGrammars(conf["per_sentence_grammar_scratch"].as<string>(), rank, size, &corpus); +    TrainingObserver observer;    while (!converged) {      observer.Reset(); diff --git a/utils/filelib.cc b/utils/filelib.cc index a0969b1a..d206fc19 100644 --- a/utils/filelib.cc +++ b/utils/filelib.cc @@ -2,6 +2,12 @@  #include <unistd.h>  #include <sys/stat.h> +#include <sys/types.h> +#include <sys/socket.h> +#include <cstdlib> +#include <cstdio> +#include <sys/stat.h> +#include <sys/types.h>  using namespace std; @@ -32,3 +38,16 @@ void MkDirP(const string& dir) {    }  } +#if 0 +void CopyFile(const string& inf, const string& outf) { +  WriteFile w(outf); +  CopyFile(inf,*w); +} +#else +void CopyFile(const string& inf, const string& outf) { +  ofstream of(outf.c_str(), fstream::trunc|fstream::binary); +  ifstream in(inf.c_str(), fstream::binary); +  of << in.rdbuf(); +} +#endif + diff --git a/utils/filelib.h b/utils/filelib.h index a8622246..bb6e7415 100644 --- a/utils/filelib.h +++ b/utils/filelib.h @@ -113,9 +113,6 @@ inline void CopyFile(std::string const& inf,std::ostream &out) {    CopyFile(*r,out);  } -inline void CopyFile(std::string const& inf,std::string const& outf) { -  WriteFile w(outf); -  CopyFile(inf,*w); -} +void CopyFile(std::string const& inf,std::string const& outf);  #endif | 
