diff options
Diffstat (limited to 'gi/pyp-topics/src/pyp-topics.cc')
| -rw-r--r-- | gi/pyp-topics/src/pyp-topics.cc | 105 | 
1 files changed, 54 insertions, 51 deletions
| diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc index 76f95b2a..e528a923 100644 --- a/gi/pyp-topics/src/pyp-topics.cc +++ b/gi/pyp-topics/src/pyp-topics.cc @@ -15,6 +15,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,    std::cerr << "\n Training with " << m_word_pyps.size()-1 << " backoff level"      << (m_word_pyps.size()==2 ? ":" : "s:") << std::endl; +    for (int i=0; i<(int)m_word_pyps.size(); ++i)    {      m_word_pyps.at(i).reserve(m_num_topics); @@ -76,6 +77,10 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,    for (int i = 0; i < corpus.num_documents(); ++i)  	  randomDocIndices[i] = i; +  if (num_jobs < max_threads) +    num_jobs = max_threads; +  int job_incr = (int) ( (float)m_document_pyps.size() / float(num_jobs) ); +    // Sampling phase    for (int curr_sample=0; curr_sample < samples; ++curr_sample) {      if (freq_cutoff_interval > 0 && curr_sample != 1 @@ -149,33 +154,38 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,      if (curr_sample != 0 && curr_sample % 10 == 0) {        std::cerr << " ||| time=" << (timer.Elapsed() / 10.0) << " sec/sample" << std::endl;        timer.Reset(); -      std::cerr << "     ... Resampling hyperparameters (" << max_threads << " threads)"; std::cerr.flush(); - +      std::cerr << "     ... Resampling hyperparameters ("; +              // resample the hyperparamters        F log_p=0.0; -      for (std::vector<PYPs>::iterator levelIt=m_word_pyps.begin(); -           levelIt != m_word_pyps.end(); ++levelIt) { -        for (PYPs::iterator pypIt=levelIt->begin(); -             pypIt != levelIt->end(); ++pypIt) { -          pypIt->resample_prior(); -          log_p += pypIt->log_restaurant_prob(); -        } -      } - -      WorkerPtrVect workers; -      for (int i = 0; i < max_threads; ++i) -      { -        JobReturnsF job = boost::bind(&PYPTopics::hresample_docs, this, max_threads, i); -        workers.push_back(new SimpleResampleWorker(job)); +      if (max_threads == 1) +      {  +        std::cerr << "1 thread)" << std::endl; std::cerr.flush(); +        log_p += hresample_topics(); +        log_p += hresample_docs(0, m_document_pyps.size());        } +      else +      { //parallelize +        std::cerr << max_threads << " threads, " << num_jobs << " jobs)" << std::endl; std::cerr.flush(); +         +        WorkerPool<JobReturnsF, F> pool(max_threads);  +        int i=0, sz = m_document_pyps.size(); +        //documents... +        while (i <= sz - 2*job_incr) +        {     +          JobReturnsF job = boost::bind(&PYPTopics::hresample_docs, this, i, i+job_incr); +          pool.addJob(job); +          i += job_incr; +        } +        //  do all remaining documents +        JobReturnsF job = boost::bind(&PYPTopics::hresample_docs, this, i,sz); +        pool.addJob(job); +         +        //topics... +        JobReturnsF topics_job = boost::bind(&PYPTopics::hresample_topics, this); +        pool.addJob(topics_job); -      WorkerPtrVect::iterator workerIt; -      for (workerIt = workers.begin(); workerIt != workers.end(); ++workerIt) -      { -        //std::cerr << "Retrieving worker result.."; std::cerr.flush(); -        F wresult = workerIt->getResult(); //blocks until worker done -        log_p += wresult; -        //std::cerr << ".. got " << wresult << std::endl; std::cerr.flush(); +        log_p += pool.get_result(); //blocks        } @@ -204,45 +214,38 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,    delete [] randomDocIndices;  } -PYPTopics::F PYPTopics::hresample_docs(int num_threads, int thread_id) +PYPTopics::F PYPTopics::hresample_docs(int start, int end)  {    int resample_counter=0;    F log_p = 0.0; -  PYPs::iterator pypIt = m_document_pyps.begin(); -  PYPs::iterator end = m_document_pyps.end(); -  pypIt += thread_id; -//  std::cerr << thread_id << " started " << std::endl; std::cerr.flush(); - -  while (pypIt < end) +  assert(start >= 0); +  assert(end >= 0); +  assert(start <= end); +  for (int i=start; i < end; ++i)    { -    pypIt->resample_prior(); -    log_p += pypIt->log_restaurant_prob(); +    m_document_pyps[i].resample_prior(); +    log_p += m_document_pyps[i].log_restaurant_prob();      if (resample_counter++ % 5000 == 0) {        std::cerr << "."; std::cerr.flush();      } -    pypIt += num_threads;    } -//  std::cerr << thread_id << " did " << resample_counter << " with answer " << log_p << std::endl; std::cerr.flush(); -    return log_p;  } -//PYPTopics::F PYPTopics::hresample_topics() -//{ -//  F log_p = 0.0; -//  for (std::vector<PYPs>::iterator levelIt=m_word_pyps.begin(); -//      levelIt != m_word_pyps.end(); ++levelIt) { -//    for (PYPs::iterator pypIt=levelIt->begin(); -//        pypIt != levelIt->end(); ++pypIt) { -// -//      pypIt->resample_prior(); -//      log_p += pypIt->log_restaurant_prob(); -//    } -//  } -//  //std::cerr << "topicworker has answer " << log_p << std::endl; std::cerr.flush(); -// -// return log_p; -//} +PYPTopics::F PYPTopics::hresample_topics() +{ +  F log_p = 0.0; +  for (std::vector<PYPs>::iterator levelIt=m_word_pyps.begin(); +      levelIt != m_word_pyps.end(); ++levelIt) { +    for (PYPs::iterator pypIt=levelIt->begin(); +        pypIt != levelIt->end(); ++pypIt) { + +      pypIt->resample_prior(); +      log_p += pypIt->log_restaurant_prob(); +    } +  } +  return log_p; +}  void PYPTopics::decrement(const Term& term, int topic, int level) {    //std::cerr << "PYPTopics::decrement(" << term << "," << topic << "," << level << ")" << std::endl; | 
