From 8f97e6b03114761870f0c72f18f0928fac28d0f9 Mon Sep 17 00:00:00 2001 From: philblunsom Date: Wed, 14 Jul 2010 22:42:35 +0000 Subject: starting an mpi version. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@253 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/pyp-topics/src/Makefile.am | 9 +- gi/pyp-topics/src/contexts_corpus.cc | 22 +- gi/pyp-topics/src/contexts_lexer.h | 2 +- gi/pyp-topics/src/contexts_lexer.l | 5 +- gi/pyp-topics/src/mpi-pyp-topics.cc | 431 +++++++++++++++++++++++++ gi/pyp-topics/src/mpi-pyp-topics.hh | 97 ++++++ gi/pyp-topics/src/mpi-pyp.hh | 552 ++++++++++++++++++++++++++++++++ gi/pyp-topics/src/mpi-train-contexts.cc | 169 ++++++++++ gi/pyp-topics/src/pyp-topics.cc | 24 +- gi/pyp-topics/src/pyp-topics.hh | 3 +- gi/pyp-topics/src/pyp.hh | 52 ++- gi/pyp-topics/src/train-contexts.cc | 4 +- 12 files changed, 1334 insertions(+), 36 deletions(-) create mode 100644 gi/pyp-topics/src/mpi-pyp-topics.cc create mode 100644 gi/pyp-topics/src/mpi-pyp-topics.hh create mode 100644 gi/pyp-topics/src/mpi-pyp.hh create mode 100644 gi/pyp-topics/src/mpi-train-contexts.cc (limited to 'gi/pyp-topics') diff --git a/gi/pyp-topics/src/Makefile.am b/gi/pyp-topics/src/Makefile.am index abfc95ac..a3a30acd 100644 --- a/gi/pyp-topics/src/Makefile.am +++ b/gi/pyp-topics/src/Makefile.am @@ -1,13 +1,16 @@ -bin_PROGRAMS = pyp-topics-train pyp-contexts-train +bin_PROGRAMS = pyp-topics-train pyp-contexts-train mpi-pyp-contexts-train contexts_lexer.cc: contexts_lexer.l $(LEX) -s -CF -8 -o$@ $< -pyp_topics_train_SOURCES = corpus.cc gzstream.cc pyp-topics.cc train.cc contexts_lexer.cc contexts_corpus.cc +pyp_topics_train_SOURCES = mt19937ar.c corpus.cc gzstream.cc pyp-topics.cc train.cc contexts_lexer.cc contexts_corpus.cc pyp_topics_train_LDADD = $(top_srcdir)/decoder/libcdec.a -lz -pyp_contexts_train_SOURCES = corpus.cc gzstream.cc pyp-topics.cc contexts_lexer.cc contexts_corpus.cc train-contexts.cc +pyp_contexts_train_SOURCES = mt19937ar.c corpus.cc gzstream.cc pyp-topics.cc contexts_lexer.cc contexts_corpus.cc train-contexts.cc pyp_contexts_train_LDADD = $(top_srcdir)/decoder/libcdec.a -lz +mpi_pyp_contexts_train_SOURCES = mt19937ar.c corpus.cc gzstream.cc mpi-pyp-topics.cc contexts_lexer.cc contexts_corpus.cc mpi-train-contexts.cc +mpi_pyp_contexts_train_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + AM_CPPFLAGS = -W -Wall -Wno-sign-compare -funroll-loops diff --git a/gi/pyp-topics/src/contexts_corpus.cc b/gi/pyp-topics/src/contexts_corpus.cc index 280b2976..26d5718a 100644 --- a/gi/pyp-topics/src/contexts_corpus.cc +++ b/gi/pyp-topics/src/contexts_corpus.cc @@ -28,9 +28,12 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* Document* doc(new Document()); //cout << "READ: " << new_contexts.phrase << "\t"; - for (int i=0; i < new_contexts.contexts.size(); ++i) { + for (int i=0; i < new_contexts.counts.size(); ++i) { int cache_word_count = corpus_ptr->m_dict.max(); - string context_str = corpus_ptr->m_dict.toString(new_contexts.contexts[i]); + + //string context_str = corpus_ptr->m_dict.toString(new_contexts.contexts[i]); + int context_index = new_contexts.counts.at(i).first; + string context_str = corpus_ptr->m_dict.toString(new_contexts.contexts[context_index]); // filter out singleton contexts //if (!counts->empty()) { @@ -45,7 +48,8 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* corpus_ptr->m_num_types++; } - int count = new_contexts.counts[i]; + //int count = new_contexts.counts[i]; + int count = new_contexts.counts.at(i).second; for (int j=0; jpush_back(id); corpus_ptr->m_num_terms += count; @@ -54,7 +58,8 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* if (backoff_gen) { int order = 1; WordID backoff_id = id; - ContextsLexer::Context backedoff_context = new_contexts.contexts[i]; + //ContextsLexer::Context backedoff_context = new_contexts.contexts[i]; + ContextsLexer::Context backedoff_context = new_contexts.contexts[context_index]; while (true) { if (!corpus_ptr->m_backoff->has_backoff(backoff_id)) { //cerr << "Backing off from " << corpus_ptr->m_dict.Convert(backoff_id) << " to "; @@ -96,10 +101,13 @@ void filter_callback(const ContextsLexer::PhraseContextsType& new_contexts, void map* context_counts = (static_cast*>(extra)); - for (int i=0; i < new_contexts.contexts.size(); ++i) { - int count = new_contexts.counts[i]; + for (int i=0; i < new_contexts.counts.size(); ++i) { + int context_index = new_contexts.counts.at(i).first; + int count = new_contexts.counts.at(i).second; + //int count = new_contexts.counts[i]; pair::iterator,bool> result - = context_counts->insert(make_pair(Dict::toString(new_contexts.contexts[i]),count)); + = context_counts->insert(make_pair(Dict::toString(new_contexts.contexts[context_index]),count)); + //= context_counts->insert(make_pair(Dict::toString(new_contexts.contexts[i]),count)); if (!result.second) result.first->second += count; } diff --git a/gi/pyp-topics/src/contexts_lexer.h b/gi/pyp-topics/src/contexts_lexer.h index f9a1b21c..1b79c6fd 100644 --- a/gi/pyp-topics/src/contexts_lexer.h +++ b/gi/pyp-topics/src/contexts_lexer.h @@ -12,7 +12,7 @@ struct ContextsLexer { struct PhraseContextsType { std::string phrase; std::vector contexts; - std::vector counts; + std::vector< std::pair > counts; }; typedef void (*ContextsCallback)(const PhraseContextsType& new_contexts, void* extra); diff --git a/gi/pyp-topics/src/contexts_lexer.l b/gi/pyp-topics/src/contexts_lexer.l index 61189a73..7a5d9460 100644 --- a/gi/pyp-topics/src/contexts_lexer.l +++ b/gi/pyp-topics/src/contexts_lexer.l @@ -6,6 +6,7 @@ #include #include #include +#include int lex_line = 0; std::istream* contextslex_stream = NULL; @@ -69,7 +70,7 @@ INT [\-+]?[0-9]+|inf|[\-+]inf [ \t]+ { ; } C={INT} { - current_contexts.counts.push_back(atoi(yytext+2)); + current_contexts.counts.push_back(std::make_pair(current_contexts.counts.size(), atoi(yytext+2))); BEGIN(COUNT_END); } . { @@ -84,6 +85,8 @@ INT [\-+]?[0-9]+|inf|[\-+]inf \n { //std::cerr << "READ:" << current_contexts.phrase << " with " << current_contexts.contexts.size() // << " contexts, and " << current_contexts.counts.size() << " counts." << std::endl; + std::sort(current_contexts.counts.rbegin(), current_contexts.counts.rend()); + contexts_callback(current_contexts, contexts_callback_extra); current_contexts.phrase.clear(); current_contexts.contexts.clear(); diff --git a/gi/pyp-topics/src/mpi-pyp-topics.cc b/gi/pyp-topics/src/mpi-pyp-topics.cc new file mode 100644 index 00000000..d2daad4f --- /dev/null +++ b/gi/pyp-topics/src/mpi-pyp-topics.cc @@ -0,0 +1,431 @@ +#include "timing.h" +#include "mpi-pyp-topics.hh" + +//#include +void PYPTopics::sample_corpus(const Corpus& corpus, int samples, + int freq_cutoff_start, int freq_cutoff_end, + int freq_cutoff_interval, + int max_contexts_per_document) { + Timer timer; + + if (!m_backoff.get()) { + m_word_pyps.clear(); + m_word_pyps.push_back(PYPs()); + } + + std::cerr << "\n Training with " << m_word_pyps.size()-1 << " backoff level" + << (m_word_pyps.size()==2 ? ":" : "s:") << std::endl; + + for (int i=0; i<(int)m_word_pyps.size(); ++i) + { + m_word_pyps.at(i).reserve(m_num_topics); + for (int j=0; j(0.5, 1.0, m_seed)); + } + std::cerr << std::endl; + + m_document_pyps.reserve(corpus.num_documents()); + for (int j=0; j(0.5, 1.0, m_seed)); + + m_topic_p0 = 1.0/m_num_topics; + m_term_p0 = 1.0/corpus.num_types(); + m_backoff_p0 = 1.0/corpus.num_documents(); + + std::cerr << " Documents: " << corpus.num_documents() << " Terms: " + << corpus.num_types() << std::endl; + + int frequency_cutoff = freq_cutoff_start; + std::cerr << " Context frequency cutoff set to " << frequency_cutoff << std::endl; + + timer.Reset(); + // Initialisation pass + int document_id=0, topic_counter=0; + for (Corpus::const_iterator corpusIt=corpus.begin(); + corpusIt != corpus.end(); ++corpusIt, ++document_id) { + m_corpus_topics.push_back(DocumentTopics(corpusIt->size(), 0)); + + int term_index=0; + for (Document::const_iterator docIt=corpusIt->begin(); + docIt != corpusIt->end(); ++docIt, ++term_index) { + topic_counter++; + Term term = *docIt; + + // sample a new_topic + //int new_topic = (topic_counter % m_num_topics); + int freq = corpus.context_count(term); + int new_topic = -1; + if (freq > frequency_cutoff + && (!max_contexts_per_document || term_index < max_contexts_per_document)) { + new_topic = document_id % m_num_topics; + + // add the new topic to the PYPs + increment(term, new_topic); + + if (m_use_topic_pyp) { + F p0 = m_topic_pyp.prob(new_topic, m_topic_p0); + int table_delta = m_document_pyps[document_id].increment(new_topic, p0); + if (table_delta) + m_topic_pyp.increment(new_topic, m_topic_p0); + } + else m_document_pyps[document_id].increment(new_topic, m_topic_p0); + } + + m_corpus_topics[document_id][term_index] = new_topic; + } + } + std::cerr << " Initialized in " << timer.Elapsed() << " seconds\n"; + + int* randomDocIndices = new int[corpus.num_documents()]; + for (int i = 0; i < corpus.num_documents(); ++i) + randomDocIndices[i] = i; + + // Sampling phase + for (int curr_sample=0; curr_sample < samples; ++curr_sample) { + if (freq_cutoff_interval > 0 && curr_sample != 1 + && curr_sample % freq_cutoff_interval == 1 + && frequency_cutoff > freq_cutoff_end) { + frequency_cutoff--; + std::cerr << "\n Context frequency cutoff set to " << frequency_cutoff << std::endl; + } + + std::cerr << "\n -- Sample " << curr_sample << " "; std::cerr.flush(); + + // Randomize the corpus indexing array + int tmp; + int processed_terms=0; + for (int i = corpus.num_documents()-1; i > 0; --i) + { + //i+1 since j \in [0,i] but rnd() \in [0,1) + int j = (int)(rnd() * (i+1)); + assert(j >= 0 && j <= i); + tmp = randomDocIndices[i]; + randomDocIndices[i] = randomDocIndices[j]; + randomDocIndices[j] = tmp; + } + + // for each document in the corpus + int document_id; + for (int i=0; i max_contexts_per_document) + break; + + Term term = *docIt; + int freq = corpus.context_count(term); + if (freq < frequency_cutoff) + continue; + + processed_terms++; + + // remove the prevous topic from the PYPs + int current_topic = m_corpus_topics[document_id][term_index]; + // a negative label mean that term hasn't been sampled yet + if (current_topic >= 0) { + decrement(term, current_topic); + + int table_delta = m_document_pyps[document_id].decrement(current_topic); + if (m_use_topic_pyp && table_delta < 0) + m_topic_pyp.decrement(current_topic); + } + + // sample a new_topic + int new_topic = sample(document_id, term); + + // add the new topic to the PYPs + m_corpus_topics[document_id][term_index] = new_topic; + increment(term, new_topic); + + if (m_use_topic_pyp) { + F p0 = m_topic_pyp.prob(new_topic, m_topic_p0); + int table_delta = m_document_pyps[document_id].increment(new_topic, p0); + if (table_delta) + m_topic_pyp.increment(new_topic, m_topic_p0); + } + else m_document_pyps[document_id].increment(new_topic, m_topic_p0); + } + if (document_id && document_id % 10000 == 0) { + std::cerr << "."; std::cerr.flush(); + } + } + std::cerr << " ||| sampled " << processed_terms << " terms."; + + if (curr_sample != 0 && curr_sample % 10 == 0) { + std::cerr << " ||| time=" << (timer.Elapsed() / 10.0) << " sec/sample" << std::endl; + timer.Reset(); + std::cerr << " ... Resampling hyperparameters (" << max_threads << " threads)"; std::cerr.flush(); + + // resample the hyperparamters + F log_p=0.0; + for (std::vector::iterator levelIt=m_word_pyps.begin(); + levelIt != m_word_pyps.end(); ++levelIt) { + for (PYPs::iterator pypIt=levelIt->begin(); + pypIt != levelIt->end(); ++pypIt) { + pypIt->resample_prior(); + log_p += pypIt->log_restaurant_prob(); + } + } + + WorkerPtrVect workers; + for (int i = 0; i < max_threads; ++i) + { + JobReturnsF job = boost::bind(&PYPTopics::hresample_docs, this, max_threads, i); + workers.push_back(new SimpleResampleWorker(job)); + } + + WorkerPtrVect::iterator workerIt; + for (workerIt = workers.begin(); workerIt != workers.end(); ++workerIt) + { + //std::cerr << "Retrieving worker result.."; std::cerr.flush(); + F wresult = workerIt->getResult(); //blocks until worker done + log_p += wresult; + //std::cerr << ".. got " << wresult << std::endl; std::cerr.flush(); + + } + + if (m_use_topic_pyp) { + m_topic_pyp.resample_prior(); + log_p += m_topic_pyp.log_restaurant_prob(); + } + + std::cerr.precision(10); + std::cerr << " ||| LLH=" << log_p << " ||| resampling time=" << timer.Elapsed() << " sec" << std::endl; + timer.Reset(); + + int k=0; + std::cerr << "Topics distribution: "; + std::cerr.precision(2); + for (PYPs::iterator pypIt=m_word_pyps.front().begin(); + pypIt != m_word_pyps.front().end(); ++pypIt, ++k) { + if (k % 5 == 0) std::cerr << std::endl << '\t'; + std::cerr << "<" << k << ":" << pypIt->num_customers() << "," + << pypIt->num_types() << "," << m_topic_pyp.prob(k, m_topic_p0) << "> "; + } + std::cerr.precision(4); + std::cerr << std::endl; + } + } + delete [] randomDocIndices; +} + +PYPTopics::F PYPTopics::hresample_docs(int num_threads, int thread_id) +{ + int resample_counter=0; + F log_p = 0.0; + PYPs::iterator pypIt = m_document_pyps.begin(); + PYPs::iterator end = m_document_pyps.end(); + pypIt += thread_id; +// std::cerr << thread_id << " started " << std::endl; std::cerr.flush(); + + while (pypIt < end) + { + pypIt->resample_prior(); + log_p += pypIt->log_restaurant_prob(); + if (resample_counter++ % 5000 == 0) { + std::cerr << "."; std::cerr.flush(); + } + pypIt += num_threads; + } +// std::cerr << thread_id << " did " << resample_counter << " with answer " << log_p << std::endl; std::cerr.flush(); + + return log_p; +} + +//PYPTopics::F PYPTopics::hresample_topics() +//{ +// F log_p = 0.0; +// for (std::vector::iterator levelIt=m_word_pyps.begin(); +// levelIt != m_word_pyps.end(); ++levelIt) { +// for (PYPs::iterator pypIt=levelIt->begin(); +// pypIt != levelIt->end(); ++pypIt) { +// +// pypIt->resample_prior(); +// log_p += pypIt->log_restaurant_prob(); +// } +// } +// //std::cerr << "topicworker has answer " << log_p << std::endl; std::cerr.flush(); +// +// return log_p; +//} + +void PYPTopics::decrement(const Term& term, int topic, int level) { + //std::cerr << "PYPTopics::decrement(" << term << "," << topic << "," << level << ")" << std::endl; + m_word_pyps.at(level).at(topic).decrement(term); + if (m_backoff.get()) { + Term backoff_term = (*m_backoff)[term]; + if (!m_backoff->is_null(backoff_term)) + decrement(backoff_term, topic, level+1); + } +} + +void PYPTopics::increment(const Term& term, int topic, int level) { + //std::cerr << "PYPTopics::increment(" << term << "," << topic << "," << level << ")" << std::endl; + m_word_pyps.at(level).at(topic).increment(term, word_pyps_p0(term, topic, level)); + + if (m_backoff.get()) { + Term backoff_term = (*m_backoff)[term]; + if (!m_backoff->is_null(backoff_term)) + increment(backoff_term, topic, level+1); + } +} + +int PYPTopics::sample(const DocumentId& doc, const Term& term) { + // First pass: collect probs + F sum=0.0; + std::vector sums; + for (int k=0; kis_null(backoff_term)) { + assert (level < m_backoff->order()); + p0 = (1.0/(double)m_backoff->terms_at_level(level))*prob(backoff_term, topic, level+1); + } + else + p0 = m_term_p0; + } + //for (int i=0; i current_max) { + current_max = prob; + current_topic = k; + } + } + assert(current_topic >= 0); + return current_topic; +} + +int PYPTopics::max(const DocumentId& doc, const Term& term) const { + //std::cerr << "PYPTopics::max(" << doc << "," << term << ")" << std::endl; + // collect probs + F current_max=0.0; + int current_topic=-1; + for (int k=0; k current_max) { + current_max = prob; + current_topic = k; + } + } + assert(current_topic >= 0); + return current_topic; +} + +std::ostream& PYPTopics::print_document_topics(std::ostream& out) const { + for (CorpusTopics::const_iterator corpusIt=m_corpus_topics.begin(); + corpusIt != m_corpus_topics.end(); ++corpusIt) { + int term_index=0; + for (DocumentTopics::const_iterator docIt=corpusIt->begin(); + docIt != corpusIt->end(); ++docIt, ++term_index) { + if (term_index) out << " "; + out << *docIt; + } + out << std::endl; + } + return out; +} + +std::ostream& PYPTopics::print_topic_terms(std::ostream& out) const { + for (PYPs::const_iterator pypsIt=m_word_pyps.front().begin(); + pypsIt != m_word_pyps.front().end(); ++pypsIt) { + int term_index=0; + for (PYP::const_iterator termIt=pypsIt->begin(); + termIt != pypsIt->end(); ++termIt, ++term_index) { + if (term_index) out << " "; + out << termIt->first << ":" << termIt->second; + } + out << std::endl; + } + return out; +} diff --git a/gi/pyp-topics/src/mpi-pyp-topics.hh b/gi/pyp-topics/src/mpi-pyp-topics.hh new file mode 100644 index 00000000..d978c7a1 --- /dev/null +++ b/gi/pyp-topics/src/mpi-pyp-topics.hh @@ -0,0 +1,97 @@ +#ifndef PYP_TOPICS_HH +#define PYP_TOPICS_HH + +#include +#include +#include + +#include +#include +#include + +#include "mpi-pyp.hh" +#include "corpus.hh" +#include "workers.hh" + +class PYPTopics { +public: + typedef std::vector DocumentTopics; + typedef std::vector CorpusTopics; + typedef double F; + +public: + PYPTopics(int num_topics, bool use_topic_pyp=false, unsigned long seed = 0, + int max_threads = 1) + : m_num_topics(num_topics), m_word_pyps(1), + m_topic_pyp(0.5,1.0,seed), m_use_topic_pyp(use_topic_pyp), + m_seed(seed), + uni_dist(0,1), rng(seed == 0 ? (unsigned long)this : seed), + rnd(rng, uni_dist), max_threads(max_threads) {} + + void sample_corpus(const Corpus& corpus, int samples, + int freq_cutoff_start=0, int freq_cutoff_end=0, + int freq_cutoff_interval=0, + int max_contexts_per_document=0); + + int sample(const DocumentId& doc, const Term& term); + int max(const DocumentId& doc, const Term& term) const; + int max(const DocumentId& doc) const; + int max_topic() const; + + void set_backoff(const std::string& filename) { + m_backoff.reset(new TermBackoff); + m_backoff->read(filename); + m_word_pyps.clear(); + m_word_pyps.resize(m_backoff->order(), PYPs()); + } + void set_backoff(TermBackoffPtr backoff) { + m_backoff = backoff; + m_word_pyps.clear(); + m_word_pyps.resize(m_backoff->order(), PYPs()); + } + + F prob(const Term& term, int topic, int level=0) const; + void decrement(const Term& term, int topic, int level=0); + void increment(const Term& term, int topic, int level=0); + + std::ostream& print_document_topics(std::ostream& out) const; + std::ostream& print_topic_terms(std::ostream& out) const; + +private: + F word_pyps_p0(const Term& term, int topic, int level) const; + + int m_num_topics; + F m_term_p0, m_topic_p0, m_backoff_p0; + + CorpusTopics m_corpus_topics; + typedef boost::ptr_vector< PYP > PYPs; + PYPs m_document_pyps; + std::vector m_word_pyps; + PYP m_topic_pyp; + bool m_use_topic_pyp; + + unsigned long m_seed; + + typedef boost::mt19937 base_generator_type; + typedef boost::uniform_real<> uni_dist_type; + typedef boost::variate_generator gen_type; + + uni_dist_type uni_dist; + base_generator_type rng; //this gets the seed + gen_type rnd; //instantiate: rnd(rng, uni_dist) + //call: rnd() generates uniform on [0,1) + + typedef boost::function JobReturnsF; + typedef SimpleWorker SimpleResampleWorker; + typedef boost::ptr_vector WorkerPtrVect; + + F hresample_docs(int num_threads, int thread_id); + +// F hresample_topics(); + + int max_threads; + + TermBackoffPtr m_backoff; +}; + +#endif // PYP_TOPICS_HH diff --git a/gi/pyp-topics/src/mpi-pyp.hh b/gi/pyp-topics/src/mpi-pyp.hh new file mode 100644 index 00000000..dc47244b --- /dev/null +++ b/gi/pyp-topics/src/mpi-pyp.hh @@ -0,0 +1,552 @@ +#ifndef _pyp_hh +#define _pyp_hh + +#include +#include +#include +//#include + +#include +#include +#include + +#include "log_add.h" +#include "slice-sampler.h" +#include "mt19937ar.h" + +// +// Pitman-Yor process with customer and table tracking +// + +template > +class PYP : protected std::tr1::unordered_map +//class PYP : protected google::sparse_hash_map +{ +public: + using std::tr1::unordered_map::const_iterator; + using std::tr1::unordered_map::iterator; + using std::tr1::unordered_map::begin; + using std::tr1::unordered_map::end; +// using google::sparse_hash_map::const_iterator; +// using google::sparse_hash_map::iterator; +// using google::sparse_hash_map::begin; +// using google::sparse_hash_map::end; + + PYP(double a, double b, unsigned long seed = 0, Hash hash=Hash()); + + int increment(Dish d, double p0); + int decrement(Dish d); + + // lookup functions + int count(Dish d) const; + double prob(Dish dish, double p0) const; + double prob(Dish dish, double dcd, double dca, + double dtd, double dta, double p0) const; + double unnormalised_prob(Dish dish, double p0) const; + + int num_customers() const { return _total_customers; } + int num_types() const { return std::tr1::unordered_map::size(); } + //int num_types() const { return google::sparse_hash_map::size(); } + bool empty() const { return _total_customers == 0; } + + double log_prob(Dish dish, double log_p0) const; + // nb. d* are NOT logs + double log_prob(Dish dish, double dcd, double dca, + double dtd, double dta, double log_p0) const; + + int num_tables(Dish dish) const; + int num_tables() const; + + double a() const { return _a; } + void set_a(double a) { _a = a; } + + double b() const { return _b; } + void set_b(double b) { _b = b; } + + void clear(); + std::ostream& debug_info(std::ostream& os) const; + + double log_restaurant_prob() const; + double log_prior() const; + static double log_prior_a(double a, double beta_a, double beta_b); + static double log_prior_b(double b, double gamma_c, double gamma_s); + + void resample_prior(); + void resample_prior_a(); + void resample_prior_b(); + +private: + double _a, _b; // parameters of the Pitman-Yor distribution + double _a_beta_a, _a_beta_b; // parameters of Beta prior on a + double _b_gamma_s, _b_gamma_c; // parameters of Gamma prior on b + + struct TableCounter + { + TableCounter() : tables(0) {}; + int tables; + std::map table_histogram; // num customers at table -> number tables + }; + typedef std::tr1::unordered_map DishTableType; + //typedef google::sparse_hash_map DishTableType; + DishTableType _dish_tables; + int _total_customers, _total_tables; + + typedef boost::mt19937 base_generator_type; + typedef boost::uniform_real<> uni_dist_type; + typedef boost::variate_generator gen_type; + +// uni_dist_type uni_dist; +// base_generator_type rng; //this gets the seed +// gen_type rnd; //instantiate: rnd(rng, uni_dist) + //call: rnd() generates uniform on [0,1) + + // Function objects for calculating the parts of the log_prob for + // the parameters a and b + struct resample_a_type { + int n, m; double b, a_beta_a, a_beta_b; + const DishTableType& dish_tables; + resample_a_type(int n, int m, double b, double a_beta_a, + double a_beta_b, const DishTableType& dish_tables) + : n(n), m(m), b(b), a_beta_a(a_beta_a), a_beta_b(a_beta_b), dish_tables(dish_tables) {} + + double operator() (double proposed_a) const { + double log_prior = log_prior_a(proposed_a, a_beta_a, a_beta_b); + double log_prob = 0.0; + double lgamma1a = lgamma(1.0 - proposed_a); + for (typename DishTableType::const_iterator dish_it=dish_tables.begin(); dish_it != dish_tables.end(); ++dish_it) + for (std::map::const_iterator table_it=dish_it->second.table_histogram.begin(); + table_it !=dish_it->second.table_histogram.end(); ++table_it) + log_prob += (table_it->second * (lgamma(table_it->first - proposed_a) - lgamma1a)); + + log_prob += (proposed_a == 0.0 ? (m-1.0)*log(b) + : ((m-1.0)*log(proposed_a) + lgamma((m-1.0) + b/proposed_a) - lgamma(b/proposed_a))); + assert(std::isfinite(log_prob)); + return log_prob + log_prior; + } + }; + + struct resample_b_type { + int n, m; double a, b_gamma_c, b_gamma_s; + resample_b_type(int n, int m, double a, double b_gamma_c, double b_gamma_s) + : n(n), m(m), a(a), b_gamma_c(b_gamma_c), b_gamma_s(b_gamma_s) {} + + double operator() (double proposed_b) const { + double log_prior = log_prior_b(proposed_b, b_gamma_c, b_gamma_s); + double log_prob = 0.0; + log_prob += (a == 0.0 ? (m-1.0)*log(proposed_b) + : ((m-1.0)*log(a) + lgamma((m-1.0) + proposed_b/a) - lgamma(proposed_b/a))); + log_prob += (lgamma(1.0+proposed_b) - lgamma(n+proposed_b)); + return log_prob + log_prior; + } + }; + + /* lbetadist() returns the log probability density of x under a Beta(alpha,beta) + * distribution. - copied from Mark Johnson's gammadist.c + */ + static long double lbetadist(long double x, long double alpha, long double beta); + + /* lgammadist() returns the log probability density of x under a Gamma(alpha,beta) + * distribution - copied from Mark Johnson's gammadist.c + */ + static long double lgammadist(long double x, long double alpha, long double beta); + +}; + +template +PYP::PYP(double a, double b, unsigned long seed, Hash) +: std::tr1::unordered_map(10), _a(a), _b(b), +//: google::sparse_hash_map(10), _a(a), _b(b), + _a_beta_a(1), _a_beta_b(1), _b_gamma_s(1), _b_gamma_c(1), + //_a_beta_a(1), _a_beta_b(1), _b_gamma_s(10), _b_gamma_c(0.1), + _total_customers(0), _total_tables(0)//, + //uni_dist(0,1), rng(seed == 0 ? (unsigned long)this : seed), rnd(rng, uni_dist) +{ +// std::cerr << "\t##PYP::PYP(a=" << _a << ",b=" << _b << ")" << std::endl; + //set_deleted_key(-std::numeric_limits::max()); +} + +template +double +PYP::prob(Dish dish, double p0) const +{ + int c = count(dish), t = num_tables(dish); + double r = num_tables() * _a + _b; + //std::cerr << "\t\t\t\tPYP::prob(" << dish << "," << p0 << ") c=" << c << " r=" << r << std::endl; + if (c > 0) + return (c - _a * t + r * p0) / (num_customers() + _b); + else + return r * p0 / (num_customers() + _b); +} + +template +double +PYP::unnormalised_prob(Dish dish, double p0) const +{ + int c = count(dish), t = num_tables(dish); + double r = num_tables() * _a + _b; + if (c > 0) return (c - _a * t + r * p0); + else return r * p0; +} + +template +double +PYP::prob(Dish dish, double dcd, double dca, + double dtd, double dta, double p0) +const +{ + int c = count(dish) + dcd, t = num_tables(dish) + dtd; + double r = (num_tables() + dta) * _a + _b; + if (c > 0) + return (c - _a * t + r * p0) / (num_customers() + dca + _b); + else + return r * p0 / (num_customers() + dca + _b); +} + +template +double +PYP::log_prob(Dish dish, double log_p0) const +{ + using std::log; + int c = count(dish), t = num_tables(dish); + double r = log(num_tables() * _a + b); + if (c > 0) + return Log::add(log(c - _a * t), r + log_p0) + - log(num_customers() + _b); + else + return r + log_p0 - log(num_customers() + b); +} + +template +double +PYP::log_prob(Dish dish, double dcd, double dca, + double dtd, double dta, double log_p0) +const +{ + using std::log; + int c = count(dish) + dcd, t = num_tables(dish) + dtd; + double r = log((num_tables() + dta) * _a + b); + if (c > 0) + return Log::add(log(c - _a * t), r + log_p0) + - log(num_customers() + dca + _b); + else + return r + log_p0 - log(num_customers() + dca + b); +} + +template +int +PYP::increment(Dish dish, double p0) { + int delta = 0; + TableCounter &tc = _dish_tables[dish]; + + // seated on a new or existing table? + int c = count(dish), t = num_tables(dish), T = num_tables(); + double pshare = (c > 0) ? (c - _a*t) : 0.0; + double pnew = (_b + _a*T) * p0; + assert (pshare >= 0.0); + //assert (pnew > 0.0); + + //if (rnd() < pnew / (pshare + pnew)) { + if (mt_genrand_res53() < pnew / (pshare + pnew)) { + // assign to a new table + tc.tables += 1; + tc.table_histogram[1] += 1; + _total_tables += 1; + delta = 1; + } + else { + // randomly assign to an existing table + // remove constant denominator from inner loop + //double r = rnd() * (c - _a*t); + double r = mt_genrand_res53() * (c - _a*t); + for (std::map::iterator + hit = tc.table_histogram.begin(); + hit != tc.table_histogram.end(); ++hit) { + r -= ((hit->first - _a) * hit->second); + if (r <= 0) { + tc.table_histogram[hit->first+1] += 1; + hit->second -= 1; + if (hit->second == 0) + tc.table_histogram.erase(hit); + break; + } + } + if (r > 0) { + std::cerr << r << " " << c << " " << _a << " " << t << std::endl; + assert(false); + } + delta = 0; + } + + std::tr1::unordered_map::operator[](dish) += 1; + //google::sparse_hash_map::operator[](dish) += 1; + _total_customers += 1; + + return delta; +} + +template +int +PYP::count(Dish dish) const +{ + typename std::tr1::unordered_map::const_iterator + //typename google::sparse_hash_map::const_iterator + dcit = find(dish); + if (dcit != end()) + return dcit->second; + else + return 0; +} + +template +int +PYP::decrement(Dish dish) +{ + typename std::tr1::unordered_map::iterator dcit = find(dish); + //typename google::sparse_hash_map::iterator dcit = find(dish); + if (dcit == end()) { + std::cerr << dish << std::endl; + assert(false); + } + + int delta = 0; + + typename std::tr1::unordered_map::iterator dtit = _dish_tables.find(dish); + //typename google::sparse_hash_map::iterator dtit = _dish_tables.find(dish); + if (dtit == _dish_tables.end()) { + std::cerr << dish << std::endl; + assert(false); + } + TableCounter &tc = dtit->second; + + //std::cerr << "\tdecrement for " << dish << "\n"; + //std::cerr << "\tBEFORE histogram: " << tc.table_histogram << " "; + //std::cerr << "count: " << count(dish) << " "; + //std::cerr << "tables: " << tc.tables << "\n"; + + //double r = rnd() * count(dish); + double r = mt_genrand_res53() * count(dish); + for (std::map::iterator hit = tc.table_histogram.begin(); + hit != tc.table_histogram.end(); ++hit) + { + //r -= (hit->first - _a) * hit->second; + r -= (hit->first) * hit->second; + if (r <= 0) + { + if (hit->first > 1) + tc.table_histogram[hit->first-1] += 1; + else + { + delta = -1; + tc.tables -= 1; + _total_tables -= 1; + } + + hit->second -= 1; + if (hit->second == 0) tc.table_histogram.erase(hit); + break; + } + } + if (r > 0) { + std::cerr << r << " " << count(dish) << " " << _a << " " << num_tables(dish) << std::endl; + assert(false); + } + + // remove the customer + dcit->second -= 1; + _total_customers -= 1; + assert(dcit->second >= 0); + if (dcit->second == 0) { + erase(dcit); + _dish_tables.erase(dtit); + //std::cerr << "\tAFTER histogram: Empty\n"; + } + else { + //std::cerr << "\tAFTER histogram: " << _dish_tables[dish].table_histogram << " "; + //std::cerr << "count: " << count(dish) << " "; + //std::cerr << "tables: " << _dish_tables[dish].tables << "\n"; + } + + return delta; +} + +template +int +PYP::num_tables(Dish dish) const +{ + typename std::tr1::unordered_map::const_iterator + //typename google::sparse_hash_map::const_iterator + dtit = _dish_tables.find(dish); + + //assert(dtit != _dish_tables.end()); + if (dtit == _dish_tables.end()) + return 0; + + return dtit->second.tables; +} + +template +int +PYP::num_tables() const +{ + return _total_tables; +} + +template +std::ostream& +PYP::debug_info(std::ostream& os) const +{ + int hists = 0, tables = 0; + for (typename std::tr1::unordered_map::const_iterator + //for (typename google::sparse_hash_map::const_iterator + dtit = _dish_tables.begin(); dtit != _dish_tables.end(); ++dtit) + { + hists += dtit->second.table_histogram.size(); + tables += dtit->second.tables; + + assert(dtit->second.tables > 0); + assert(!dtit->second.table_histogram.empty()); + + for (std::map::const_iterator + hit = dtit->second.table_histogram.begin(); + hit != dtit->second.table_histogram.end(); ++hit) + assert(hit->second > 0); + } + + os << "restaurant has " + << _total_customers << " customers; " + << _total_tables << " tables; " + << tables << " tables'; " + << num_types() << " dishes; " + << _dish_tables.size() << " dishes'; and " + << hists << " histogram entries\n"; + + return os; +} + +template +void +PYP::clear() +{ + this->std::tr1::unordered_map::clear(); + //this->google::sparse_hash_map::clear(); + _dish_tables.clear(); + _total_tables = _total_customers = 0; +} + +// log_restaurant_prob returns the log probability of the PYP table configuration. +// Excludes Hierarchical P0 term which must be calculated separately. +template +double +PYP::log_restaurant_prob() const { + if (_total_customers < 1) + return (double)0.0; + + double log_prob = 0.0; + double lgamma1a = lgamma(1.0-_a); + + //std::cerr << "-------------------\n" << std::endl; + for (typename DishTableType::const_iterator dish_it=_dish_tables.begin(); + dish_it != _dish_tables.end(); ++dish_it) { + for (std::map::const_iterator table_it=dish_it->second.table_histogram.begin(); + table_it !=dish_it->second.table_histogram.end(); ++table_it) { + log_prob += (table_it->second * (lgamma(table_it->first - _a) - lgamma1a)); + //std::cerr << "|" << dish_it->first->parent << " --> " << dish_it->first->rhs << " " << table_it->first << " " << table_it->second << " " << log_prob; + } + } + //std::cerr << std::endl; + + log_prob += (_a == (double)0.0 ? (_total_tables-1.0)*log(_b) : (_total_tables-1.0)*log(_a) + lgamma((_total_tables-1.0) + _b/_a) - lgamma(_b/_a)); + //std::cerr << "\t\t" << log_prob << std::endl; + log_prob += (lgamma(1.0 + _b) - lgamma(_total_customers + _b)); + + //std::cerr << _total_customers << " " << _total_tables << " " << log_prob << " " << log_prior() << std::endl; + //std::cerr << _a << " " << _b << std::endl; + if (!std::isfinite(log_prob)) { + assert(false); + } + //return log_prob; + return log_prob + log_prior(); +} + +template +double +PYP::log_prior() const { + double prior = 0.0; + if (_a_beta_a > 0.0 && _a_beta_b > 0.0 && _a > 0.0) + prior += log_prior_a(_a, _a_beta_a, _a_beta_b); + if (_b_gamma_s > 0.0 && _b_gamma_c > 0.0) + prior += log_prior_b(_b, _b_gamma_c, _b_gamma_s); + + return prior; +} + +template +double +PYP::log_prior_a(double a, double beta_a, double beta_b) { + return lbetadist(a, beta_a, beta_b); +} + +template +double +PYP::log_prior_b(double b, double gamma_c, double gamma_s) { + return lgammadist(b, gamma_c, gamma_s); +} + +template +long double PYP::lbetadist(long double x, long double alpha, long double beta) { + assert(x > 0); + assert(x < 1); + assert(alpha > 0); + assert(beta > 0); + return (alpha-1)*log(x)+(beta-1)*log(1-x)+lgamma(alpha+beta)-lgamma(alpha)-lgamma(beta); +//boost::math::lgamma +} + +template +long double PYP::lgammadist(long double x, long double alpha, long double beta) { + assert(alpha > 0); + assert(beta > 0); + return (alpha-1)*log(x) - alpha*log(beta) - x/beta - lgamma(alpha); +} + + +template +void +PYP::resample_prior() { + for (int num_its=5; num_its >= 0; --num_its) { + resample_prior_b(); + resample_prior_a(); + } + resample_prior_b(); +} + +template +void +PYP::resample_prior_b() { + if (_total_tables == 0) + return; + + int niterations = 10; // number of resampling iterations + //std::cerr << "\n## resample_prior_b(), initial a = " << _a << ", b = " << _b << std::endl; + resample_b_type b_log_prob(_total_customers, _total_tables, _a, _b_gamma_c, _b_gamma_s); + //_b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits::infinity(), + _b = slice_sampler1d(b_log_prob, _b, random, (double) 0.0, std::numeric_limits::infinity(), + (double) 0.0, niterations, 100*niterations); + //std::cerr << "\n## resample_prior_b(), final a = " << _a << ", b = " << _b << std::endl; +} + +template +void +PYP::resample_prior_a() { + if (_total_tables == 0) + return; + + int niterations = 10; + //std::cerr << "\n## Initial a = " << _a << ", b = " << _b << std::endl; + resample_a_type a_log_prob(_total_customers, _total_tables, _b, _a_beta_a, _a_beta_b, _dish_tables); + //_a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits::min(), + _a = slice_sampler1d(a_log_prob, _a, random, std::numeric_limits::min(), + (double) 1.0, (double) 0.0, niterations, 100*niterations); +} + +#endif diff --git a/gi/pyp-topics/src/mpi-train-contexts.cc b/gi/pyp-topics/src/mpi-train-contexts.cc new file mode 100644 index 00000000..6309fe93 --- /dev/null +++ b/gi/pyp-topics/src/mpi-train-contexts.cc @@ -0,0 +1,169 @@ +// STL +#include +#include +#include +#include + +// Boost +#include +#include +#include + +// Local +#include "mpi-pyp-topics.hh" +#include "corpus.hh" +#include "contexts_corpus.hh" +#include "gzstream.hh" + +static const char *REVISION = "$Rev: 170 $"; + +// Namespaces +using namespace boost; +using namespace boost::program_options; +using namespace std; + +int main(int argc, char **argv) +{ + cout << "Pitman Yor topic models: Copyright 2010 Phil Blunsom\n"; + cout << REVISION << '\n' <(), "config file specifying additional command line options") + ; + options_description config_options("Allowed options"); + config_options.add_options() + ("help,h", "print help message") + ("data,d", value(), "file containing the documents and context terms") + ("topics,t", value()->default_value(50), "number of topics") + ("document-topics-out,o", value(), "file to write the document topics to") + ("default-topics-out", value(), "file to write default term topic assignments.") + ("topic-words-out,w", value(), "file to write the topic word distribution to") + ("samples,s", value()->default_value(10), "number of sampling passes through the data") + ("backoff-type", value(), "backoff type: none|simple") +// ("filter-singleton-contexts", "filter singleton contexts") + ("hierarchical-topics", "Use a backoff hierarchical PYP as the P0 for the document topics distribution.") + ("freq-cutoff-start", value()->default_value(0), "initial frequency cutoff.") + ("freq-cutoff-end", value()->default_value(0), "final frequency cutoff.") + ("freq-cutoff-interval", value()->default_value(0), "number of iterations between frequency decrement.") + ("max-threads", value()->default_value(1), "maximum number of simultaneous threads allowed") + ("max-contexts-per-document", value()->default_value(0), "Only sample the n most frequent contexts for a document.") + ; + + cmdline_specific.add(config_options); + + store(parse_command_line(argc, argv, cmdline_specific), vm); + notify(vm); + + if (vm.count("config") > 0) { + ifstream config(vm["config"].as().c_str()); + store(parse_config_file(config, config_options), vm); + } + + if (vm.count("help")) { + cout << cmdline_specific << "\n"; + return 1; + } + } + //////////////////////////////////////////////////////////////////////////////////////////// + + if (!vm.count("data")) { + cerr << "Please specify a file containing the data." << endl; + return 1; + } + + // seed the random number generator: 0 = automatic, specify value otherwise + unsigned long seed = 0; + PYPTopics model(vm["topics"].as(), vm.count("hierarchical-topics"), seed, vm["max-threads"].as()); + + // read the data + BackoffGenerator* backoff_gen=0; + if (vm.count("backoff-type")) { + if (vm["backoff-type"].as() == "none") { + backoff_gen = 0; + } + else if (vm["backoff-type"].as() == "simple") { + backoff_gen = new SimpleBackoffGenerator(); + } + else { + cerr << "Backoff type (--backoff-type) must be one of none|simple." <(), backoff_gen, /*vm.count("filter-singleton-contexts")*/ false); + model.set_backoff(contexts_corpus.backoff_index()); + + if (backoff_gen) + delete backoff_gen; + + // train the sampler + model.sample_corpus(contexts_corpus, vm["samples"].as(), + vm["freq-cutoff-start"].as(), + vm["freq-cutoff-end"].as(), + vm["freq-cutoff-interval"].as(), + vm["max-contexts-per-document"].as()); + + if (vm.count("document-topics-out")) { + ogzstream documents_out(vm["document-topics-out"].as().c_str()); + + int document_id=0; + map all_terms; + for (Corpus::const_iterator corpusIt=contexts_corpus.begin(); + corpusIt != contexts_corpus.end(); ++corpusIt, ++document_id) { + vector unique_terms; + for (Document::const_iterator docIt=corpusIt->begin(); + docIt != corpusIt->end(); ++docIt) { + if (unique_terms.empty() || *docIt != unique_terms.back()) + unique_terms.push_back(*docIt); + // increment this terms frequency + pair::iterator,bool> insert_result = all_terms.insert(make_pair(*docIt,1)); + if (!insert_result.second) + all_terms[*docIt] = all_terms[*docIt] + 1; + //insert_result.first++; + } + documents_out << contexts_corpus.key(document_id) << '\t'; + documents_out << model.max(document_id) << " " << corpusIt->size() << " ||| "; + for (std::vector::const_iterator termIt=unique_terms.begin(); + termIt != unique_terms.end(); ++termIt) { + if (termIt != unique_terms.begin()) + documents_out << " ||| "; + vector strings = contexts_corpus.context2string(*termIt); + copy(strings.begin(), strings.end(),ostream_iterator(documents_out, " ")); + documents_out << "||| C=" << model.max(document_id, *termIt); + + } + documents_out <().c_str()); + default_topics << model.max_topic() <::const_iterator termIt=all_terms.begin(); termIt != all_terms.end(); ++termIt) { + vector strings = contexts_corpus.context2string(termIt->first); + default_topics << model.max(-1, termIt->first) << " ||| " << termIt->second << " ||| "; + copy(strings.begin(), strings.end(),ostream_iterator(default_topics, " ")); + default_topics <().c_str()); + model.print_topic_terms(topics_out); + topics_out.close(); + } + + cout < void PYPTopics::sample_corpus(const Corpus& corpus, int samples, int freq_cutoff_start, int freq_cutoff_end, - int freq_cutoff_interval) { + int freq_cutoff_interval, + int max_contexts_per_document) { Timer timer; if (!m_backoff.get()) { @@ -54,11 +55,12 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, // sample a new_topic //int new_topic = (topic_counter % m_num_topics); int freq = corpus.context_count(term); - int new_topic = (freq > frequency_cutoff ? (document_id % m_num_topics) : -1); + int new_topic = -1; + if (freq > frequency_cutoff + && (!max_contexts_per_document || term_index < max_contexts_per_document)) { + new_topic = document_id % m_num_topics; - // add the new topic to the PYPs - m_corpus_topics[document_id][term_index] = new_topic; - if (freq > frequency_cutoff) { + // add the new topic to the PYPs increment(term, new_topic); if (m_use_topic_pyp) { @@ -69,6 +71,8 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, } else m_document_pyps[document_id].increment(new_topic, m_topic_p0); } + + m_corpus_topics[document_id][term_index] = new_topic; } } std::cerr << " Initialized in " << timer.Elapsed() << " seconds\n"; @@ -94,6 +98,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, // Randomize the corpus indexing array int tmp; + int processed_terms=0; for (int i = corpus.num_documents()-1; i > 0; --i) { //i+1 since j \in [0,i] but rnd() \in [0,1) @@ -106,8 +111,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, // for each document in the corpus int document_id; - for (int i=0; i max_contexts_per_document) + break; + Term term = *docIt; int freq = corpus.context_count(term); if (freq < frequency_cutoff) continue; + processed_terms++; + // remove the prevous topic from the PYPs int current_topic = m_corpus_topics[document_id][term_index]; // a negative label mean that term hasn't been sampled yet @@ -150,6 +159,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, std::cerr << "."; std::cerr.flush(); } } + std::cerr << " ||| sampled " << processed_terms << " terms."; if (curr_sample != 0 && curr_sample % 10 == 0) { std::cerr << " ||| time=" << (timer.Elapsed() / 10.0) << " sec/sample" << std::endl; diff --git a/gi/pyp-topics/src/pyp-topics.hh b/gi/pyp-topics/src/pyp-topics.hh index 32d2d939..ebe951b1 100644 --- a/gi/pyp-topics/src/pyp-topics.hh +++ b/gi/pyp-topics/src/pyp-topics.hh @@ -30,7 +30,8 @@ public: void sample_corpus(const Corpus& corpus, int samples, int freq_cutoff_start=0, int freq_cutoff_end=0, - int freq_cutoff_interval=0); + int freq_cutoff_interval=0, + int max_contexts_per_document=0); int sample(const DocumentId& doc, const Term& term); std::pair max(const DocumentId& doc, const Term& term) const; diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh index 7a520d6a..dc47244b 100644 --- a/gi/pyp-topics/src/pyp.hh +++ b/gi/pyp-topics/src/pyp.hh @@ -4,6 +4,7 @@ #include #include #include +//#include #include #include @@ -11,6 +12,7 @@ #include "log_add.h" #include "slice-sampler.h" +#include "mt19937ar.h" // // Pitman-Yor process with customer and table tracking @@ -18,12 +20,17 @@ template > class PYP : protected std::tr1::unordered_map +//class PYP : protected google::sparse_hash_map { public: using std::tr1::unordered_map::const_iterator; using std::tr1::unordered_map::iterator; using std::tr1::unordered_map::begin; using std::tr1::unordered_map::end; +// using google::sparse_hash_map::const_iterator; +// using google::sparse_hash_map::iterator; +// using google::sparse_hash_map::begin; +// using google::sparse_hash_map::end; PYP(double a, double b, unsigned long seed = 0, Hash hash=Hash()); @@ -39,6 +46,7 @@ public: int num_customers() const { return _total_customers; } int num_types() const { return std::tr1::unordered_map::size(); } + //int num_types() const { return google::sparse_hash_map::size(); } bool empty() const { return _total_customers == 0; } double log_prob(Dish dish, double log_p0) const; @@ -79,6 +87,7 @@ private: std::map table_histogram; // num customers at table -> number tables }; typedef std::tr1::unordered_map DishTableType; + //typedef google::sparse_hash_map DishTableType; DishTableType _dish_tables; int _total_customers, _total_tables; @@ -86,11 +95,10 @@ private: typedef boost::uniform_real<> uni_dist_type; typedef boost::variate_generator gen_type; - uni_dist_type uni_dist; - base_generator_type rng; //this gets the seed - gen_type rnd; //instantiate: rnd(rng, uni_dist) +// uni_dist_type uni_dist; +// base_generator_type rng; //this gets the seed +// gen_type rnd; //instantiate: rnd(rng, uni_dist) //call: rnd() generates uniform on [0,1) - // Function objects for calculating the parts of the log_prob for // the parameters a and b @@ -132,12 +140,12 @@ private: } }; - /* lbetadist() returns the log probability density of x under a Beta(alpha,beta) + /* lbetadist() returns the log probability density of x under a Beta(alpha,beta) * distribution. - copied from Mark Johnson's gammadist.c */ - static long double lbetadist(long double x, long double alpha, long double beta); + static long double lbetadist(long double x, long double alpha, long double beta); - /* lgammadist() returns the log probability density of x under a Gamma(alpha,beta) + /* lgammadist() returns the log probability density of x under a Gamma(alpha,beta) * distribution - copied from Mark Johnson's gammadist.c */ static long double lgammadist(long double x, long double alpha, long double beta); @@ -146,13 +154,15 @@ private: template PYP::PYP(double a, double b, unsigned long seed, Hash) -: std::tr1::unordered_map(), _a(a), _b(b), +: std::tr1::unordered_map(10), _a(a), _b(b), +//: google::sparse_hash_map(10), _a(a), _b(b), _a_beta_a(1), _a_beta_b(1), _b_gamma_s(1), _b_gamma_c(1), //_a_beta_a(1), _a_beta_b(1), _b_gamma_s(10), _b_gamma_c(0.1), - _total_customers(0), _total_tables(0), - uni_dist(0,1), rng(seed == 0 ? (unsigned long)this : seed), rnd(rng, uni_dist) + _total_customers(0), _total_tables(0)//, + //uni_dist(0,1), rng(seed == 0 ? (unsigned long)this : seed), rnd(rng, uni_dist) { // std::cerr << "\t##PYP::PYP(a=" << _a << ",b=" << _b << ")" << std::endl; + //set_deleted_key(-std::numeric_limits::max()); } template @@ -235,7 +245,8 @@ PYP::increment(Dish dish, double p0) { assert (pshare >= 0.0); //assert (pnew > 0.0); - if (rnd() < pnew / (pshare + pnew)) { + //if (rnd() < pnew / (pshare + pnew)) { + if (mt_genrand_res53() < pnew / (pshare + pnew)) { // assign to a new table tc.tables += 1; tc.table_histogram[1] += 1; @@ -245,7 +256,8 @@ PYP::increment(Dish dish, double p0) { else { // randomly assign to an existing table // remove constant denominator from inner loop - double r = rnd() * (c - _a*t); + //double r = rnd() * (c - _a*t); + double r = mt_genrand_res53() * (c - _a*t); for (std::map::iterator hit = tc.table_histogram.begin(); hit != tc.table_histogram.end(); ++hit) { @@ -266,6 +278,7 @@ PYP::increment(Dish dish, double p0) { } std::tr1::unordered_map::operator[](dish) += 1; + //google::sparse_hash_map::operator[](dish) += 1; _total_customers += 1; return delta; @@ -276,6 +289,7 @@ int PYP::count(Dish dish) const { typename std::tr1::unordered_map::const_iterator + //typename google::sparse_hash_map::const_iterator dcit = find(dish); if (dcit != end()) return dcit->second; @@ -288,6 +302,7 @@ int PYP::decrement(Dish dish) { typename std::tr1::unordered_map::iterator dcit = find(dish); + //typename google::sparse_hash_map::iterator dcit = find(dish); if (dcit == end()) { std::cerr << dish << std::endl; assert(false); @@ -296,6 +311,7 @@ PYP::decrement(Dish dish) int delta = 0; typename std::tr1::unordered_map::iterator dtit = _dish_tables.find(dish); + //typename google::sparse_hash_map::iterator dtit = _dish_tables.find(dish); if (dtit == _dish_tables.end()) { std::cerr << dish << std::endl; assert(false); @@ -307,7 +323,8 @@ PYP::decrement(Dish dish) //std::cerr << "count: " << count(dish) << " "; //std::cerr << "tables: " << tc.tables << "\n"; - double r = rnd() * count(dish); + //double r = rnd() * count(dish); + double r = mt_genrand_res53() * count(dish); for (std::map::iterator hit = tc.table_histogram.begin(); hit != tc.table_histogram.end(); ++hit) { @@ -357,6 +374,7 @@ int PYP::num_tables(Dish dish) const { typename std::tr1::unordered_map::const_iterator + //typename google::sparse_hash_map::const_iterator dtit = _dish_tables.find(dish); //assert(dtit != _dish_tables.end()); @@ -379,6 +397,7 @@ PYP::debug_info(std::ostream& os) const { int hists = 0, tables = 0; for (typename std::tr1::unordered_map::const_iterator + //for (typename google::sparse_hash_map::const_iterator dtit = _dish_tables.begin(); dtit != _dish_tables.end(); ++dtit) { hists += dtit->second.table_histogram.size(); @@ -409,6 +428,7 @@ void PYP::clear() { this->std::tr1::unordered_map::clear(); + //this->google::sparse_hash_map::clear(); _dish_tables.clear(); _total_tables = _total_customers = 0; } @@ -509,7 +529,8 @@ PYP::resample_prior_b() { int niterations = 10; // number of resampling iterations //std::cerr << "\n## resample_prior_b(), initial a = " << _a << ", b = " << _b << std::endl; resample_b_type b_log_prob(_total_customers, _total_tables, _a, _b_gamma_c, _b_gamma_s); - _b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits::infinity(), + //_b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits::infinity(), + _b = slice_sampler1d(b_log_prob, _b, random, (double) 0.0, std::numeric_limits::infinity(), (double) 0.0, niterations, 100*niterations); //std::cerr << "\n## resample_prior_b(), final a = " << _a << ", b = " << _b << std::endl; } @@ -523,7 +544,8 @@ PYP::resample_prior_a() { int niterations = 10; //std::cerr << "\n## Initial a = " << _a << ", b = " << _b << std::endl; resample_a_type a_log_prob(_total_customers, _total_tables, _b, _a_beta_a, _a_beta_b, _dish_tables); - _a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits::min(), + //_a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits::min(), + _a = slice_sampler1d(a_log_prob, _a, random, std::numeric_limits::min(), (double) 1.0, (double) 0.0, niterations, 100*niterations); } diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc index 0a48d3d9..5e98d02f 100644 --- a/gi/pyp-topics/src/train-contexts.cc +++ b/gi/pyp-topics/src/train-contexts.cc @@ -54,6 +54,7 @@ int main(int argc, char **argv) ("freq-cutoff-end", value()->default_value(0), "final frequency cutoff.") ("freq-cutoff-interval", value()->default_value(0), "number of iterations between frequency decrement.") ("max-threads", value()->default_value(1), "maximum number of simultaneous threads allowed") + ("max-contexts-per-document", value()->default_value(0), "Only sample the n most frequent contexts for a document.") ("num-jobs", value()->default_value(1), "allows finer control over parallelization") ; @@ -110,7 +111,8 @@ int main(int argc, char **argv) model.sample_corpus(contexts_corpus, vm["samples"].as(), vm["freq-cutoff-start"].as(), vm["freq-cutoff-end"].as(), - vm["freq-cutoff-interval"].as()); + vm["freq-cutoff-interval"].as(), + vm["max-contexts-per-document"].as()); if (vm.count("document-topics-out")) { ogzstream documents_out(vm["document-topics-out"].as().c_str()); -- cgit v1.2.3