From 52c65e78485613b24d84a7d96f4d440c347c2028 Mon Sep 17 00:00:00 2001 From: "philblunsom@gmail.com" Date: Thu, 1 Jul 2010 04:11:26 +0000 Subject: Added hierarchical topics. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@87 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/pyp-topics/src/pyp-topics.cc | 59 +++++++++++++++++++++++++++++++------ gi/pyp-topics/src/pyp-topics.hh | 6 +++- gi/pyp-topics/src/train-contexts.cc | 3 +- 3 files changed, 57 insertions(+), 11 deletions(-) (limited to 'gi') diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc index a4ec2463..51511b3a 100644 --- a/gi/pyp-topics/src/pyp-topics.cc +++ b/gi/pyp-topics/src/pyp-topics.cc @@ -60,7 +60,14 @@ void PYPTopics::sample(const Corpus& corpus, int samples) { // add the new topic to the PYPs m_corpus_topics[document_id][term_index] = new_topic; increment(term, new_topic); - m_document_pyps[document_id].increment(new_topic, m_topic_p0); + + if (m_use_topic_pyp) { + F p0 = m_topic_pyp.prob(new_topic, m_topic_p0); + int table_delta = m_document_pyps[document_id].increment(new_topic, p0); + if (table_delta) + m_topic_pyp.increment(new_topic, m_topic_p0); + } + else m_document_pyps[document_id].increment(new_topic, m_topic_p0); } } std::cerr << " Initialized in " << timer.Elapsed() << " seconds\n"; @@ -99,7 +106,10 @@ void PYPTopics::sample(const Corpus& corpus, int samples) { // remove the prevous topic from the PYPs int current_topic = m_corpus_topics[document_id][term_index]; decrement(term, current_topic); - m_document_pyps[document_id].decrement(current_topic); + + int table_delta = m_document_pyps[document_id].decrement(current_topic); + if (m_use_topic_pyp && table_delta < 0) + m_topic_pyp.decrement(current_topic); // sample a new_topic int new_topic = sample(document_id, term); @@ -107,7 +117,14 @@ void PYPTopics::sample(const Corpus& corpus, int samples) { // add the new topic to the PYPs m_corpus_topics[document_id][term_index] = new_topic; increment(term, new_topic); - m_document_pyps[document_id].increment(new_topic, m_topic_p0); + + if (m_use_topic_pyp) { + F p0 = m_topic_pyp.prob(new_topic, m_topic_p0); + int table_delta = m_document_pyps[document_id].increment(new_topic, p0); + if (table_delta) + m_topic_pyp.increment(new_topic, m_topic_p0); + } + else m_document_pyps[document_id].increment(new_topic, m_topic_p0); } if (document_id && document_id % 10000 == 0) { std::cerr << "."; std::cerr.flush(); @@ -126,19 +143,35 @@ void PYPTopics::sample(const Corpus& corpus, int samples) { pypIt != levelIt->end(); ++pypIt) { pypIt->resample_prior(); log_p += pypIt->log_restaurant_prob(); - if (resample_counter++ % 100 == 0) { - std::cerr << "."; std::cerr.flush(); - } } } + resample_counter=0; for (PYPs::iterator pypIt=m_document_pyps.begin(); - pypIt != m_document_pyps.end(); ++pypIt) { + pypIt != m_document_pyps.end(); ++pypIt, ++resample_counter) { pypIt->resample_prior(); log_p += pypIt->log_restaurant_prob(); + if (resample_counter++ % 10000 == 0) { + std::cerr << "."; std::cerr.flush(); + } + } + if (m_use_topic_pyp) { + m_topic_pyp.resample_prior(); + log_p += m_topic_pyp.log_restaurant_prob(); } + std::cerr << " ||| LLH=" << log_p << " ||| resampling time=" << timer.Elapsed() << " sec" << std::endl; timer.Reset(); + + int k=0; + std::cerr << "Topics distribution: "; + for (PYPs::iterator pypIt=m_word_pyps.front().begin(); + pypIt != m_word_pyps.front().end(); ++pypIt, ++k) { + std::cerr << "<" << k << ":" << pypIt->num_customers() << "," + << pypIt->num_types() << "," << m_topic_pyp.count(k) << "> "; + if (k % 5 == 0) std::cerr << std::endl << '\t'; + } + std::cerr << std::endl; } } delete [] randomDocIndices; @@ -171,7 +204,11 @@ int PYPTopics::sample(const DocumentId& doc, const Term& term) { std::vector sums; for (int k=0; k current_max) { current_max = prob; diff --git a/gi/pyp-topics/src/pyp-topics.hh b/gi/pyp-topics/src/pyp-topics.hh index 47207d65..db0f7468 100644 --- a/gi/pyp-topics/src/pyp-topics.hh +++ b/gi/pyp-topics/src/pyp-topics.hh @@ -15,7 +15,9 @@ public: typedef double F; public: - PYPTopics(int num_topics) : m_num_topics(num_topics), m_word_pyps(1) {} + PYPTopics(int num_topics, bool use_topic_pyp=false) + : m_num_topics(num_topics), m_word_pyps(1), + m_topic_pyp(0.5,1.0), m_use_topic_pyp(use_topic_pyp) {} void sample(const Corpus& corpus, int samples); int sample(const DocumentId& doc, const Term& term); @@ -50,6 +52,8 @@ private: typedef std::vector< PYP > PYPs; PYPs m_document_pyps; std::vector m_word_pyps; + PYP m_topic_pyp; + bool m_use_topic_pyp; TermBackoffPtr m_backoff; }; diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc index 833565cd..02bb7b76 100644 --- a/gi/pyp-topics/src/train-contexts.cc +++ b/gi/pyp-topics/src/train-contexts.cc @@ -44,6 +44,7 @@ int main(int argc, char **argv) ("samples,s", value()->default_value(10), "number of sampling passes through the data") ("backoff-type", value(), "backoff type: none|simple") ("filter-singleton-contexts", "filter singleton contexts") + ("hierarchical-topics", "Use a backoff hierarchical PYP as the P0 for the document topics distribution.") ; store(parse_command_line(argc, argv, cmdline_options), vm); notify(vm); @@ -63,7 +64,7 @@ int main(int argc, char **argv) // seed the random number generator //mt_init_genrand(time(0)); - PYPTopics model(vm["topics"].as()); + PYPTopics model(vm["topics"].as(), vm.count("hierarchical-topics")); // read the data BackoffGenerator* backoff_gen=0; -- cgit v1.2.3