From 73dbb0343a895345a80d49da9d48edac8858e87a Mon Sep 17 00:00:00 2001 From: philblunsom Date: Mon, 19 Jul 2010 18:33:29 +0000 Subject: Vaguely working distributed implementation. Hierarchical topics doesn't yet work correctly. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@317 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/pyp-topics/src/mpi-pyp-topics.cc | 99 ++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 44 deletions(-) (limited to 'gi/pyp-topics/src/mpi-pyp-topics.cc') diff --git a/gi/pyp-topics/src/mpi-pyp-topics.cc b/gi/pyp-topics/src/mpi-pyp-topics.cc index 4525302e..50db61c1 100644 --- a/gi/pyp-topics/src/mpi-pyp-topics.cc +++ b/gi/pyp-topics/src/mpi-pyp-topics.cc @@ -8,7 +8,6 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples, int freq_cutoff_start, int freq_cutoff_end, int freq_cutoff_interval, int max_contexts_per_document) { - std::cout << "I am process " << m_rank << " of " << m_size << "." << std::endl; Timer timer; std::cout << m_am_root << std::endl; @@ -85,7 +84,7 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples, F p0 = m_topic_pyp.prob(new_topic, m_topic_p0); int table_delta = m_document_pyps.at(i).increment(new_topic, p0); if (table_delta) - m_topic_pyp.increment(new_topic, m_topic_p0); + m_topic_pyp.increment(new_topic, m_topic_p0, rnd); } else m_document_pyps.at(i).increment(new_topic, m_topic_p0); } @@ -95,23 +94,7 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples, } // Synchronise the topic->word counds across the processes. - for (std::vector::iterator levelIt=m_word_pyps.begin(); - levelIt != m_word_pyps.end(); ++levelIt) { - for (MPIPYPs::iterator pypIt=levelIt->begin(); - pypIt != levelIt->end(); ++pypIt) { - if (!m_am_root) boost::mpi::communicator().barrier(); - std::cerr << "Before Sync Process " << m_rank << ":"; - pypIt->debug_info(std::cerr); std::cerr << std::endl; - if (m_am_root) boost::mpi::communicator().barrier(); - - pypIt->synchronise(); - - if (!m_am_root) boost::mpi::communicator().barrier(); - std::cerr << "After Sync Process " << m_rank << ":"; - pypIt->debug_info(std::cerr); std::cerr << std::endl; - if (m_am_root) boost::mpi::communicator().barrier(); - } - } + synchronise(); if (m_am_root) std::cerr << " Initialized in " << timer.Elapsed() << " seconds\n"; @@ -172,7 +155,7 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples, int table_delta = m_document_pyps.at(doc_index).decrement(current_topic); if (m_use_topic_pyp && table_delta < 0) - m_topic_pyp.decrement(current_topic); + m_topic_pyp.decrement(current_topic, rnd); } // sample a new_topic @@ -186,7 +169,7 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples, F p0 = m_topic_pyp.prob(new_topic, m_topic_p0); int table_delta = m_document_pyps.at(doc_index).increment(new_topic, p0); if (table_delta) - m_topic_pyp.increment(new_topic, m_topic_p0); + m_topic_pyp.increment(new_topic, m_topic_p0, rnd); } else m_document_pyps.at(doc_index).increment(new_topic, m_topic_p0); } @@ -194,19 +177,10 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples, if (m_am_root) std::cerr << "."; std::cerr.flush(); } } - m_world.barrier(); + std::cerr << "|"; std::cerr.flush(); + // Synchronise the topic->word counds across the processes. - for (std::vector::iterator levelIt=m_word_pyps.begin(); - levelIt != m_word_pyps.end(); ++levelIt) { - for (MPIPYPs::iterator pypIt=levelIt->begin(); - pypIt != levelIt->end(); ++pypIt) { - std::cerr << "Before Sync Process " << m_rank << ":"; - pypIt->debug_info(std::cerr); std::cerr << std::endl; - pypIt->synchronise(); - std::cerr << "After Sync Process " << m_rank << ":"; - pypIt->debug_info(std::cerr); std::cerr << std::endl; - } - } + synchronise(); if (m_am_root) std::cerr << " ||| sampled " << processed_terms << " terms."; @@ -221,19 +195,19 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples, levelIt != m_word_pyps.end(); ++levelIt) { for (MPIPYPs::iterator pypIt=levelIt->begin(); pypIt != levelIt->end(); ++pypIt) { - pypIt->resample_prior(); + pypIt->resample_prior(rnd); log_p += pypIt->log_restaurant_prob(); } } for (PYPs::iterator pypIt=m_document_pyps.begin(); pypIt != m_document_pyps.end(); ++pypIt) { - pypIt->resample_prior(); + pypIt->resample_prior(rnd); log_p += pypIt->log_restaurant_prob(); } if (m_use_topic_pyp) { - m_topic_pyp.resample_prior(); + m_topic_pyp.resample_prior(rnd); log_p += m_topic_pyp.log_restaurant_prob(); } @@ -257,10 +231,44 @@ void MPIPYPTopics::sample_corpus(const Corpus& corpus, int samples, delete [] randomDocIndices; } +void MPIPYPTopics::synchronise() { + // Synchronise the topic->word counds across the processes. + //for (std::vector::iterator levelIt=m_word_pyps.begin(); + // levelIt != m_word_pyps.end(); ++levelIt) { +// std::vector::iterator levelIt=m_word_pyps.begin(); +// { +// for (MPIPYPs::iterator pypIt=levelIt->begin(); pypIt != levelIt->end(); ++pypIt) { + for (size_t label=0; label < m_word_pyps.at(0).size(); ++label) { + MPIPYP& pyp = m_word_pyps.at(0).at(label); + + //if (!m_am_root) boost::mpi::communicator().barrier(); + //std::cerr << "Before Sync Process " << m_rank << ":"; + //pyp.debug_info(std::cerr); std::cerr << std::endl; + //if (m_am_root) boost::mpi::communicator().barrier(); + + MPIPYP::dish_delta_type delta; + pyp.synchronise(&delta); + + for (MPIPYP::dish_delta_type::const_iterator it=delta.begin(); it != delta.end(); ++it) { + int count = it->second; + if (count > 0) + for (int i=0; i < count; ++i) increment(it->first, label); + if (count < 0) + for (int i=0; i > count; --i) decrement(it->first, label); + } + pyp.reset_deltas(); + + //if (!m_am_root) boost::mpi::communicator().barrier(); + //std::cerr << "After Sync Process " << m_rank << ":"; + //pyp.debug_info(std::cerr); std::cerr << std::endl; + //if (m_am_root) boost::mpi::communicator().barrier(); + } +// } +} void MPIPYPTopics::decrement(const Term& term, int topic, int level) { //std::cerr << "MPIPYPTopics::decrement(" << term << "," << topic << "," << level << ")" << std::endl; - m_word_pyps.at(level).at(topic).decrement(term); + m_word_pyps.at(level).at(topic).decrement(term, rnd); if (m_backoff.get()) { Term backoff_term = (*m_backoff)[term]; if (!m_backoff->is_null(backoff_term)) @@ -270,7 +278,7 @@ void MPIPYPTopics::decrement(const Term& term, int topic, int level) { void MPIPYPTopics::increment(const Term& term, int topic, int level) { //std::cerr << "MPIPYPTopics::increment(" << term << "," << topic << "," << level << ")" << std::endl; - m_word_pyps.at(level).at(topic).increment(term, word_pyps_p0(term, topic, level)); + m_word_pyps.at(level).at(topic).increment(term, word_pyps_p0(term, topic, level), rnd); if (m_backoff.get()) { Term backoff_term = (*m_backoff)[term]; @@ -301,6 +309,7 @@ int MPIPYPTopics::sample(const DocumentId& doc, const Term& term) { if (cutoff <= sums[k]) return k; } + std::cerr << cutoff << " " << sum << std::endl; assert(false); } @@ -355,10 +364,11 @@ int MPIPYPTopics::max_topic() const { return current_topic; } -int MPIPYPTopics::max(const DocumentId& doc) const { +int MPIPYPTopics::max(const DocumentId& true_doc) const { //std::cerr << "MPIPYPTopics::max(" << doc << "," << term << ")" << std::endl; // collect probs F current_max=0.0; + DocumentId local_doc = true_doc - m_mpi_start; int current_topic=-1; for (int k=0; k current_max) { current_max = prob; @@ -380,10 +390,11 @@ int MPIPYPTopics::max(const DocumentId& doc) const { return current_topic; } -int MPIPYPTopics::max(const DocumentId& doc, const Term& term) const { +int MPIPYPTopics::max(const DocumentId& true_doc, const Term& term) const { //std::cerr << "MPIPYPTopics::max(" << doc << "," << term << ")" << std::endl; // collect probs F current_max=0.0; + DocumentId local_doc = true_doc - m_mpi_start; int current_topic=-1; for (int k=0; k current_max) { -- cgit v1.2.3