summaryrefslogtreecommitdiff
path: root/gi/pyp-topics
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics')
-rw-r--r--gi/pyp-topics/src/contexts_corpus.hh2
-rw-r--r--gi/pyp-topics/src/pyp-topics.cc103
-rw-r--r--gi/pyp-topics/src/pyp-topics.hh9
-rw-r--r--gi/pyp-topics/src/pyp.hh4
-rw-r--r--gi/pyp-topics/src/train-contexts.cc5
5 files changed, 97 insertions, 26 deletions
diff --git a/gi/pyp-topics/src/contexts_corpus.hh b/gi/pyp-topics/src/contexts_corpus.hh
index b2d235cb..2527f655 100644
--- a/gi/pyp-topics/src/contexts_corpus.hh
+++ b/gi/pyp-topics/src/contexts_corpus.hh
@@ -78,6 +78,8 @@ public:
return m_keys.at(i);
}
+ const Dict& dict() const { return m_dict; }
+
protected:
TermBackoffPtr m_backoff;
Dict m_dict;
diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc
index 4c777f0c..16cc9588 100644
--- a/gi/pyp-topics/src/pyp-topics.cc
+++ b/gi/pyp-topics/src/pyp-topics.cc
@@ -1,12 +1,17 @@
#include "timing.h"
#include "pyp-topics.hh"
+#include "contexts_corpus.hh"
+
+//Dict const *dict;
//#include <boost/date_time/posix_time/posix_time_types.hpp>
void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
int freq_cutoff_start, int freq_cutoff_end,
int freq_cutoff_interval,
- int max_contexts_per_document) {
+ int max_contexts_per_document,
+ F temp_start, F temp_end) {
Timer timer;
+ //dict = &((ContextsCorpus*) &corpus)->dict();
if (!m_backoff.get()) {
m_word_pyps.clear();
@@ -21,16 +26,17 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
{
m_word_pyps.at(i).reserve(m_num_topics);
for (int j=0; j<m_num_topics; ++j)
- m_word_pyps.at(i).push_back(new PYP<int>(0.5, 1.0, m_seed));
+ m_word_pyps.at(i).push_back(new PYP<int>(0.01, 1.0, m_seed));
}
std::cerr << std::endl;
m_document_pyps.reserve(corpus.num_documents());
for (int j=0; j<corpus.num_documents(); ++j)
- m_document_pyps.push_back(new PYP<int>(0.5, 1.0, m_seed));
+ m_document_pyps.push_back(new PYP<int>(0.01, 1.0, m_seed));
m_topic_p0 = 1.0/m_num_topics;
- m_term_p0 = 1.0/corpus.num_types();
+ m_term_p0 = 1.0/(F)m_backoff->terms_at_level(m_word_pyps.size()-1);
+ //m_term_p0 = 1.0/corpus.num_types();
m_backoff_p0 = 1.0/corpus.num_documents();
std::cerr << " Documents: " << corpus.num_documents() << " Terms: "
@@ -58,8 +64,9 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
int new_topic = -1;
if (freq > frequency_cutoff
&& (!max_contexts_per_document || term_index < max_contexts_per_document)) {
- new_topic = sample(document_id, term);
+ //new_topic = sample(document_id, term);
//new_topic = document_id % m_num_topics;
+ new_topic = (int) (rnd() * m_num_topics);
// add the new topic to the PYPs
increment(term, new_topic);
@@ -95,11 +102,13 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
std::cerr << "\n Context frequency cutoff set to " << frequency_cutoff << std::endl;
}
- std::cerr << "\n -- Sample " << curr_sample << " "; std::cerr.flush();
+ F temp = 1.0 / (temp_start - curr_sample*(temp_start-temp_end)/samples);
+ std::cerr << "\n -- Sample " << curr_sample << " (T=" << temp << ") "; std::cerr.flush();
// Randomize the corpus indexing array
int tmp;
int processed_terms=0;
+ /*
for (int i = corpus.num_documents()-1; i > 0; --i)
{
//i+1 since j \in [0,i] but rnd() \in [0,1)
@@ -109,6 +118,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
randomDocIndices[i] = randomDocIndices[j];
randomDocIndices[j] = tmp;
}
+ */
// for each document in the corpus
int document_id;
@@ -124,6 +134,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
break;
Term term = *docIt;
+
int freq = corpus.context_count(term);
if (freq < frequency_cutoff)
continue;
@@ -142,7 +153,9 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
}
// sample a new_topic
- int new_topic = sample(document_id, term);
+ int new_topic = sample(document_id, term, temp);
+ //std::cerr << "TERM: " << dict->Convert(term) << " (" << term << ") " << " Old Topic: "
+ // << current_topic << " New Topic: " << new_topic << "\n" << std::endl;
// add the new topic to the PYPs
m_corpus_topics[document_id][term_index] = new_topic;
@@ -160,9 +173,10 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
std::cerr << "."; std::cerr.flush();
}
}
- std::cerr << " ||| sampled " << processed_terms << " terms.";
+ std::cerr << " ||| LLH= " << log_likelihood();
if (curr_sample != 0 && curr_sample % 10 == 0) {
+ //if (true) {
std::cerr << " ||| time=" << (timer.Elapsed() / 10.0) << " sec/sample" << std::endl;
timer.Reset();
std::cerr << " ... Resampling hyperparameters (";
@@ -201,12 +215,12 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
}
if (m_use_topic_pyp) {
- m_topic_pyp.resample_prior(rnd);
+ //m_topic_pyp.resample_prior(rnd);
log_p += m_topic_pyp.log_restaurant_prob();
}
std::cerr.precision(10);
- std::cerr << " ||| LLH=" << log_p << " ||| resampling time=" << timer.Elapsed() << " sec" << std::endl;
+ std::cerr << " ||| LLH=" << log_likelihood() << " ||| resampling time=" << timer.Elapsed() << " sec" << std::endl;
timer.Reset();
int k=0;
@@ -218,7 +232,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
std::cerr << "<" << k << ":" << pypIt->num_customers() << ","
<< pypIt->num_types() << "," << m_topic_pyp.prob(k, m_topic_p0) << "> ";
}
- std::cerr.precision(4);
+ std::cerr.precision(10);
std::cerr << std::endl;
}
}
@@ -234,7 +248,7 @@ PYPTopics::F PYPTopics::hresample_docs(int start, int end)
assert(start <= end);
for (int i=start; i < end; ++i)
{
- m_document_pyps[i].resample_prior(rnd);
+ //m_document_pyps[i].resample_prior(rnd);
log_p += m_document_pyps[i].log_restaurant_prob();
if (resample_counter++ % 5000 == 0) {
std::cerr << "."; std::cerr.flush();
@@ -251,13 +265,47 @@ PYPTopics::F PYPTopics::hresample_topics()
for (PYPs::iterator pypIt=levelIt->begin();
pypIt != levelIt->end(); ++pypIt) {
- pypIt->resample_prior(rnd);
+ //pypIt->resample_prior(rnd);
log_p += pypIt->log_restaurant_prob();
}
+ std::cerr << log_p << std::endl;
}
return log_p;
}
+PYPTopics::F PYPTopics::log_likelihood() const
+{
+ F log_p = 0.0;
+
+ // LLH of topic term distribution
+ size_t i=0;
+ for (std::vector<PYPs>::const_iterator levelIt=m_word_pyps.begin();
+ levelIt != m_word_pyps.end(); ++levelIt, ++i) {
+ for (PYPs::const_iterator pypIt=levelIt->begin();
+ pypIt != levelIt->end(); ++pypIt, ++i) {
+ log_p += pypIt->log_restaurant_prob();
+
+ if (i == m_word_pyps.size()-1)
+ log_p += (pypIt->num_tables() * -log(m_backoff->terms_at_level(i)));
+ else
+ log_p += (pypIt->num_tables() * log(m_term_p0));
+ }
+ }
+ std::cerr << " TERM LLH: " << log_p << " "; //std::endl;
+
+ // LLH of document topic distribution
+ for (size_t i=0; i < m_document_pyps.size(); ++i) {
+ log_p += m_document_pyps[i].log_restaurant_prob();
+ if (!m_use_topic_pyp) log_p += (m_document_pyps[i].num_tables() * m_topic_p0);
+ }
+ if (m_use_topic_pyp) {
+ log_p += m_topic_pyp.log_restaurant_prob();
+ log_p += (m_topic_pyp.num_tables() * log(m_topic_p0));
+ }
+
+ return log_p;
+}
+
void PYPTopics::decrement(const Term& term, int topic, int level) {
//std::cerr << "PYPTopics::decrement(" << term << "," << topic << "," << level << ")" << std::endl;
int table_delta = m_word_pyps.at(level).at(topic).decrement(term);
@@ -279,7 +327,7 @@ void PYPTopics::increment(const Term& term, int topic, int level) {
}
}
-int PYPTopics::sample(const DocumentId& doc, const Term& term) {
+int PYPTopics::sample(const DocumentId& doc, const Term& term, F inv_temp) {
// First pass: collect probs
F sum=0.0;
std::vector<F> sums;
@@ -292,7 +340,14 @@ int PYPTopics::sample(const DocumentId& doc, const Term& term) {
//F p_k_d = m_document_pyps[doc].prob(k, topic_prob);
F p_k_d = m_document_pyps[doc].unnormalised_prob(k, topic_prob);
- sum += (p_w_k*p_k_d);
+ F prob = p_w_k*p_k_d;
+ /*
+ if (prob < 0.0) { std::cerr << "\n\n" << prob << " " << p_w_k << " " << p_k_d << std::endl; assert(false); }
+ if (prob > 1.0) { std::cerr << "\n\n" << prob << " " << p_w_k << " " << p_k_d << std::endl; assert(false); }
+ assert (pow(prob, inv_temp) >= 0.0);
+ assert (pow(prob, inv_temp) <= 1.0);
+ */
+ sum += pow(prob, inv_temp);
sums.push_back(sum);
}
// Second pass: sample a topic
@@ -313,13 +368,16 @@ PYPTopics::F PYPTopics::word_pyps_p0(const Term& term, int topic, int level) con
//static F fudge=m_backoff_p0; // TODO
Term backoff_term = (*m_backoff)[term];
+ //std::cerr << "T: " << term << " BO: " << backoff_term << std::endl;
if (!m_backoff->is_null(backoff_term)) {
assert (level < m_backoff->order());
- //p0 = (1.0/(double)m_backoff->terms_at_level(level))*prob(backoff_term, topic, level+1);
+ //p0 = (1.0/(F)m_backoff->terms_at_level(level))*prob(backoff_term, topic, level+1);
+ p0 = m_term_p0*prob(backoff_term, topic, level+1);
p0 = prob(backoff_term, topic, level+1);
}
else
- p0 = m_term_p0;
+ p0 = (1.0/(F) m_backoff->terms_at_level(level));
+ //p0 = m_term_p0;
}
//for (int i=0; i<level+1; ++i) std::cerr << " ";
//std::cerr << "PYPTopics::word_pyps_p0(" << term << "," << topic << "," << level << ") = " << p0 << std::endl;
@@ -328,14 +386,17 @@ PYPTopics::F PYPTopics::word_pyps_p0(const Term& term, int topic, int level) con
PYPTopics::F PYPTopics::prob(const Term& term, int topic, int level) const {
//for (int i=0; i<level+1; ++i) std::cerr << " ";
- //std::cerr << "PYPTopics::prob(" << term << "," << topic << "," << level << " " << factor << ")" << std::endl;
+ //std::cerr << "PYPTopics::prob(" << dict->Convert(term) << "," << topic << "," << level << ")" << std::endl;
F p0 = word_pyps_p0(term, topic, level);
F p_w_k = m_word_pyps.at(level).at(topic).prob(term, p0);
- //for (int i=0; i<level+1; ++i) std::cerr << " ";
- //std::cerr << "PYPTopics::prob(" << term << "," << topic << "," << level << ") = " << p_w_k << std::endl;
-
+ /*
+ for (int i=0; i<level+1; ++i) std::cerr << " ";
+ std::cerr << "PYPTopics::prob(" << dict->Convert(term) << "," << topic << "," << level << ") = " << p_w_k << std::endl;
+ for (int i=0; i<level+1; ++i) std::cerr << " ";
+ m_word_pyps.at(level).at(topic).debug_info(std::cerr);
+ */
return p_w_k;
}
diff --git a/gi/pyp-topics/src/pyp-topics.hh b/gi/pyp-topics/src/pyp-topics.hh
index ebe951b1..3a910540 100644
--- a/gi/pyp-topics/src/pyp-topics.hh
+++ b/gi/pyp-topics/src/pyp-topics.hh
@@ -17,7 +17,7 @@ class PYPTopics {
public:
typedef std::vector<int> DocumentTopics;
typedef std::vector<DocumentTopics> CorpusTopics;
- typedef double F;
+ typedef long double F;
public:
PYPTopics(int num_topics, bool use_topic_pyp=false, unsigned long seed = 0,
@@ -31,9 +31,10 @@ public:
void sample_corpus(const Corpus& corpus, int samples,
int freq_cutoff_start=0, int freq_cutoff_end=0,
int freq_cutoff_interval=0,
- int max_contexts_per_document=0);
+ int max_contexts_per_document=0,
+ F temp_start=1.0, F temp_end=1.0);
- int sample(const DocumentId& doc, const Term& term);
+ int sample(const DocumentId& doc, const Term& term, F inv_temp=1.0);
std::pair<int,F> max(const DocumentId& doc, const Term& term) const;
std::pair<int,F> max(const DocumentId& doc) const;
int max_topic() const;
@@ -54,6 +55,8 @@ public:
void decrement(const Term& term, int topic, int level=0);
void increment(const Term& term, int topic, int level=0);
+ F log_likelihood() const;
+
std::ostream& print_document_topics(std::ostream& out) const;
std::ostream& print_topic_terms(std::ostream& out) const;
diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh
index 19cd6be8..b1cb62be 100644
--- a/gi/pyp-topics/src/pyp.hh
+++ b/gi/pyp-topics/src/pyp.hh
@@ -472,7 +472,9 @@ PYP<Dish,Hash>::log_restaurant_prob() const {
assert(false);
}
//return log_prob;
- return log_prob + log_prior();
+ if (log_prob > 0.0)
+ std::cerr << log_prob << std::endl;
+ return log_prob;// + log_prior();
}
template <typename Dish, typename Hash>
diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc
index da2c2b32..9463f9fc 100644
--- a/gi/pyp-topics/src/train-contexts.cc
+++ b/gi/pyp-topics/src/train-contexts.cc
@@ -55,6 +55,8 @@ int main(int argc, char **argv)
("max-threads", value<int>()->default_value(1), "maximum number of simultaneous threads allowed")
("max-contexts-per-document", value<int>()->default_value(0), "Only sample the n most frequent contexts for a document.")
("num-jobs", value<int>()->default_value(1), "allows finer control over parallelization")
+ ("temp-start", value<double>()->default_value(1.0), "starting annealing temperature.")
+ ("temp-end", value<double>()->default_value(1.0), "end annealing temperature.")
;
cmdline_specific.add(config_options);
@@ -111,7 +113,8 @@ int main(int argc, char **argv)
vm["freq-cutoff-start"].as<int>(),
vm["freq-cutoff-end"].as<int>(),
vm["freq-cutoff-interval"].as<int>(),
- vm["max-contexts-per-document"].as<int>());
+ vm["max-contexts-per-document"].as<int>(),
+ vm["temp-start"].as<double>(), vm["temp-end"].as<double>());
if (vm.count("document-topics-out")) {
ogzstream documents_out(vm["document-topics-out"].as<string>().c_str());