summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src
diff options
context:
space:
mode:
authorphilblunsom <philblunsom@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-14 22:42:35 +0000
committerphilblunsom <philblunsom@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-14 22:42:35 +0000
commitdc6e2c9c453a76f0bb3dfbca4471e763cc8af1e7 (patch)
tree5a67b276a6f7936f1c6b414b554397cc88064de8 /gi/pyp-topics/src
parent851207fcbd93c4a0857e0d7719007abe9c82dae1 (diff)
starting an mpi version.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@253 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/pyp-topics/src')
-rw-r--r--gi/pyp-topics/src/Makefile.am9
-rw-r--r--gi/pyp-topics/src/contexts_corpus.cc22
-rw-r--r--gi/pyp-topics/src/contexts_lexer.h2
-rw-r--r--gi/pyp-topics/src/contexts_lexer.l5
-rw-r--r--gi/pyp-topics/src/mpi-pyp-topics.cc431
-rw-r--r--gi/pyp-topics/src/mpi-pyp-topics.hh97
-rw-r--r--gi/pyp-topics/src/mpi-pyp.hh552
-rw-r--r--gi/pyp-topics/src/mpi-train-contexts.cc169
-rw-r--r--gi/pyp-topics/src/pyp-topics.cc24
-rw-r--r--gi/pyp-topics/src/pyp-topics.hh3
-rw-r--r--gi/pyp-topics/src/pyp.hh52
-rw-r--r--gi/pyp-topics/src/train-contexts.cc4
12 files changed, 1334 insertions, 36 deletions
diff --git a/gi/pyp-topics/src/Makefile.am b/gi/pyp-topics/src/Makefile.am
index abfc95ac..a3a30acd 100644
--- a/gi/pyp-topics/src/Makefile.am
+++ b/gi/pyp-topics/src/Makefile.am
@@ -1,13 +1,16 @@
-bin_PROGRAMS = pyp-topics-train pyp-contexts-train
+bin_PROGRAMS = pyp-topics-train pyp-contexts-train mpi-pyp-contexts-train
contexts_lexer.cc: contexts_lexer.l
$(LEX) -s -CF -8 -o$@ $<
-pyp_topics_train_SOURCES = corpus.cc gzstream.cc pyp-topics.cc train.cc contexts_lexer.cc contexts_corpus.cc
+pyp_topics_train_SOURCES = mt19937ar.c corpus.cc gzstream.cc pyp-topics.cc train.cc contexts_lexer.cc contexts_corpus.cc
pyp_topics_train_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
-pyp_contexts_train_SOURCES = corpus.cc gzstream.cc pyp-topics.cc contexts_lexer.cc contexts_corpus.cc train-contexts.cc
+pyp_contexts_train_SOURCES = mt19937ar.c corpus.cc gzstream.cc pyp-topics.cc contexts_lexer.cc contexts_corpus.cc train-contexts.cc
pyp_contexts_train_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
+mpi_pyp_contexts_train_SOURCES = mt19937ar.c corpus.cc gzstream.cc mpi-pyp-topics.cc contexts_lexer.cc contexts_corpus.cc mpi-train-contexts.cc
+mpi_pyp_contexts_train_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
+
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -funroll-loops
diff --git a/gi/pyp-topics/src/contexts_corpus.cc b/gi/pyp-topics/src/contexts_corpus.cc
index 280b2976..26d5718a 100644
--- a/gi/pyp-topics/src/contexts_corpus.cc
+++ b/gi/pyp-topics/src/contexts_corpus.cc
@@ -28,9 +28,12 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void*
Document* doc(new Document());
//cout << "READ: " << new_contexts.phrase << "\t";
- for (int i=0; i < new_contexts.contexts.size(); ++i) {
+ for (int i=0; i < new_contexts.counts.size(); ++i) {
int cache_word_count = corpus_ptr->m_dict.max();
- string context_str = corpus_ptr->m_dict.toString(new_contexts.contexts[i]);
+
+ //string context_str = corpus_ptr->m_dict.toString(new_contexts.contexts[i]);
+ int context_index = new_contexts.counts.at(i).first;
+ string context_str = corpus_ptr->m_dict.toString(new_contexts.contexts[context_index]);
// filter out singleton contexts
//if (!counts->empty()) {
@@ -45,7 +48,8 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void*
corpus_ptr->m_num_types++;
}
- int count = new_contexts.counts[i];
+ //int count = new_contexts.counts[i];
+ int count = new_contexts.counts.at(i).second;
for (int j=0; j<count; ++j)
doc->push_back(id);
corpus_ptr->m_num_terms += count;
@@ -54,7 +58,8 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void*
if (backoff_gen) {
int order = 1;
WordID backoff_id = id;
- ContextsLexer::Context backedoff_context = new_contexts.contexts[i];
+ //ContextsLexer::Context backedoff_context = new_contexts.contexts[i];
+ ContextsLexer::Context backedoff_context = new_contexts.contexts[context_index];
while (true) {
if (!corpus_ptr->m_backoff->has_backoff(backoff_id)) {
//cerr << "Backing off from " << corpus_ptr->m_dict.Convert(backoff_id) << " to ";
@@ -96,10 +101,13 @@ void filter_callback(const ContextsLexer::PhraseContextsType& new_contexts, void
map<string,int>* context_counts = (static_cast<map<string,int>*>(extra));
- for (int i=0; i < new_contexts.contexts.size(); ++i) {
- int count = new_contexts.counts[i];
+ for (int i=0; i < new_contexts.counts.size(); ++i) {
+ int context_index = new_contexts.counts.at(i).first;
+ int count = new_contexts.counts.at(i).second;
+ //int count = new_contexts.counts[i];
pair<map<string,int>::iterator,bool> result
- = context_counts->insert(make_pair(Dict::toString(new_contexts.contexts[i]),count));
+ = context_counts->insert(make_pair(Dict::toString(new_contexts.contexts[context_index]),count));
+ //= context_counts->insert(make_pair(Dict::toString(new_contexts.contexts[i]),count));
if (!result.second)
result.first->second += count;
}
diff --git a/gi/pyp-topics/src/contexts_lexer.h b/gi/pyp-topics/src/contexts_lexer.h
index f9a1b21c..1b79c6fd 100644
--- a/gi/pyp-topics/src/contexts_lexer.h
+++ b/gi/pyp-topics/src/contexts_lexer.h
@@ -12,7 +12,7 @@ struct ContextsLexer {
struct PhraseContextsType {
std::string phrase;
std::vector<Context> contexts;
- std::vector<int> counts;
+ std::vector< std::pair<int,int> > counts;
};
typedef void (*ContextsCallback)(const PhraseContextsType& new_contexts, void* extra);
diff --git a/gi/pyp-topics/src/contexts_lexer.l b/gi/pyp-topics/src/contexts_lexer.l
index 61189a73..7a5d9460 100644
--- a/gi/pyp-topics/src/contexts_lexer.l
+++ b/gi/pyp-topics/src/contexts_lexer.l
@@ -6,6 +6,7 @@
#include <sstream>
#include <cstring>
#include <cassert>
+#include <algorithm>
int lex_line = 0;
std::istream* contextslex_stream = NULL;
@@ -69,7 +70,7 @@ INT [\-+]?[0-9]+|inf|[\-+]inf
<COUNT>[ \t]+ { ; }
<COUNT>C={INT} {
- current_contexts.counts.push_back(atoi(yytext+2));
+ current_contexts.counts.push_back(std::make_pair(current_contexts.counts.size(), atoi(yytext+2)));
BEGIN(COUNT_END);
}
<COUNT>. {
@@ -84,6 +85,8 @@ INT [\-+]?[0-9]+|inf|[\-+]inf
<COUNT_END>\n {
//std::cerr << "READ:" << current_contexts.phrase << " with " << current_contexts.contexts.size()
// << " contexts, and " << current_contexts.counts.size() << " counts." << std::endl;
+ std::sort(current_contexts.counts.rbegin(), current_contexts.counts.rend());
+
contexts_callback(current_contexts, contexts_callback_extra);
current_contexts.phrase.clear();
current_contexts.contexts.clear();
diff --git a/gi/pyp-topics/src/mpi-pyp-topics.cc b/gi/pyp-topics/src/mpi-pyp-topics.cc
new file mode 100644
index 00000000..d2daad4f
--- /dev/null
+++ b/gi/pyp-topics/src/mpi-pyp-topics.cc
@@ -0,0 +1,431 @@
+#include "timing.h"
+#include "mpi-pyp-topics.hh"
+
+//#include <boost/date_time/posix_time/posix_time_types.hpp>
+void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
+ int freq_cutoff_start, int freq_cutoff_end,
+ int freq_cutoff_interval,
+ int max_contexts_per_document) {
+ Timer timer;
+
+ if (!m_backoff.get()) {
+ m_word_pyps.clear();
+ m_word_pyps.push_back(PYPs());
+ }
+
+ std::cerr << "\n Training with " << m_word_pyps.size()-1 << " backoff level"
+ << (m_word_pyps.size()==2 ? ":" : "s:") << std::endl;
+
+ for (int i=0; i<(int)m_word_pyps.size(); ++i)
+ {
+ m_word_pyps.at(i).reserve(m_num_topics);
+ for (int j=0; j<m_num_topics; ++j)
+ m_word_pyps.at(i).push_back(new PYP<int>(0.5, 1.0, m_seed));
+ }
+ std::cerr << std::endl;
+
+ m_document_pyps.reserve(corpus.num_documents());
+ for (int j=0; j<corpus.num_documents(); ++j)
+ m_document_pyps.push_back(new PYP<int>(0.5, 1.0, m_seed));
+
+ m_topic_p0 = 1.0/m_num_topics;
+ m_term_p0 = 1.0/corpus.num_types();
+ m_backoff_p0 = 1.0/corpus.num_documents();
+
+ std::cerr << " Documents: " << corpus.num_documents() << " Terms: "
+ << corpus.num_types() << std::endl;
+
+ int frequency_cutoff = freq_cutoff_start;
+ std::cerr << " Context frequency cutoff set to " << frequency_cutoff << std::endl;
+
+ timer.Reset();
+ // Initialisation pass
+ int document_id=0, topic_counter=0;
+ for (Corpus::const_iterator corpusIt=corpus.begin();
+ corpusIt != corpus.end(); ++corpusIt, ++document_id) {
+ m_corpus_topics.push_back(DocumentTopics(corpusIt->size(), 0));
+
+ int term_index=0;
+ for (Document::const_iterator docIt=corpusIt->begin();
+ docIt != corpusIt->end(); ++docIt, ++term_index) {
+ topic_counter++;
+ Term term = *docIt;
+
+ // sample a new_topic
+ //int new_topic = (topic_counter % m_num_topics);
+ int freq = corpus.context_count(term);
+ int new_topic = -1;
+ if (freq > frequency_cutoff
+ && (!max_contexts_per_document || term_index < max_contexts_per_document)) {
+ new_topic = document_id % m_num_topics;
+
+ // add the new topic to the PYPs
+ increment(term, new_topic);
+
+ if (m_use_topic_pyp) {
+ F p0 = m_topic_pyp.prob(new_topic, m_topic_p0);
+ int table_delta = m_document_pyps[document_id].increment(new_topic, p0);
+ if (table_delta)
+ m_topic_pyp.increment(new_topic, m_topic_p0);
+ }
+ else m_document_pyps[document_id].increment(new_topic, m_topic_p0);
+ }
+
+ m_corpus_topics[document_id][term_index] = new_topic;
+ }
+ }
+ std::cerr << " Initialized in " << timer.Elapsed() << " seconds\n";
+
+ int* randomDocIndices = new int[corpus.num_documents()];
+ for (int i = 0; i < corpus.num_documents(); ++i)
+ randomDocIndices[i] = i;
+
+ // Sampling phase
+ for (int curr_sample=0; curr_sample < samples; ++curr_sample) {
+ if (freq_cutoff_interval > 0 && curr_sample != 1
+ && curr_sample % freq_cutoff_interval == 1
+ && frequency_cutoff > freq_cutoff_end) {
+ frequency_cutoff--;
+ std::cerr << "\n Context frequency cutoff set to " << frequency_cutoff << std::endl;
+ }
+
+ std::cerr << "\n -- Sample " << curr_sample << " "; std::cerr.flush();
+
+ // Randomize the corpus indexing array
+ int tmp;
+ int processed_terms=0;
+ for (int i = corpus.num_documents()-1; i > 0; --i)
+ {
+ //i+1 since j \in [0,i] but rnd() \in [0,1)
+ int j = (int)(rnd() * (i+1));
+ assert(j >= 0 && j <= i);
+ tmp = randomDocIndices[i];
+ randomDocIndices[i] = randomDocIndices[j];
+ randomDocIndices[j] = tmp;
+ }
+
+ // for each document in the corpus
+ int document_id;
+ for (int i=0; i<corpus.num_documents(); ++i) {
+ document_id = randomDocIndices[i];
+
+ // for each term in the document
+ int term_index=0;
+ Document::const_iterator docEnd = corpus.at(document_id).end();
+ for (Document::const_iterator docIt=corpus.at(document_id).begin();
+ docIt != docEnd; ++docIt, ++term_index) {
+ if (max_contexts_per_document && term_index > max_contexts_per_document)
+ break;
+
+ Term term = *docIt;
+ int freq = corpus.context_count(term);
+ if (freq < frequency_cutoff)
+ continue;
+
+ processed_terms++;
+
+ // remove the prevous topic from the PYPs
+ int current_topic = m_corpus_topics[document_id][term_index];
+ // a negative label mean that term hasn't been sampled yet
+ if (current_topic >= 0) {
+ decrement(term, current_topic);
+
+ int table_delta = m_document_pyps[document_id].decrement(current_topic);
+ if (m_use_topic_pyp && table_delta < 0)
+ m_topic_pyp.decrement(current_topic);
+ }
+
+ // sample a new_topic
+ int new_topic = sample(document_id, term);
+
+ // add the new topic to the PYPs
+ m_corpus_topics[document_id][term_index] = new_topic;
+ increment(term, new_topic);
+
+ if (m_use_topic_pyp) {
+ F p0 = m_topic_pyp.prob(new_topic, m_topic_p0);
+ int table_delta = m_document_pyps[document_id].increment(new_topic, p0);
+ if (table_delta)
+ m_topic_pyp.increment(new_topic, m_topic_p0);
+ }
+ else m_document_pyps[document_id].increment(new_topic, m_topic_p0);
+ }
+ if (document_id && document_id % 10000 == 0) {
+ std::cerr << "."; std::cerr.flush();
+ }
+ }
+ std::cerr << " ||| sampled " << processed_terms << " terms.";
+
+ if (curr_sample != 0 && curr_sample % 10 == 0) {
+ std::cerr << " ||| time=" << (timer.Elapsed() / 10.0) << " sec/sample" << std::endl;
+ timer.Reset();
+ std::cerr << " ... Resampling hyperparameters (" << max_threads << " threads)"; std::cerr.flush();
+
+ // resample the hyperparamters
+ F log_p=0.0;
+ for (std::vector<PYPs>::iterator levelIt=m_word_pyps.begin();
+ levelIt != m_word_pyps.end(); ++levelIt) {
+ for (PYPs::iterator pypIt=levelIt->begin();
+ pypIt != levelIt->end(); ++pypIt) {
+ pypIt->resample_prior();
+ log_p += pypIt->log_restaurant_prob();
+ }
+ }
+
+ WorkerPtrVect workers;
+ for (int i = 0; i < max_threads; ++i)
+ {
+ JobReturnsF job = boost::bind(&PYPTopics::hresample_docs, this, max_threads, i);
+ workers.push_back(new SimpleResampleWorker(job));
+ }
+
+ WorkerPtrVect::iterator workerIt;
+ for (workerIt = workers.begin(); workerIt != workers.end(); ++workerIt)
+ {
+ //std::cerr << "Retrieving worker result.."; std::cerr.flush();
+ F wresult = workerIt->getResult(); //blocks until worker done
+ log_p += wresult;
+ //std::cerr << ".. got " << wresult << std::endl; std::cerr.flush();
+
+ }
+
+ if (m_use_topic_pyp) {
+ m_topic_pyp.resample_prior();
+ log_p += m_topic_pyp.log_restaurant_prob();
+ }
+
+ std::cerr.precision(10);
+ std::cerr << " ||| LLH=" << log_p << " ||| resampling time=" << timer.Elapsed() << " sec" << std::endl;
+ timer.Reset();
+
+ int k=0;
+ std::cerr << "Topics distribution: ";
+ std::cerr.precision(2);
+ for (PYPs::iterator pypIt=m_word_pyps.front().begin();
+ pypIt != m_word_pyps.front().end(); ++pypIt, ++k) {
+ if (k % 5 == 0) std::cerr << std::endl << '\t';
+ std::cerr << "<" << k << ":" << pypIt->num_customers() << ","
+ << pypIt->num_types() << "," << m_topic_pyp.prob(k, m_topic_p0) << "> ";
+ }
+ std::cerr.precision(4);
+ std::cerr << std::endl;
+ }
+ }
+ delete [] randomDocIndices;
+}
+
+PYPTopics::F PYPTopics::hresample_docs(int num_threads, int thread_id)
+{
+ int resample_counter=0;
+ F log_p = 0.0;
+ PYPs::iterator pypIt = m_document_pyps.begin();
+ PYPs::iterator end = m_document_pyps.end();
+ pypIt += thread_id;
+// std::cerr << thread_id << " started " << std::endl; std::cerr.flush();
+
+ while (pypIt < end)
+ {
+ pypIt->resample_prior();
+ log_p += pypIt->log_restaurant_prob();
+ if (resample_counter++ % 5000 == 0) {
+ std::cerr << "."; std::cerr.flush();
+ }
+ pypIt += num_threads;
+ }
+// std::cerr << thread_id << " did " << resample_counter << " with answer " << log_p << std::endl; std::cerr.flush();
+
+ return log_p;
+}
+
+//PYPTopics::F PYPTopics::hresample_topics()
+//{
+// F log_p = 0.0;
+// for (std::vector<PYPs>::iterator levelIt=m_word_pyps.begin();
+// levelIt != m_word_pyps.end(); ++levelIt) {
+// for (PYPs::iterator pypIt=levelIt->begin();
+// pypIt != levelIt->end(); ++pypIt) {
+//
+// pypIt->resample_prior();
+// log_p += pypIt->log_restaurant_prob();
+// }
+// }
+// //std::cerr << "topicworker has answer " << log_p << std::endl; std::cerr.flush();
+//
+// return log_p;
+//}
+
+void PYPTopics::decrement(const Term& term, int topic, int level) {
+ //std::cerr << "PYPTopics::decrement(" << term << "," << topic << "," << level << ")" << std::endl;
+ m_word_pyps.at(level).at(topic).decrement(term);
+ if (m_backoff.get()) {
+ Term backoff_term = (*m_backoff)[term];
+ if (!m_backoff->is_null(backoff_term))
+ decrement(backoff_term, topic, level+1);
+ }
+}
+
+void PYPTopics::increment(const Term& term, int topic, int level) {
+ //std::cerr << "PYPTopics::increment(" << term << "," << topic << "," << level << ")" << std::endl;
+ m_word_pyps.at(level).at(topic).increment(term, word_pyps_p0(term, topic, level));
+
+ if (m_backoff.get()) {
+ Term backoff_term = (*m_backoff)[term];
+ if (!m_backoff->is_null(backoff_term))
+ increment(backoff_term, topic, level+1);
+ }
+}
+
+int PYPTopics::sample(const DocumentId& doc, const Term& term) {
+ // First pass: collect probs
+ F sum=0.0;
+ std::vector<F> sums;
+ for (int k=0; k<m_num_topics; ++k) {
+ F p_w_k = prob(term, k);
+
+ F topic_prob = m_topic_p0;
+ if (m_use_topic_pyp) topic_prob = m_topic_pyp.prob(k, m_topic_p0);
+
+ //F p_k_d = m_document_pyps[doc].prob(k, topic_prob);
+ F p_k_d = m_document_pyps[doc].unnormalised_prob(k, topic_prob);
+
+ sum += (p_w_k*p_k_d);
+ sums.push_back(sum);
+ }
+ // Second pass: sample a topic
+ F cutoff = rnd() * sum;
+ for (int k=0; k<m_num_topics; ++k) {
+ if (cutoff <= sums[k])
+ return k;
+ }
+ assert(false);
+}
+
+PYPTopics::F PYPTopics::word_pyps_p0(const Term& term, int topic, int level) const {
+ //for (int i=0; i<level+1; ++i) std::cerr << " ";
+ //std::cerr << "PYPTopics::word_pyps_p0(" << term << "," << topic << "," << level << ")" << std::endl;
+
+ F p0 = m_term_p0;
+ if (m_backoff.get()) {
+ //static F fudge=m_backoff_p0; // TODO
+
+ Term backoff_term = (*m_backoff)[term];
+ if (!m_backoff->is_null(backoff_term)) {
+ assert (level < m_backoff->order());
+ p0 = (1.0/(double)m_backoff->terms_at_level(level))*prob(backoff_term, topic, level+1);
+ }
+ else
+ p0 = m_term_p0;
+ }
+ //for (int i=0; i<level+1; ++i) std::cerr << " ";
+ //std::cerr << "PYPTopics::word_pyps_p0(" << term << "," << topic << "," << level << ") = " << p0 << std::endl;
+ return p0;
+}
+
+PYPTopics::F PYPTopics::prob(const Term& term, int topic, int level) const {
+ //for (int i=0; i<level+1; ++i) std::cerr << " ";
+ //std::cerr << "PYPTopics::prob(" << term << "," << topic << "," << level << " " << factor << ")" << std::endl;
+
+ F p0 = word_pyps_p0(term, topic, level);
+ F p_w_k = m_word_pyps.at(level).at(topic).prob(term, p0);
+
+ //for (int i=0; i<level+1; ++i) std::cerr << " ";
+ //std::cerr << "PYPTopics::prob(" << term << "," << topic << "," << level << ") = " << p_w_k << std::endl;
+
+ return p_w_k;
+}
+
+int PYPTopics::max_topic() const {
+ if (!m_use_topic_pyp)
+ return -1;
+
+ F current_max=0.0;
+ int current_topic=-1;
+ for (int k=0; k<m_num_topics; ++k) {
+ F prob = m_topic_pyp.prob(k, m_topic_p0);
+ if (prob > current_max) {
+ current_max = prob;
+ current_topic = k;
+ }
+ }
+ assert(current_topic >= 0);
+ return current_topic;
+}
+
+int PYPTopics::max(const DocumentId& doc) const {
+ //std::cerr << "PYPTopics::max(" << doc << "," << term << ")" << std::endl;
+ // collect probs
+ F current_max=0.0;
+ int current_topic=-1;
+ for (int k=0; k<m_num_topics; ++k) {
+ //F p_w_k = prob(term, k);
+
+ F topic_prob = m_topic_p0;
+ if (m_use_topic_pyp)
+ topic_prob = m_topic_pyp.prob(k, m_topic_p0);
+
+ F prob = 0;
+ if (doc < 0) prob = topic_prob;
+ else prob = m_document_pyps[doc].prob(k, topic_prob);
+
+ if (prob > current_max) {
+ current_max = prob;
+ current_topic = k;
+ }
+ }
+ assert(current_topic >= 0);
+ return current_topic;
+}
+
+int PYPTopics::max(const DocumentId& doc, const Term& term) const {
+ //std::cerr << "PYPTopics::max(" << doc << "," << term << ")" << std::endl;
+ // collect probs
+ F current_max=0.0;
+ int current_topic=-1;
+ for (int k=0; k<m_num_topics; ++k) {
+ F p_w_k = prob(term, k);
+
+ F topic_prob = m_topic_p0;
+ if (m_use_topic_pyp)
+ topic_prob = m_topic_pyp.prob(k, m_topic_p0);
+
+ F p_k_d = 0;
+ if (doc < 0) p_k_d = topic_prob;
+ else p_k_d = m_document_pyps[doc].prob(k, topic_prob);
+
+ F prob = (p_w_k*p_k_d);
+ if (prob > current_max) {
+ current_max = prob;
+ current_topic = k;
+ }
+ }
+ assert(current_topic >= 0);
+ return current_topic;
+}
+
+std::ostream& PYPTopics::print_document_topics(std::ostream& out) const {
+ for (CorpusTopics::const_iterator corpusIt=m_corpus_topics.begin();
+ corpusIt != m_corpus_topics.end(); ++corpusIt) {
+ int term_index=0;
+ for (DocumentTopics::const_iterator docIt=corpusIt->begin();
+ docIt != corpusIt->end(); ++docIt, ++term_index) {
+ if (term_index) out << " ";
+ out << *docIt;
+ }
+ out << std::endl;
+ }
+ return out;
+}
+
+std::ostream& PYPTopics::print_topic_terms(std::ostream& out) const {
+ for (PYPs::const_iterator pypsIt=m_word_pyps.front().begin();
+ pypsIt != m_word_pyps.front().end(); ++pypsIt) {
+ int term_index=0;
+ for (PYP<int>::const_iterator termIt=pypsIt->begin();
+ termIt != pypsIt->end(); ++termIt, ++term_index) {
+ if (term_index) out << " ";
+ out << termIt->first << ":" << termIt->second;
+ }
+ out << std::endl;
+ }
+ return out;
+}
diff --git a/gi/pyp-topics/src/mpi-pyp-topics.hh b/gi/pyp-topics/src/mpi-pyp-topics.hh
new file mode 100644
index 00000000..d978c7a1
--- /dev/null
+++ b/gi/pyp-topics/src/mpi-pyp-topics.hh
@@ -0,0 +1,97 @@
+#ifndef PYP_TOPICS_HH
+#define PYP_TOPICS_HH
+
+#include <vector>
+#include <iostream>
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <boost/random/uniform_real.hpp>
+#include <boost/random/variate_generator.hpp>
+#include <boost/random/mersenne_twister.hpp>
+
+#include "mpi-pyp.hh"
+#include "corpus.hh"
+#include "workers.hh"
+
+class PYPTopics {
+public:
+ typedef std::vector<int> DocumentTopics;
+ typedef std::vector<DocumentTopics> CorpusTopics;
+ typedef double F;
+
+public:
+ PYPTopics(int num_topics, bool use_topic_pyp=false, unsigned long seed = 0,
+ int max_threads = 1)
+ : m_num_topics(num_topics), m_word_pyps(1),
+ m_topic_pyp(0.5,1.0,seed), m_use_topic_pyp(use_topic_pyp),
+ m_seed(seed),
+ uni_dist(0,1), rng(seed == 0 ? (unsigned long)this : seed),
+ rnd(rng, uni_dist), max_threads(max_threads) {}
+
+ void sample_corpus(const Corpus& corpus, int samples,
+ int freq_cutoff_start=0, int freq_cutoff_end=0,
+ int freq_cutoff_interval=0,
+ int max_contexts_per_document=0);
+
+ int sample(const DocumentId& doc, const Term& term);
+ int max(const DocumentId& doc, const Term& term) const;
+ int max(const DocumentId& doc) const;
+ int max_topic() const;
+
+ void set_backoff(const std::string& filename) {
+ m_backoff.reset(new TermBackoff);
+ m_backoff->read(filename);
+ m_word_pyps.clear();
+ m_word_pyps.resize(m_backoff->order(), PYPs());
+ }
+ void set_backoff(TermBackoffPtr backoff) {
+ m_backoff = backoff;
+ m_word_pyps.clear();
+ m_word_pyps.resize(m_backoff->order(), PYPs());
+ }
+
+ F prob(const Term& term, int topic, int level=0) const;
+ void decrement(const Term& term, int topic, int level=0);
+ void increment(const Term& term, int topic, int level=0);
+
+ std::ostream& print_document_topics(std::ostream& out) const;
+ std::ostream& print_topic_terms(std::ostream& out) const;
+
+private:
+ F word_pyps_p0(const Term& term, int topic, int level) const;
+
+ int m_num_topics;
+ F m_term_p0, m_topic_p0, m_backoff_p0;
+
+ CorpusTopics m_corpus_topics;
+ typedef boost::ptr_vector< PYP<int> > PYPs;
+ PYPs m_document_pyps;
+ std::vector<PYPs> m_word_pyps;
+ PYP<int> m_topic_pyp;
+ bool m_use_topic_pyp;
+
+ unsigned long m_seed;
+
+ typedef boost::mt19937 base_generator_type;
+ typedef boost::uniform_real<> uni_dist_type;
+ typedef boost::variate_generator<base_generator_type&, uni_dist_type> gen_type;
+
+ uni_dist_type uni_dist;
+ base_generator_type rng; //this gets the seed
+ gen_type rnd; //instantiate: rnd(rng, uni_dist)
+ //call: rnd() generates uniform on [0,1)
+
+ typedef boost::function<F()> JobReturnsF;
+ typedef SimpleWorker<JobReturnsF, F> SimpleResampleWorker;
+ typedef boost::ptr_vector<SimpleResampleWorker> WorkerPtrVect;
+
+ F hresample_docs(int num_threads, int thread_id);
+
+// F hresample_topics();
+
+ int max_threads;
+
+ TermBackoffPtr m_backoff;
+};
+
+#endif // PYP_TOPICS_HH
diff --git a/gi/pyp-topics/src/mpi-pyp.hh b/gi/pyp-topics/src/mpi-pyp.hh
new file mode 100644
index 00000000..dc47244b
--- /dev/null
+++ b/gi/pyp-topics/src/mpi-pyp.hh
@@ -0,0 +1,552 @@
+#ifndef _pyp_hh
+#define _pyp_hh
+
+#include <math.h>
+#include <map>
+#include <tr1/unordered_map>
+//#include <google/sparse_hash_map>
+
+#include <boost/random/uniform_real.hpp>
+#include <boost/random/variate_generator.hpp>
+#include <boost/random/mersenne_twister.hpp>
+
+#include "log_add.h"
+#include "slice-sampler.h"
+#include "mt19937ar.h"
+
+//
+// Pitman-Yor process with customer and table tracking
+//
+
+template <typename Dish, typename Hash=std::tr1::hash<Dish> >
+class PYP : protected std::tr1::unordered_map<Dish, int, Hash>
+//class PYP : protected google::sparse_hash_map<Dish, int, Hash>
+{
+public:
+ using std::tr1::unordered_map<Dish,int>::const_iterator;
+ using std::tr1::unordered_map<Dish,int>::iterator;
+ using std::tr1::unordered_map<Dish,int>::begin;
+ using std::tr1::unordered_map<Dish,int>::end;
+// using google::sparse_hash_map<Dish,int>::const_iterator;
+// using google::sparse_hash_map<Dish,int>::iterator;
+// using google::sparse_hash_map<Dish,int>::begin;
+// using google::sparse_hash_map<Dish,int>::end;
+
+ PYP(double a, double b, unsigned long seed = 0, Hash hash=Hash());
+
+ int increment(Dish d, double p0);
+ int decrement(Dish d);
+
+ // lookup functions
+ int count(Dish d) const;
+ double prob(Dish dish, double p0) const;
+ double prob(Dish dish, double dcd, double dca,
+ double dtd, double dta, double p0) const;
+ double unnormalised_prob(Dish dish, double p0) const;
+
+ int num_customers() const { return _total_customers; }
+ int num_types() const { return std::tr1::unordered_map<Dish,int>::size(); }
+ //int num_types() const { return google::sparse_hash_map<Dish,int>::size(); }
+ bool empty() const { return _total_customers == 0; }
+
+ double log_prob(Dish dish, double log_p0) const;
+ // nb. d* are NOT logs
+ double log_prob(Dish dish, double dcd, double dca,
+ double dtd, double dta, double log_p0) const;
+
+ int num_tables(Dish dish) const;
+ int num_tables() const;
+
+ double a() const { return _a; }
+ void set_a(double a) { _a = a; }
+
+ double b() const { return _b; }
+ void set_b(double b) { _b = b; }
+
+ void clear();
+ std::ostream& debug_info(std::ostream& os) const;
+
+ double log_restaurant_prob() const;
+ double log_prior() const;
+ static double log_prior_a(double a, double beta_a, double beta_b);
+ static double log_prior_b(double b, double gamma_c, double gamma_s);
+
+ void resample_prior();
+ void resample_prior_a();
+ void resample_prior_b();
+
+private:
+ double _a, _b; // parameters of the Pitman-Yor distribution
+ double _a_beta_a, _a_beta_b; // parameters of Beta prior on a
+ double _b_gamma_s, _b_gamma_c; // parameters of Gamma prior on b
+
+ struct TableCounter
+ {
+ TableCounter() : tables(0) {};
+ int tables;
+ std::map<int, int> table_histogram; // num customers at table -> number tables
+ };
+ typedef std::tr1::unordered_map<Dish, TableCounter, Hash> DishTableType;
+ //typedef google::sparse_hash_map<Dish, TableCounter, Hash> DishTableType;
+ DishTableType _dish_tables;
+ int _total_customers, _total_tables;
+
+ typedef boost::mt19937 base_generator_type;
+ typedef boost::uniform_real<> uni_dist_type;
+ typedef boost::variate_generator<base_generator_type&, uni_dist_type> gen_type;
+
+// uni_dist_type uni_dist;
+// base_generator_type rng; //this gets the seed
+// gen_type rnd; //instantiate: rnd(rng, uni_dist)
+ //call: rnd() generates uniform on [0,1)
+
+ // Function objects for calculating the parts of the log_prob for
+ // the parameters a and b
+ struct resample_a_type {
+ int n, m; double b, a_beta_a, a_beta_b;
+ const DishTableType& dish_tables;
+ resample_a_type(int n, int m, double b, double a_beta_a,
+ double a_beta_b, const DishTableType& dish_tables)
+ : n(n), m(m), b(b), a_beta_a(a_beta_a), a_beta_b(a_beta_b), dish_tables(dish_tables) {}
+
+ double operator() (double proposed_a) const {
+ double log_prior = log_prior_a(proposed_a, a_beta_a, a_beta_b);
+ double log_prob = 0.0;
+ double lgamma1a = lgamma(1.0 - proposed_a);
+ for (typename DishTableType::const_iterator dish_it=dish_tables.begin(); dish_it != dish_tables.end(); ++dish_it)
+ for (std::map<int, int>::const_iterator table_it=dish_it->second.table_histogram.begin();
+ table_it !=dish_it->second.table_histogram.end(); ++table_it)
+ log_prob += (table_it->second * (lgamma(table_it->first - proposed_a) - lgamma1a));
+
+ log_prob += (proposed_a == 0.0 ? (m-1.0)*log(b)
+ : ((m-1.0)*log(proposed_a) + lgamma((m-1.0) + b/proposed_a) - lgamma(b/proposed_a)));
+ assert(std::isfinite(log_prob));
+ return log_prob + log_prior;
+ }
+ };
+
+ struct resample_b_type {
+ int n, m; double a, b_gamma_c, b_gamma_s;
+ resample_b_type(int n, int m, double a, double b_gamma_c, double b_gamma_s)
+ : n(n), m(m), a(a), b_gamma_c(b_gamma_c), b_gamma_s(b_gamma_s) {}
+
+ double operator() (double proposed_b) const {
+ double log_prior = log_prior_b(proposed_b, b_gamma_c, b_gamma_s);
+ double log_prob = 0.0;
+ log_prob += (a == 0.0 ? (m-1.0)*log(proposed_b)
+ : ((m-1.0)*log(a) + lgamma((m-1.0) + proposed_b/a) - lgamma(proposed_b/a)));
+ log_prob += (lgamma(1.0+proposed_b) - lgamma(n+proposed_b));
+ return log_prob + log_prior;
+ }
+ };
+
+ /* lbetadist() returns the log probability density of x under a Beta(alpha,beta)
+ * distribution. - copied from Mark Johnson's gammadist.c
+ */
+ static long double lbetadist(long double x, long double alpha, long double beta);
+
+ /* lgammadist() returns the log probability density of x under a Gamma(alpha,beta)
+ * distribution - copied from Mark Johnson's gammadist.c
+ */
+ static long double lgammadist(long double x, long double alpha, long double beta);
+
+};
+
+template <typename Dish, typename Hash>
+PYP<Dish,Hash>::PYP(double a, double b, unsigned long seed, Hash)
+: std::tr1::unordered_map<Dish, int, Hash>(10), _a(a), _b(b),
+//: google::sparse_hash_map<Dish, int, Hash>(10), _a(a), _b(b),
+ _a_beta_a(1), _a_beta_b(1), _b_gamma_s(1), _b_gamma_c(1),
+ //_a_beta_a(1), _a_beta_b(1), _b_gamma_s(10), _b_gamma_c(0.1),
+ _total_customers(0), _total_tables(0)//,
+ //uni_dist(0,1), rng(seed == 0 ? (unsigned long)this : seed), rnd(rng, uni_dist)
+{
+// std::cerr << "\t##PYP<Dish,Hash>::PYP(a=" << _a << ",b=" << _b << ")" << std::endl;
+ //set_deleted_key(-std::numeric_limits<Dish>::max());
+}
+
+template <typename Dish, typename Hash>
+double
+PYP<Dish,Hash>::prob(Dish dish, double p0) const
+{
+ int c = count(dish), t = num_tables(dish);
+ double r = num_tables() * _a + _b;
+ //std::cerr << "\t\t\t\tPYP<Dish,Hash>::prob(" << dish << "," << p0 << ") c=" << c << " r=" << r << std::endl;
+ if (c > 0)
+ return (c - _a * t + r * p0) / (num_customers() + _b);
+ else
+ return r * p0 / (num_customers() + _b);
+}
+
+template <typename Dish, typename Hash>
+double
+PYP<Dish,Hash>::unnormalised_prob(Dish dish, double p0) const
+{
+ int c = count(dish), t = num_tables(dish);
+ double r = num_tables() * _a + _b;
+ if (c > 0) return (c - _a * t + r * p0);
+ else return r * p0;
+}
+
+template <typename Dish, typename Hash>
+double
+PYP<Dish,Hash>::prob(Dish dish, double dcd, double dca,
+ double dtd, double dta, double p0)
+const
+{
+ int c = count(dish) + dcd, t = num_tables(dish) + dtd;
+ double r = (num_tables() + dta) * _a + _b;
+ if (c > 0)
+ return (c - _a * t + r * p0) / (num_customers() + dca + _b);
+ else
+ return r * p0 / (num_customers() + dca + _b);
+}
+
+template <typename Dish, typename Hash>
+double
+PYP<Dish,Hash>::log_prob(Dish dish, double log_p0) const
+{
+ using std::log;
+ int c = count(dish), t = num_tables(dish);
+ double r = log(num_tables() * _a + b);
+ if (c > 0)
+ return Log<double>::add(log(c - _a * t), r + log_p0)
+ - log(num_customers() + _b);
+ else
+ return r + log_p0 - log(num_customers() + b);
+}
+
+template <typename Dish, typename Hash>
+double
+PYP<Dish,Hash>::log_prob(Dish dish, double dcd, double dca,
+ double dtd, double dta, double log_p0)
+const
+{
+ using std::log;
+ int c = count(dish) + dcd, t = num_tables(dish) + dtd;
+ double r = log((num_tables() + dta) * _a + b);
+ if (c > 0)
+ return Log<double>::add(log(c - _a * t), r + log_p0)
+ - log(num_customers() + dca + _b);
+ else
+ return r + log_p0 - log(num_customers() + dca + b);
+}
+
+template <typename Dish, typename Hash>
+int
+PYP<Dish,Hash>::increment(Dish dish, double p0) {
+ int delta = 0;
+ TableCounter &tc = _dish_tables[dish];
+
+ // seated on a new or existing table?
+ int c = count(dish), t = num_tables(dish), T = num_tables();
+ double pshare = (c > 0) ? (c - _a*t) : 0.0;
+ double pnew = (_b + _a*T) * p0;
+ assert (pshare >= 0.0);
+ //assert (pnew > 0.0);
+
+ //if (rnd() < pnew / (pshare + pnew)) {
+ if (mt_genrand_res53() < pnew / (pshare + pnew)) {
+ // assign to a new table
+ tc.tables += 1;
+ tc.table_histogram[1] += 1;
+ _total_tables += 1;
+ delta = 1;
+ }
+ else {
+ // randomly assign to an existing table
+ // remove constant denominator from inner loop
+ //double r = rnd() * (c - _a*t);
+ double r = mt_genrand_res53() * (c - _a*t);
+ for (std::map<int,int>::iterator
+ hit = tc.table_histogram.begin();
+ hit != tc.table_histogram.end(); ++hit) {
+ r -= ((hit->first - _a) * hit->second);
+ if (r <= 0) {
+ tc.table_histogram[hit->first+1] += 1;
+ hit->second -= 1;
+ if (hit->second == 0)
+ tc.table_histogram.erase(hit);
+ break;
+ }
+ }
+ if (r > 0) {
+ std::cerr << r << " " << c << " " << _a << " " << t << std::endl;
+ assert(false);
+ }
+ delta = 0;
+ }
+
+ std::tr1::unordered_map<Dish,int,Hash>::operator[](dish) += 1;
+ //google::sparse_hash_map<Dish,int,Hash>::operator[](dish) += 1;
+ _total_customers += 1;
+
+ return delta;
+}
+
+template <typename Dish, typename Hash>
+int
+PYP<Dish,Hash>::count(Dish dish) const
+{
+ typename std::tr1::unordered_map<Dish, int>::const_iterator
+ //typename google::sparse_hash_map<Dish, int>::const_iterator
+ dcit = find(dish);
+ if (dcit != end())
+ return dcit->second;
+ else
+ return 0;
+}
+
+template <typename Dish, typename Hash>
+int
+PYP<Dish,Hash>::decrement(Dish dish)
+{
+ typename std::tr1::unordered_map<Dish, int>::iterator dcit = find(dish);
+ //typename google::sparse_hash_map<Dish, int>::iterator dcit = find(dish);
+ if (dcit == end()) {
+ std::cerr << dish << std::endl;
+ assert(false);
+ }
+
+ int delta = 0;
+
+ typename std::tr1::unordered_map<Dish, TableCounter>::iterator dtit = _dish_tables.find(dish);
+ //typename google::sparse_hash_map<Dish, TableCounter>::iterator dtit = _dish_tables.find(dish);
+ if (dtit == _dish_tables.end()) {
+ std::cerr << dish << std::endl;
+ assert(false);
+ }
+ TableCounter &tc = dtit->second;
+
+ //std::cerr << "\tdecrement for " << dish << "\n";
+ //std::cerr << "\tBEFORE histogram: " << tc.table_histogram << " ";
+ //std::cerr << "count: " << count(dish) << " ";
+ //std::cerr << "tables: " << tc.tables << "\n";
+
+ //double r = rnd() * count(dish);
+ double r = mt_genrand_res53() * count(dish);
+ for (std::map<int,int>::iterator hit = tc.table_histogram.begin();
+ hit != tc.table_histogram.end(); ++hit)
+ {
+ //r -= (hit->first - _a) * hit->second;
+ r -= (hit->first) * hit->second;
+ if (r <= 0)
+ {
+ if (hit->first > 1)
+ tc.table_histogram[hit->first-1] += 1;
+ else
+ {
+ delta = -1;
+ tc.tables -= 1;
+ _total_tables -= 1;
+ }
+
+ hit->second -= 1;
+ if (hit->second == 0) tc.table_histogram.erase(hit);
+ break;
+ }
+ }
+ if (r > 0) {
+ std::cerr << r << " " << count(dish) << " " << _a << " " << num_tables(dish) << std::endl;
+ assert(false);
+ }
+
+ // remove the customer
+ dcit->second -= 1;
+ _total_customers -= 1;
+ assert(dcit->second >= 0);
+ if (dcit->second == 0) {
+ erase(dcit);
+ _dish_tables.erase(dtit);
+ //std::cerr << "\tAFTER histogram: Empty\n";
+ }
+ else {
+ //std::cerr << "\tAFTER histogram: " << _dish_tables[dish].table_histogram << " ";
+ //std::cerr << "count: " << count(dish) << " ";
+ //std::cerr << "tables: " << _dish_tables[dish].tables << "\n";
+ }
+
+ return delta;
+}
+
+template <typename Dish, typename Hash>
+int
+PYP<Dish,Hash>::num_tables(Dish dish) const
+{
+ typename std::tr1::unordered_map<Dish, TableCounter, Hash>::const_iterator
+ //typename google::sparse_hash_map<Dish, TableCounter, Hash>::const_iterator
+ dtit = _dish_tables.find(dish);
+
+ //assert(dtit != _dish_tables.end());
+ if (dtit == _dish_tables.end())
+ return 0;
+
+ return dtit->second.tables;
+}
+
+template <typename Dish, typename Hash>
+int
+PYP<Dish,Hash>::num_tables() const
+{
+ return _total_tables;
+}
+
+template <typename Dish, typename Hash>
+std::ostream&
+PYP<Dish,Hash>::debug_info(std::ostream& os) const
+{
+ int hists = 0, tables = 0;
+ for (typename std::tr1::unordered_map<Dish, TableCounter, Hash>::const_iterator
+ //for (typename google::sparse_hash_map<Dish, TableCounter, Hash>::const_iterator
+ dtit = _dish_tables.begin(); dtit != _dish_tables.end(); ++dtit)
+ {
+ hists += dtit->second.table_histogram.size();
+ tables += dtit->second.tables;
+
+ assert(dtit->second.tables > 0);
+ assert(!dtit->second.table_histogram.empty());
+
+ for (std::map<int,int>::const_iterator
+ hit = dtit->second.table_histogram.begin();
+ hit != dtit->second.table_histogram.end(); ++hit)
+ assert(hit->second > 0);
+ }
+
+ os << "restaurant has "
+ << _total_customers << " customers; "
+ << _total_tables << " tables; "
+ << tables << " tables'; "
+ << num_types() << " dishes; "
+ << _dish_tables.size() << " dishes'; and "
+ << hists << " histogram entries\n";
+
+ return os;
+}
+
+template <typename Dish, typename Hash>
+void
+PYP<Dish,Hash>::clear()
+{
+ this->std::tr1::unordered_map<Dish,int,Hash>::clear();
+ //this->google::sparse_hash_map<Dish,int,Hash>::clear();
+ _dish_tables.clear();
+ _total_tables = _total_customers = 0;
+}
+
+// log_restaurant_prob returns the log probability of the PYP table configuration.
+// Excludes Hierarchical P0 term which must be calculated separately.
+template <typename Dish, typename Hash>
+double
+PYP<Dish,Hash>::log_restaurant_prob() const {
+ if (_total_customers < 1)
+ return (double)0.0;
+
+ double log_prob = 0.0;
+ double lgamma1a = lgamma(1.0-_a);
+
+ //std::cerr << "-------------------\n" << std::endl;
+ for (typename DishTableType::const_iterator dish_it=_dish_tables.begin();
+ dish_it != _dish_tables.end(); ++dish_it) {
+ for (std::map<int, int>::const_iterator table_it=dish_it->second.table_histogram.begin();
+ table_it !=dish_it->second.table_histogram.end(); ++table_it) {
+ log_prob += (table_it->second * (lgamma(table_it->first - _a) - lgamma1a));
+ //std::cerr << "|" << dish_it->first->parent << " --> " << dish_it->first->rhs << " " << table_it->first << " " << table_it->second << " " << log_prob;
+ }
+ }
+ //std::cerr << std::endl;
+
+ log_prob += (_a == (double)0.0 ? (_total_tables-1.0)*log(_b) : (_total_tables-1.0)*log(_a) + lgamma((_total_tables-1.0) + _b/_a) - lgamma(_b/_a));
+ //std::cerr << "\t\t" << log_prob << std::endl;
+ log_prob += (lgamma(1.0 + _b) - lgamma(_total_customers + _b));
+
+ //std::cerr << _total_customers << " " << _total_tables << " " << log_prob << " " << log_prior() << std::endl;
+ //std::cerr << _a << " " << _b << std::endl;
+ if (!std::isfinite(log_prob)) {
+ assert(false);
+ }
+ //return log_prob;
+ return log_prob + log_prior();
+}
+
+template <typename Dish, typename Hash>
+double
+PYP<Dish,Hash>::log_prior() const {
+ double prior = 0.0;
+ if (_a_beta_a > 0.0 && _a_beta_b > 0.0 && _a > 0.0)
+ prior += log_prior_a(_a, _a_beta_a, _a_beta_b);
+ if (_b_gamma_s > 0.0 && _b_gamma_c > 0.0)
+ prior += log_prior_b(_b, _b_gamma_c, _b_gamma_s);
+
+ return prior;
+}
+
+template <typename Dish, typename Hash>
+double
+PYP<Dish,Hash>::log_prior_a(double a, double beta_a, double beta_b) {
+ return lbetadist(a, beta_a, beta_b);
+}
+
+template <typename Dish, typename Hash>
+double
+PYP<Dish,Hash>::log_prior_b(double b, double gamma_c, double gamma_s) {
+ return lgammadist(b, gamma_c, gamma_s);
+}
+
+template <typename Dish, typename Hash>
+long double PYP<Dish,Hash>::lbetadist(long double x, long double alpha, long double beta) {
+ assert(x > 0);
+ assert(x < 1);
+ assert(alpha > 0);
+ assert(beta > 0);
+ return (alpha-1)*log(x)+(beta-1)*log(1-x)+lgamma(alpha+beta)-lgamma(alpha)-lgamma(beta);
+//boost::math::lgamma
+}
+
+template <typename Dish, typename Hash>
+long double PYP<Dish,Hash>::lgammadist(long double x, long double alpha, long double beta) {
+ assert(alpha > 0);
+ assert(beta > 0);
+ return (alpha-1)*log(x) - alpha*log(beta) - x/beta - lgamma(alpha);
+}
+
+
+template <typename Dish, typename Hash>
+void
+PYP<Dish,Hash>::resample_prior() {
+ for (int num_its=5; num_its >= 0; --num_its) {
+ resample_prior_b();
+ resample_prior_a();
+ }
+ resample_prior_b();
+}
+
+template <typename Dish, typename Hash>
+void
+PYP<Dish,Hash>::resample_prior_b() {
+ if (_total_tables == 0)
+ return;
+
+ int niterations = 10; // number of resampling iterations
+ //std::cerr << "\n## resample_prior_b(), initial a = " << _a << ", b = " << _b << std::endl;
+ resample_b_type b_log_prob(_total_customers, _total_tables, _a, _b_gamma_c, _b_gamma_s);
+ //_b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits<double>::infinity(),
+ _b = slice_sampler1d(b_log_prob, _b, random, (double) 0.0, std::numeric_limits<double>::infinity(),
+ (double) 0.0, niterations, 100*niterations);
+ //std::cerr << "\n## resample_prior_b(), final a = " << _a << ", b = " << _b << std::endl;
+}
+
+template <typename Dish, typename Hash>
+void
+PYP<Dish,Hash>::resample_prior_a() {
+ if (_total_tables == 0)
+ return;
+
+ int niterations = 10;
+ //std::cerr << "\n## Initial a = " << _a << ", b = " << _b << std::endl;
+ resample_a_type a_log_prob(_total_customers, _total_tables, _b, _a_beta_a, _a_beta_b, _dish_tables);
+ //_a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits<double>::min(),
+ _a = slice_sampler1d(a_log_prob, _a, random, std::numeric_limits<double>::min(),
+ (double) 1.0, (double) 0.0, niterations, 100*niterations);
+}
+
+#endif
diff --git a/gi/pyp-topics/src/mpi-train-contexts.cc b/gi/pyp-topics/src/mpi-train-contexts.cc
new file mode 100644
index 00000000..6309fe93
--- /dev/null
+++ b/gi/pyp-topics/src/mpi-train-contexts.cc
@@ -0,0 +1,169 @@
+// STL
+#include <iostream>
+#include <fstream>
+#include <algorithm>
+#include <iterator>
+
+// Boost
+#include <boost/program_options/parsers.hpp>
+#include <boost/program_options/variables_map.hpp>
+#include <boost/scoped_ptr.hpp>
+
+// Local
+#include "mpi-pyp-topics.hh"
+#include "corpus.hh"
+#include "contexts_corpus.hh"
+#include "gzstream.hh"
+
+static const char *REVISION = "$Rev: 170 $";
+
+// Namespaces
+using namespace boost;
+using namespace boost::program_options;
+using namespace std;
+
+int main(int argc, char **argv)
+{
+ cout << "Pitman Yor topic models: Copyright 2010 Phil Blunsom\n";
+ cout << REVISION << '\n' <<endl;
+
+ ////////////////////////////////////////////////////////////////////////////////////////////
+ // Command line processing
+ variables_map vm;
+
+ // Command line processing
+ {
+ options_description cmdline_specific("Command line specific options");
+ cmdline_specific.add_options()
+ ("help,h", "print help message")
+ ("config,c", value<string>(), "config file specifying additional command line options")
+ ;
+ options_description config_options("Allowed options");
+ config_options.add_options()
+ ("help,h", "print help message")
+ ("data,d", value<string>(), "file containing the documents and context terms")
+ ("topics,t", value<int>()->default_value(50), "number of topics")
+ ("document-topics-out,o", value<string>(), "file to write the document topics to")
+ ("default-topics-out", value<string>(), "file to write default term topic assignments.")
+ ("topic-words-out,w", value<string>(), "file to write the topic word distribution to")
+ ("samples,s", value<int>()->default_value(10), "number of sampling passes through the data")
+ ("backoff-type", value<string>(), "backoff type: none|simple")
+// ("filter-singleton-contexts", "filter singleton contexts")
+ ("hierarchical-topics", "Use a backoff hierarchical PYP as the P0 for the document topics distribution.")
+ ("freq-cutoff-start", value<int>()->default_value(0), "initial frequency cutoff.")
+ ("freq-cutoff-end", value<int>()->default_value(0), "final frequency cutoff.")
+ ("freq-cutoff-interval", value<int>()->default_value(0), "number of iterations between frequency decrement.")
+ ("max-threads", value<int>()->default_value(1), "maximum number of simultaneous threads allowed")
+ ("max-contexts-per-document", value<int>()->default_value(0), "Only sample the n most frequent contexts for a document.")
+ ;
+
+ cmdline_specific.add(config_options);
+
+ store(parse_command_line(argc, argv, cmdline_specific), vm);
+ notify(vm);
+
+ if (vm.count("config") > 0) {
+ ifstream config(vm["config"].as<string>().c_str());
+ store(parse_config_file(config, config_options), vm);
+ }
+
+ if (vm.count("help")) {
+ cout << cmdline_specific << "\n";
+ return 1;
+ }
+ }
+ ////////////////////////////////////////////////////////////////////////////////////////////
+
+ if (!vm.count("data")) {
+ cerr << "Please specify a file containing the data." << endl;
+ return 1;
+ }
+
+ // seed the random number generator: 0 = automatic, specify value otherwise
+ unsigned long seed = 0;
+ PYPTopics model(vm["topics"].as<int>(), vm.count("hierarchical-topics"), seed, vm["max-threads"].as<int>());
+
+ // read the data
+ BackoffGenerator* backoff_gen=0;
+ if (vm.count("backoff-type")) {
+ if (vm["backoff-type"].as<std::string>() == "none") {
+ backoff_gen = 0;
+ }
+ else if (vm["backoff-type"].as<std::string>() == "simple") {
+ backoff_gen = new SimpleBackoffGenerator();
+ }
+ else {
+ cerr << "Backoff type (--backoff-type) must be one of none|simple." <<endl;
+ return(1);
+ }
+ }
+
+ ContextsCorpus contexts_corpus;
+ contexts_corpus.read_contexts(vm["data"].as<string>(), backoff_gen, /*vm.count("filter-singleton-contexts")*/ false);
+ model.set_backoff(contexts_corpus.backoff_index());
+
+ if (backoff_gen)
+ delete backoff_gen;
+
+ // train the sampler
+ model.sample_corpus(contexts_corpus, vm["samples"].as<int>(),
+ vm["freq-cutoff-start"].as<int>(),
+ vm["freq-cutoff-end"].as<int>(),
+ vm["freq-cutoff-interval"].as<int>(),
+ vm["max-contexts-per-document"].as<int>());
+
+ if (vm.count("document-topics-out")) {
+ ogzstream documents_out(vm["document-topics-out"].as<string>().c_str());
+
+ int document_id=0;
+ map<int,int> all_terms;
+ for (Corpus::const_iterator corpusIt=contexts_corpus.begin();
+ corpusIt != contexts_corpus.end(); ++corpusIt, ++document_id) {
+ vector<int> unique_terms;
+ for (Document::const_iterator docIt=corpusIt->begin();
+ docIt != corpusIt->end(); ++docIt) {
+ if (unique_terms.empty() || *docIt != unique_terms.back())
+ unique_terms.push_back(*docIt);
+ // increment this terms frequency
+ pair<map<int,int>::iterator,bool> insert_result = all_terms.insert(make_pair(*docIt,1));
+ if (!insert_result.second)
+ all_terms[*docIt] = all_terms[*docIt] + 1;
+ //insert_result.first++;
+ }
+ documents_out << contexts_corpus.key(document_id) << '\t';
+ documents_out << model.max(document_id) << " " << corpusIt->size() << " ||| ";
+ for (std::vector<int>::const_iterator termIt=unique_terms.begin();
+ termIt != unique_terms.end(); ++termIt) {
+ if (termIt != unique_terms.begin())
+ documents_out << " ||| ";
+ vector<std::string> strings = contexts_corpus.context2string(*termIt);
+ copy(strings.begin(), strings.end(),ostream_iterator<std::string>(documents_out, " "));
+ documents_out << "||| C=" << model.max(document_id, *termIt);
+
+ }
+ documents_out <<endl;
+ }
+ documents_out.close();
+
+ if (vm.count("default-topics-out")) {
+ ofstream default_topics(vm["default-topics-out"].as<string>().c_str());
+ default_topics << model.max_topic() <<endl;
+ for (std::map<int,int>::const_iterator termIt=all_terms.begin(); termIt != all_terms.end(); ++termIt) {
+ vector<std::string> strings = contexts_corpus.context2string(termIt->first);
+ default_topics << model.max(-1, termIt->first) << " ||| " << termIt->second << " ||| ";
+ copy(strings.begin(), strings.end(),ostream_iterator<std::string>(default_topics, " "));
+ default_topics <<endl;
+ }
+ }
+ }
+
+ if (vm.count("topic-words-out")) {
+ ogzstream topics_out(vm["topic-words-out"].as<string>().c_str());
+ model.print_topic_terms(topics_out);
+ topics_out.close();
+ }
+
+ cout <<endl;
+
+ return 0;
+}
diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc
index 3614fb22..2cc1fc79 100644
--- a/gi/pyp-topics/src/pyp-topics.cc
+++ b/gi/pyp-topics/src/pyp-topics.cc
@@ -4,7 +4,8 @@
//#include <boost/date_time/posix_time/posix_time_types.hpp>
void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
int freq_cutoff_start, int freq_cutoff_end,
- int freq_cutoff_interval) {
+ int freq_cutoff_interval,
+ int max_contexts_per_document) {
Timer timer;
if (!m_backoff.get()) {
@@ -54,11 +55,12 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
// sample a new_topic
//int new_topic = (topic_counter % m_num_topics);
int freq = corpus.context_count(term);
- int new_topic = (freq > frequency_cutoff ? (document_id % m_num_topics) : -1);
+ int new_topic = -1;
+ if (freq > frequency_cutoff
+ && (!max_contexts_per_document || term_index < max_contexts_per_document)) {
+ new_topic = document_id % m_num_topics;
- // add the new topic to the PYPs
- m_corpus_topics[document_id][term_index] = new_topic;
- if (freq > frequency_cutoff) {
+ // add the new topic to the PYPs
increment(term, new_topic);
if (m_use_topic_pyp) {
@@ -69,6 +71,8 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
}
else m_document_pyps[document_id].increment(new_topic, m_topic_p0);
}
+
+ m_corpus_topics[document_id][term_index] = new_topic;
}
}
std::cerr << " Initialized in " << timer.Elapsed() << " seconds\n";
@@ -94,6 +98,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
// Randomize the corpus indexing array
int tmp;
+ int processed_terms=0;
for (int i = corpus.num_documents()-1; i > 0; --i)
{
//i+1 since j \in [0,i] but rnd() \in [0,1)
@@ -106,8 +111,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
// for each document in the corpus
int document_id;
- for (int i=0; i<corpus.num_documents(); ++i)
- {
+ for (int i=0; i<corpus.num_documents(); ++i) {
document_id = randomDocIndices[i];
// for each term in the document
@@ -115,11 +119,16 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
Document::const_iterator docEnd = corpus.at(document_id).end();
for (Document::const_iterator docIt=corpus.at(document_id).begin();
docIt != docEnd; ++docIt, ++term_index) {
+ if (max_contexts_per_document && term_index > max_contexts_per_document)
+ break;
+
Term term = *docIt;
int freq = corpus.context_count(term);
if (freq < frequency_cutoff)
continue;
+ processed_terms++;
+
// remove the prevous topic from the PYPs
int current_topic = m_corpus_topics[document_id][term_index];
// a negative label mean that term hasn't been sampled yet
@@ -150,6 +159,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
std::cerr << "."; std::cerr.flush();
}
}
+ std::cerr << " ||| sampled " << processed_terms << " terms.";
if (curr_sample != 0 && curr_sample % 10 == 0) {
std::cerr << " ||| time=" << (timer.Elapsed() / 10.0) << " sec/sample" << std::endl;
diff --git a/gi/pyp-topics/src/pyp-topics.hh b/gi/pyp-topics/src/pyp-topics.hh
index 32d2d939..ebe951b1 100644
--- a/gi/pyp-topics/src/pyp-topics.hh
+++ b/gi/pyp-topics/src/pyp-topics.hh
@@ -30,7 +30,8 @@ public:
void sample_corpus(const Corpus& corpus, int samples,
int freq_cutoff_start=0, int freq_cutoff_end=0,
- int freq_cutoff_interval=0);
+ int freq_cutoff_interval=0,
+ int max_contexts_per_document=0);
int sample(const DocumentId& doc, const Term& term);
std::pair<int,F> max(const DocumentId& doc, const Term& term) const;
diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh
index 7a520d6a..dc47244b 100644
--- a/gi/pyp-topics/src/pyp.hh
+++ b/gi/pyp-topics/src/pyp.hh
@@ -4,6 +4,7 @@
#include <math.h>
#include <map>
#include <tr1/unordered_map>
+//#include <google/sparse_hash_map>
#include <boost/random/uniform_real.hpp>
#include <boost/random/variate_generator.hpp>
@@ -11,6 +12,7 @@
#include "log_add.h"
#include "slice-sampler.h"
+#include "mt19937ar.h"
//
// Pitman-Yor process with customer and table tracking
@@ -18,12 +20,17 @@
template <typename Dish, typename Hash=std::tr1::hash<Dish> >
class PYP : protected std::tr1::unordered_map<Dish, int, Hash>
+//class PYP : protected google::sparse_hash_map<Dish, int, Hash>
{
public:
using std::tr1::unordered_map<Dish,int>::const_iterator;
using std::tr1::unordered_map<Dish,int>::iterator;
using std::tr1::unordered_map<Dish,int>::begin;
using std::tr1::unordered_map<Dish,int>::end;
+// using google::sparse_hash_map<Dish,int>::const_iterator;
+// using google::sparse_hash_map<Dish,int>::iterator;
+// using google::sparse_hash_map<Dish,int>::begin;
+// using google::sparse_hash_map<Dish,int>::end;
PYP(double a, double b, unsigned long seed = 0, Hash hash=Hash());
@@ -39,6 +46,7 @@ public:
int num_customers() const { return _total_customers; }
int num_types() const { return std::tr1::unordered_map<Dish,int>::size(); }
+ //int num_types() const { return google::sparse_hash_map<Dish,int>::size(); }
bool empty() const { return _total_customers == 0; }
double log_prob(Dish dish, double log_p0) const;
@@ -79,6 +87,7 @@ private:
std::map<int, int> table_histogram; // num customers at table -> number tables
};
typedef std::tr1::unordered_map<Dish, TableCounter, Hash> DishTableType;
+ //typedef google::sparse_hash_map<Dish, TableCounter, Hash> DishTableType;
DishTableType _dish_tables;
int _total_customers, _total_tables;
@@ -86,11 +95,10 @@ private:
typedef boost::uniform_real<> uni_dist_type;
typedef boost::variate_generator<base_generator_type&, uni_dist_type> gen_type;
- uni_dist_type uni_dist;
- base_generator_type rng; //this gets the seed
- gen_type rnd; //instantiate: rnd(rng, uni_dist)
+// uni_dist_type uni_dist;
+// base_generator_type rng; //this gets the seed
+// gen_type rnd; //instantiate: rnd(rng, uni_dist)
//call: rnd() generates uniform on [0,1)
-
// Function objects for calculating the parts of the log_prob for
// the parameters a and b
@@ -132,12 +140,12 @@ private:
}
};
- /* lbetadist() returns the log probability density of x under a Beta(alpha,beta)
+ /* lbetadist() returns the log probability density of x under a Beta(alpha,beta)
* distribution. - copied from Mark Johnson's gammadist.c
*/
- static long double lbetadist(long double x, long double alpha, long double beta);
+ static long double lbetadist(long double x, long double alpha, long double beta);
- /* lgammadist() returns the log probability density of x under a Gamma(alpha,beta)
+ /* lgammadist() returns the log probability density of x under a Gamma(alpha,beta)
* distribution - copied from Mark Johnson's gammadist.c
*/
static long double lgammadist(long double x, long double alpha, long double beta);
@@ -146,13 +154,15 @@ private:
template <typename Dish, typename Hash>
PYP<Dish,Hash>::PYP(double a, double b, unsigned long seed, Hash)
-: std::tr1::unordered_map<Dish, int, Hash>(), _a(a), _b(b),
+: std::tr1::unordered_map<Dish, int, Hash>(10), _a(a), _b(b),
+//: google::sparse_hash_map<Dish, int, Hash>(10), _a(a), _b(b),
_a_beta_a(1), _a_beta_b(1), _b_gamma_s(1), _b_gamma_c(1),
//_a_beta_a(1), _a_beta_b(1), _b_gamma_s(10), _b_gamma_c(0.1),
- _total_customers(0), _total_tables(0),
- uni_dist(0,1), rng(seed == 0 ? (unsigned long)this : seed), rnd(rng, uni_dist)
+ _total_customers(0), _total_tables(0)//,
+ //uni_dist(0,1), rng(seed == 0 ? (unsigned long)this : seed), rnd(rng, uni_dist)
{
// std::cerr << "\t##PYP<Dish,Hash>::PYP(a=" << _a << ",b=" << _b << ")" << std::endl;
+ //set_deleted_key(-std::numeric_limits<Dish>::max());
}
template <typename Dish, typename Hash>
@@ -235,7 +245,8 @@ PYP<Dish,Hash>::increment(Dish dish, double p0) {
assert (pshare >= 0.0);
//assert (pnew > 0.0);
- if (rnd() < pnew / (pshare + pnew)) {
+ //if (rnd() < pnew / (pshare + pnew)) {
+ if (mt_genrand_res53() < pnew / (pshare + pnew)) {
// assign to a new table
tc.tables += 1;
tc.table_histogram[1] += 1;
@@ -245,7 +256,8 @@ PYP<Dish,Hash>::increment(Dish dish, double p0) {
else {
// randomly assign to an existing table
// remove constant denominator from inner loop
- double r = rnd() * (c - _a*t);
+ //double r = rnd() * (c - _a*t);
+ double r = mt_genrand_res53() * (c - _a*t);
for (std::map<int,int>::iterator
hit = tc.table_histogram.begin();
hit != tc.table_histogram.end(); ++hit) {
@@ -266,6 +278,7 @@ PYP<Dish,Hash>::increment(Dish dish, double p0) {
}
std::tr1::unordered_map<Dish,int,Hash>::operator[](dish) += 1;
+ //google::sparse_hash_map<Dish,int,Hash>::operator[](dish) += 1;
_total_customers += 1;
return delta;
@@ -276,6 +289,7 @@ int
PYP<Dish,Hash>::count(Dish dish) const
{
typename std::tr1::unordered_map<Dish, int>::const_iterator
+ //typename google::sparse_hash_map<Dish, int>::const_iterator
dcit = find(dish);
if (dcit != end())
return dcit->second;
@@ -288,6 +302,7 @@ int
PYP<Dish,Hash>::decrement(Dish dish)
{
typename std::tr1::unordered_map<Dish, int>::iterator dcit = find(dish);
+ //typename google::sparse_hash_map<Dish, int>::iterator dcit = find(dish);
if (dcit == end()) {
std::cerr << dish << std::endl;
assert(false);
@@ -296,6 +311,7 @@ PYP<Dish,Hash>::decrement(Dish dish)
int delta = 0;
typename std::tr1::unordered_map<Dish, TableCounter>::iterator dtit = _dish_tables.find(dish);
+ //typename google::sparse_hash_map<Dish, TableCounter>::iterator dtit = _dish_tables.find(dish);
if (dtit == _dish_tables.end()) {
std::cerr << dish << std::endl;
assert(false);
@@ -307,7 +323,8 @@ PYP<Dish,Hash>::decrement(Dish dish)
//std::cerr << "count: " << count(dish) << " ";
//std::cerr << "tables: " << tc.tables << "\n";
- double r = rnd() * count(dish);
+ //double r = rnd() * count(dish);
+ double r = mt_genrand_res53() * count(dish);
for (std::map<int,int>::iterator hit = tc.table_histogram.begin();
hit != tc.table_histogram.end(); ++hit)
{
@@ -357,6 +374,7 @@ int
PYP<Dish,Hash>::num_tables(Dish dish) const
{
typename std::tr1::unordered_map<Dish, TableCounter, Hash>::const_iterator
+ //typename google::sparse_hash_map<Dish, TableCounter, Hash>::const_iterator
dtit = _dish_tables.find(dish);
//assert(dtit != _dish_tables.end());
@@ -379,6 +397,7 @@ PYP<Dish,Hash>::debug_info(std::ostream& os) const
{
int hists = 0, tables = 0;
for (typename std::tr1::unordered_map<Dish, TableCounter, Hash>::const_iterator
+ //for (typename google::sparse_hash_map<Dish, TableCounter, Hash>::const_iterator
dtit = _dish_tables.begin(); dtit != _dish_tables.end(); ++dtit)
{
hists += dtit->second.table_histogram.size();
@@ -409,6 +428,7 @@ void
PYP<Dish,Hash>::clear()
{
this->std::tr1::unordered_map<Dish,int,Hash>::clear();
+ //this->google::sparse_hash_map<Dish,int,Hash>::clear();
_dish_tables.clear();
_total_tables = _total_customers = 0;
}
@@ -509,7 +529,8 @@ PYP<Dish,Hash>::resample_prior_b() {
int niterations = 10; // number of resampling iterations
//std::cerr << "\n## resample_prior_b(), initial a = " << _a << ", b = " << _b << std::endl;
resample_b_type b_log_prob(_total_customers, _total_tables, _a, _b_gamma_c, _b_gamma_s);
- _b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits<double>::infinity(),
+ //_b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits<double>::infinity(),
+ _b = slice_sampler1d(b_log_prob, _b, random, (double) 0.0, std::numeric_limits<double>::infinity(),
(double) 0.0, niterations, 100*niterations);
//std::cerr << "\n## resample_prior_b(), final a = " << _a << ", b = " << _b << std::endl;
}
@@ -523,7 +544,8 @@ PYP<Dish,Hash>::resample_prior_a() {
int niterations = 10;
//std::cerr << "\n## Initial a = " << _a << ", b = " << _b << std::endl;
resample_a_type a_log_prob(_total_customers, _total_tables, _b, _a_beta_a, _a_beta_b, _dish_tables);
- _a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits<double>::min(),
+ //_a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits<double>::min(),
+ _a = slice_sampler1d(a_log_prob, _a, random, std::numeric_limits<double>::min(),
(double) 1.0, (double) 0.0, niterations, 100*niterations);
}
diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc
index 0a48d3d9..5e98d02f 100644
--- a/gi/pyp-topics/src/train-contexts.cc
+++ b/gi/pyp-topics/src/train-contexts.cc
@@ -54,6 +54,7 @@ int main(int argc, char **argv)
("freq-cutoff-end", value<int>()->default_value(0), "final frequency cutoff.")
("freq-cutoff-interval", value<int>()->default_value(0), "number of iterations between frequency decrement.")
("max-threads", value<int>()->default_value(1), "maximum number of simultaneous threads allowed")
+ ("max-contexts-per-document", value<int>()->default_value(0), "Only sample the n most frequent contexts for a document.")
("num-jobs", value<int>()->default_value(1), "allows finer control over parallelization")
;
@@ -110,7 +111,8 @@ int main(int argc, char **argv)
model.sample_corpus(contexts_corpus, vm["samples"].as<int>(),
vm["freq-cutoff-start"].as<int>(),
vm["freq-cutoff-end"].as<int>(),
- vm["freq-cutoff-interval"].as<int>());
+ vm["freq-cutoff-interval"].as<int>(),
+ vm["max-contexts-per-document"].as<int>());
if (vm.count("document-topics-out")) {
ogzstream documents_out(vm["document-topics-out"].as<string>().c_str());