summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src/mpi-pyp-topics.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics/src/mpi-pyp-topics.cc')
-rw-r--r--gi/pyp-topics/src/mpi-pyp-topics.cc99
1 files changed, 55 insertions, 44 deletions
diff --git a/gi/pyp-topics/src/mpi-pyp-topics.cc b/gi/pyp-topics/src/mpi-pyp-topics.cc
index 4525302e..50db61c1 100644
--- a/gi/pyp-topics/src/mpi-pyp-topics.cc
+++ b/gi/pyp-topics/src/mpi-pyp-topics.cc
@@ -8,7 +8,6 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples,
int freq_cutoff_start, int freq_cutoff_end,
int freq_cutoff_interval,
int max_contexts_per_document) {
- std::cout << "I am process " << m_rank << " of " << m_size << "." << std::endl;
Timer timer;
std::cout << m_am_root << std::endl;
@@ -85,7 +84,7 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples,
F p0 = m_topic_pyp.prob(new_topic, m_topic_p0);
int table_delta = m_document_pyps.at(i).increment(new_topic, p0);
if (table_delta)
- m_topic_pyp.increment(new_topic, m_topic_p0);
+ m_topic_pyp.increment(new_topic, m_topic_p0, rnd);
}
else m_document_pyps.at(i).increment(new_topic, m_topic_p0);
}
@@ -95,23 +94,7 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples,
}
// Synchronise the topic->word counds across the processes.
- for (std::vector<MPIPYPs>::iterator levelIt=m_word_pyps.begin();
- levelIt != m_word_pyps.end(); ++levelIt) {
- for (MPIPYPs::iterator pypIt=levelIt->begin();
- pypIt != levelIt->end(); ++pypIt) {
- if (!m_am_root) boost::mpi::communicator().barrier();
- std::cerr << "Before Sync Process " << m_rank << ":";
- pypIt->debug_info(std::cerr); std::cerr << std::endl;
- if (m_am_root) boost::mpi::communicator().barrier();
-
- pypIt->synchronise();
-
- if (!m_am_root) boost::mpi::communicator().barrier();
- std::cerr << "After Sync Process " << m_rank << ":";
- pypIt->debug_info(std::cerr); std::cerr << std::endl;
- if (m_am_root) boost::mpi::communicator().barrier();
- }
- }
+ synchronise();
if (m_am_root) std::cerr << " Initialized in " << timer.Elapsed() << " seconds\n";
@@ -172,7 +155,7 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples,
int table_delta = m_document_pyps.at(doc_index).decrement(current_topic);
if (m_use_topic_pyp && table_delta < 0)
- m_topic_pyp.decrement(current_topic);
+ m_topic_pyp.decrement(current_topic, rnd);
}
// sample a new_topic
@@ -186,7 +169,7 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples,
F p0 = m_topic_pyp.prob(new_topic, m_topic_p0);
int table_delta = m_document_pyps.at(doc_index).increment(new_topic, p0);
if (table_delta)
- m_topic_pyp.increment(new_topic, m_topic_p0);
+ m_topic_pyp.increment(new_topic, m_topic_p0, rnd);
}
else m_document_pyps.at(doc_index).increment(new_topic, m_topic_p0);
}
@@ -194,19 +177,10 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples,
if (m_am_root) std::cerr << "."; std::cerr.flush();
}
}
- m_world.barrier();
+ std::cerr << "|"; std::cerr.flush();
+
// Synchronise the topic->word counds across the processes.
- for (std::vector<MPIPYPs>::iterator levelIt=m_word_pyps.begin();
- levelIt != m_word_pyps.end(); ++levelIt) {
- for (MPIPYPs::iterator pypIt=levelIt->begin();
- pypIt != levelIt->end(); ++pypIt) {
- std::cerr << "Before Sync Process " << m_rank << ":";
- pypIt->debug_info(std::cerr); std::cerr << std::endl;
- pypIt->synchronise();
- std::cerr << "After Sync Process " << m_rank << ":";
- pypIt->debug_info(std::cerr); std::cerr << std::endl;
- }
- }
+ synchronise();
if (m_am_root) std::cerr << " ||| sampled " << processed_terms << " terms.";
@@ -221,19 +195,19 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples,
levelIt != m_word_pyps.end(); ++levelIt) {
for (MPIPYPs::iterator pypIt=levelIt->begin();
pypIt != levelIt->end(); ++pypIt) {
- pypIt->resample_prior();
+ pypIt->resample_prior(rnd);
log_p += pypIt->log_restaurant_prob();
}
}
for (PYPs::iterator pypIt=m_document_pyps.begin();
pypIt != m_document_pyps.end(); ++pypIt) {
- pypIt->resample_prior();
+ pypIt->resample_prior(rnd);
log_p += pypIt->log_restaurant_prob();
}
if (m_use_topic_pyp) {
- m_topic_pyp.resample_prior();
+ m_topic_pyp.resample_prior(rnd);
log_p += m_topic_pyp.log_restaurant_prob();
}
@@ -257,10 +231,44 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples,
delete [] randomDocIndices;
}
+void MPIPYPTopics::synchronise() {
+ // Synchronise the topic->word counds across the processes.
+ //for (std::vector<MPIPYPs>::iterator levelIt=m_word_pyps.begin();
+ // levelIt != m_word_pyps.end(); ++levelIt) {
+// std::vector<MPIPYPs>::iterator levelIt=m_word_pyps.begin();
+// {
+// for (MPIPYPs::iterator pypIt=levelIt->begin(); pypIt != levelIt->end(); ++pypIt) {
+ for (size_t label=0; label < m_word_pyps.at(0).size(); ++label) {
+ MPIPYP<int>& pyp = m_word_pyps.at(0).at(label);
+
+ //if (!m_am_root) boost::mpi::communicator().barrier();
+ //std::cerr << "Before Sync Process " << m_rank << ":";
+ //pyp.debug_info(std::cerr); std::cerr << std::endl;
+ //if (m_am_root) boost::mpi::communicator().barrier();
+
+ MPIPYP<int>::dish_delta_type delta;
+ pyp.synchronise(&delta);
+
+ for (MPIPYP<int>::dish_delta_type::const_iterator it=delta.begin(); it != delta.end(); ++it) {
+ int count = it->second;
+ if (count > 0)
+ for (int i=0; i < count; ++i) increment(it->first, label);
+ if (count < 0)
+ for (int i=0; i > count; --i) decrement(it->first, label);
+ }
+ pyp.reset_deltas();
+
+ //if (!m_am_root) boost::mpi::communicator().barrier();
+ //std::cerr << "After Sync Process " << m_rank << ":";
+ //pyp.debug_info(std::cerr); std::cerr << std::endl;
+ //if (m_am_root) boost::mpi::communicator().barrier();
+ }
+// }
+}
void MPIPYPTopics::decrement(const Term& term, int topic, int level) {
//std::cerr << "MPIPYPTopics::decrement(" << term << "," << topic << "," << level << ")" << std::endl;
- m_word_pyps.at(level).at(topic).decrement(term);
+ m_word_pyps.at(level).at(topic).decrement(term, rnd);
if (m_backoff.get()) {
Term backoff_term = (*m_backoff)[term];
if (!m_backoff->is_null(backoff_term))
@@ -270,7 +278,7 @@ void MPIPYPTopics::decrement(const Term& term, int topic, int level) {
void MPIPYPTopics::increment(const Term& term, int topic, int level) {
//std::cerr << "MPIPYPTopics::increment(" << term << "," << topic << "," << level << ")" << std::endl;
- m_word_pyps.at(level).at(topic).increment(term, word_pyps_p0(term, topic, level));
+ m_word_pyps.at(level).at(topic).increment(term, word_pyps_p0(term, topic, level), rnd);
if (m_backoff.get()) {
Term backoff_term = (*m_backoff)[term];
@@ -301,6 +309,7 @@ int MPIPYPTopics::sample(const DocumentId& doc, const Term& term) {
if (cutoff <= sums[k])
return k;
}
+ std::cerr << cutoff << " " << sum << std::endl;
assert(false);
}
@@ -355,10 +364,11 @@ int MPIPYPTopics::max_topic() const {
return current_topic;
}
-int MPIPYPTopics::max(const DocumentId& doc) const {
+int MPIPYPTopics::max(const DocumentId& true_doc) const {
//std::cerr << "MPIPYPTopics::max(" << doc << "," << term << ")" << std::endl;
// collect probs
F current_max=0.0;
+ DocumentId local_doc = true_doc - m_mpi_start;
int current_topic=-1;
for (int k=0; k<m_num_topics; ++k) {
//F p_w_k = prob(term, k);
@@ -368,8 +378,8 @@ int MPIPYPTopics::max(const DocumentId& doc) const {
topic_prob = m_topic_pyp.prob(k, m_topic_p0);
F prob = 0;
- if (doc < 0) prob = topic_prob;
- else prob = m_document_pyps[doc].prob(k, topic_prob);
+ if (local_doc < 0) prob = topic_prob;
+ else prob = m_document_pyps.at(local_doc).prob(k, topic_prob);
if (prob > current_max) {
current_max = prob;
@@ -380,10 +390,11 @@ int MPIPYPTopics::max(const DocumentId& doc) const {
return current_topic;
}
-int MPIPYPTopics::max(const DocumentId& doc, const Term& term) const {
+int MPIPYPTopics::max(const DocumentId& true_doc, const Term& term) const {
//std::cerr << "MPIPYPTopics::max(" << doc << "," << term << ")" << std::endl;
// collect probs
F current_max=0.0;
+ DocumentId local_doc = true_doc - m_mpi_start;
int current_topic=-1;
for (int k=0; k<m_num_topics; ++k) {
F p_w_k = prob(term, k);
@@ -393,8 +404,8 @@ int MPIPYPTopics::max(const DocumentId& doc, const Term& term) const {
topic_prob = m_topic_pyp.prob(k, m_topic_p0);
F p_k_d = 0;
- if (doc < 0) p_k_d = topic_prob;
- else p_k_d = m_document_pyps[doc].prob(k, topic_prob);
+ if (local_doc < 0) p_k_d = topic_prob;
+ else p_k_d = m_document_pyps.at(local_doc).prob(k, topic_prob);
F prob = (p_w_k*p_k_d);
if (prob > current_max) {