summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src/pyp-topics.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics/src/pyp-topics.cc')
-rw-r--r--gi/pyp-topics/src/pyp-topics.cc105
1 files changed, 54 insertions, 51 deletions
diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc
index 76f95b2a..e528a923 100644
--- a/gi/pyp-topics/src/pyp-topics.cc
+++ b/gi/pyp-topics/src/pyp-topics.cc
@@ -15,6 +15,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
std::cerr << "\n Training with " << m_word_pyps.size()-1 << " backoff level"
<< (m_word_pyps.size()==2 ? ":" : "s:") << std::endl;
+
for (int i=0; i<(int)m_word_pyps.size(); ++i)
{
m_word_pyps.at(i).reserve(m_num_topics);
@@ -76,6 +77,10 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
for (int i = 0; i < corpus.num_documents(); ++i)
randomDocIndices[i] = i;
+ if (num_jobs < max_threads)
+ num_jobs = max_threads;
+ int job_incr = (int) ( (float)m_document_pyps.size() / float(num_jobs) );
+
// Sampling phase
for (int curr_sample=0; curr_sample < samples; ++curr_sample) {
if (freq_cutoff_interval > 0 && curr_sample != 1
@@ -149,33 +154,38 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
if (curr_sample != 0 && curr_sample % 10 == 0) {
std::cerr << " ||| time=" << (timer.Elapsed() / 10.0) << " sec/sample" << std::endl;
timer.Reset();
- std::cerr << " ... Resampling hyperparameters (" << max_threads << " threads)"; std::cerr.flush();
-
+ std::cerr << " ... Resampling hyperparameters (";
+
// resample the hyperparamters
F log_p=0.0;
- for (std::vector<PYPs>::iterator levelIt=m_word_pyps.begin();
- levelIt != m_word_pyps.end(); ++levelIt) {
- for (PYPs::iterator pypIt=levelIt->begin();
- pypIt != levelIt->end(); ++pypIt) {
- pypIt->resample_prior();
- log_p += pypIt->log_restaurant_prob();
- }
- }
-
- WorkerPtrVect workers;
- for (int i = 0; i < max_threads; ++i)
- {
- JobReturnsF job = boost::bind(&PYPTopics::hresample_docs, this, max_threads, i);
- workers.push_back(new SimpleResampleWorker(job));
+ if (max_threads == 1)
+ {
+ std::cerr << "1 thread)" << std::endl; std::cerr.flush();
+ log_p += hresample_topics();
+ log_p += hresample_docs(0, m_document_pyps.size());
}
+ else
+ { //parallelize
+ std::cerr << max_threads << " threads, " << num_jobs << " jobs)" << std::endl; std::cerr.flush();
+
+ WorkerPool<JobReturnsF, F> pool(max_threads);
+ int i=0, sz = m_document_pyps.size();
+ //documents...
+ while (i <= sz - 2*job_incr)
+ {
+ JobReturnsF job = boost::bind(&PYPTopics::hresample_docs, this, i, i+job_incr);
+ pool.addJob(job);
+ i += job_incr;
+ }
+ // do all remaining documents
+ JobReturnsF job = boost::bind(&PYPTopics::hresample_docs, this, i,sz);
+ pool.addJob(job);
+
+ //topics...
+ JobReturnsF topics_job = boost::bind(&PYPTopics::hresample_topics, this);
+ pool.addJob(topics_job);
- WorkerPtrVect::iterator workerIt;
- for (workerIt = workers.begin(); workerIt != workers.end(); ++workerIt)
- {
- //std::cerr << "Retrieving worker result.."; std::cerr.flush();
- F wresult = workerIt->getResult(); //blocks until worker done
- log_p += wresult;
- //std::cerr << ".. got " << wresult << std::endl; std::cerr.flush();
+ log_p += pool.get_result(); //blocks
}
@@ -204,45 +214,38 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
delete [] randomDocIndices;
}
-PYPTopics::F PYPTopics::hresample_docs(int num_threads, int thread_id)
+PYPTopics::F PYPTopics::hresample_docs(int start, int end)
{
int resample_counter=0;
F log_p = 0.0;
- PYPs::iterator pypIt = m_document_pyps.begin();
- PYPs::iterator end = m_document_pyps.end();
- pypIt += thread_id;
-// std::cerr << thread_id << " started " << std::endl; std::cerr.flush();
-
- while (pypIt < end)
+ assert(start >= 0);
+ assert(end >= 0);
+ assert(start <= end);
+ for (int i=start; i < end; ++i)
{
- pypIt->resample_prior();
- log_p += pypIt->log_restaurant_prob();
+ m_document_pyps[i].resample_prior();
+ log_p += m_document_pyps[i].log_restaurant_prob();
if (resample_counter++ % 5000 == 0) {
std::cerr << "."; std::cerr.flush();
}
- pypIt += num_threads;
}
-// std::cerr << thread_id << " did " << resample_counter << " with answer " << log_p << std::endl; std::cerr.flush();
-
return log_p;
}
-//PYPTopics::F PYPTopics::hresample_topics()
-//{
-// F log_p = 0.0;
-// for (std::vector<PYPs>::iterator levelIt=m_word_pyps.begin();
-// levelIt != m_word_pyps.end(); ++levelIt) {
-// for (PYPs::iterator pypIt=levelIt->begin();
-// pypIt != levelIt->end(); ++pypIt) {
-//
-// pypIt->resample_prior();
-// log_p += pypIt->log_restaurant_prob();
-// }
-// }
-// //std::cerr << "topicworker has answer " << log_p << std::endl; std::cerr.flush();
-//
-// return log_p;
-//}
+PYPTopics::F PYPTopics::hresample_topics()
+{
+ F log_p = 0.0;
+ for (std::vector<PYPs>::iterator levelIt=m_word_pyps.begin();
+ levelIt != m_word_pyps.end(); ++levelIt) {
+ for (PYPs::iterator pypIt=levelIt->begin();
+ pypIt != levelIt->end(); ++pypIt) {
+
+ pypIt->resample_prior();
+ log_p += pypIt->log_restaurant_prob();
+ }
+ }
+ return log_p;
+}
void PYPTopics::decrement(const Term& term, int topic, int level) {
//std::cerr << "PYPTopics::decrement(" << term << "," << topic << "," << level << ")" << std::endl;