diff options
author | philblunsom <philblunsom@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-23 18:19:53 +0000 |
---|---|---|
committer | philblunsom <philblunsom@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-23 18:19:53 +0000 |
commit | 8c2bc3d36e4773844fa829f568b2524345aba3be (patch) | |
tree | 0bdb11f272ab0aa754ae78c584e112da8afc8694 | |
parent | 485344fc8bceacaeec7272347b1eb2923738014b (diff) |
fixed llh and changed to random initialiser.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@614 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r-- | gi/pyp-topics/src/contexts_corpus.hh | 2 | ||||
-rw-r--r-- | gi/pyp-topics/src/pyp-topics.cc | 103 | ||||
-rw-r--r-- | gi/pyp-topics/src/pyp-topics.hh | 9 | ||||
-rw-r--r-- | gi/pyp-topics/src/pyp.hh | 4 | ||||
-rw-r--r-- | gi/pyp-topics/src/train-contexts.cc | 5 |
5 files changed, 97 insertions, 26 deletions
diff --git a/gi/pyp-topics/src/contexts_corpus.hh b/gi/pyp-topics/src/contexts_corpus.hh index b2d235cb..2527f655 100644 --- a/gi/pyp-topics/src/contexts_corpus.hh +++ b/gi/pyp-topics/src/contexts_corpus.hh @@ -78,6 +78,8 @@ public: return m_keys.at(i); } + const Dict& dict() const { return m_dict; } + protected: TermBackoffPtr m_backoff; Dict m_dict; diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc index 4c777f0c..16cc9588 100644 --- a/gi/pyp-topics/src/pyp-topics.cc +++ b/gi/pyp-topics/src/pyp-topics.cc @@ -1,12 +1,17 @@ #include "timing.h" #include "pyp-topics.hh" +#include "contexts_corpus.hh" + +//Dict const *dict; //#include <boost/date_time/posix_time/posix_time_types.hpp> void PYPTopics::sample_corpus(const Corpus& corpus, int samples, int freq_cutoff_start, int freq_cutoff_end, int freq_cutoff_interval, - int max_contexts_per_document) { + int max_contexts_per_document, + F temp_start, F temp_end) { Timer timer; + //dict = &((ContextsCorpus*) &corpus)->dict(); if (!m_backoff.get()) { m_word_pyps.clear(); @@ -21,16 +26,17 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, { m_word_pyps.at(i).reserve(m_num_topics); for (int j=0; j<m_num_topics; ++j) - m_word_pyps.at(i).push_back(new PYP<int>(0.5, 1.0, m_seed)); + m_word_pyps.at(i).push_back(new PYP<int>(0.01, 1.0, m_seed)); } std::cerr << std::endl; m_document_pyps.reserve(corpus.num_documents()); for (int j=0; j<corpus.num_documents(); ++j) - m_document_pyps.push_back(new PYP<int>(0.5, 1.0, m_seed)); + m_document_pyps.push_back(new PYP<int>(0.01, 1.0, m_seed)); m_topic_p0 = 1.0/m_num_topics; - m_term_p0 = 1.0/corpus.num_types(); + m_term_p0 = 1.0/(F)m_backoff->terms_at_level(m_word_pyps.size()-1); + //m_term_p0 = 1.0/corpus.num_types(); m_backoff_p0 = 1.0/corpus.num_documents(); std::cerr << " Documents: " << corpus.num_documents() << " Terms: " @@ -58,8 +64,9 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, int new_topic = -1; if (freq > frequency_cutoff && (!max_contexts_per_document || term_index < max_contexts_per_document)) { - new_topic = sample(document_id, term); + //new_topic = sample(document_id, term); //new_topic = document_id % m_num_topics; + new_topic = (int) (rnd() * m_num_topics); // add the new topic to the PYPs increment(term, new_topic); @@ -95,11 +102,13 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, std::cerr << "\n Context frequency cutoff set to " << frequency_cutoff << std::endl; } - std::cerr << "\n -- Sample " << curr_sample << " "; std::cerr.flush(); + F temp = 1.0 / (temp_start - curr_sample*(temp_start-temp_end)/samples); + std::cerr << "\n -- Sample " << curr_sample << " (T=" << temp << ") "; std::cerr.flush(); // Randomize the corpus indexing array int tmp; int processed_terms=0; + /* for (int i = corpus.num_documents()-1; i > 0; --i) { //i+1 since j \in [0,i] but rnd() \in [0,1) @@ -109,6 +118,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, randomDocIndices[i] = randomDocIndices[j]; randomDocIndices[j] = tmp; } + */ // for each document in the corpus int document_id; @@ -124,6 +134,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, break; Term term = *docIt; + int freq = corpus.context_count(term); if (freq < frequency_cutoff) continue; @@ -142,7 +153,9 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, } // sample a new_topic - int new_topic = sample(document_id, term); + int new_topic = sample(document_id, term, temp); + //std::cerr << "TERM: " << dict->Convert(term) << " (" << term << ") " << " Old Topic: " + // << current_topic << " New Topic: " << new_topic << "\n" << std::endl; // add the new topic to the PYPs m_corpus_topics[document_id][term_index] = new_topic; @@ -160,9 +173,10 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, std::cerr << "."; std::cerr.flush(); } } - std::cerr << " ||| sampled " << processed_terms << " terms."; + std::cerr << " ||| LLH= " << log_likelihood(); if (curr_sample != 0 && curr_sample % 10 == 0) { + //if (true) { std::cerr << " ||| time=" << (timer.Elapsed() / 10.0) << " sec/sample" << std::endl; timer.Reset(); std::cerr << " ... Resampling hyperparameters ("; @@ -201,12 +215,12 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, } if (m_use_topic_pyp) { - m_topic_pyp.resample_prior(rnd); + //m_topic_pyp.resample_prior(rnd); log_p += m_topic_pyp.log_restaurant_prob(); } std::cerr.precision(10); - std::cerr << " ||| LLH=" << log_p << " ||| resampling time=" << timer.Elapsed() << " sec" << std::endl; + std::cerr << " ||| LLH=" << log_likelihood() << " ||| resampling time=" << timer.Elapsed() << " sec" << std::endl; timer.Reset(); int k=0; @@ -218,7 +232,7 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, std::cerr << "<" << k << ":" << pypIt->num_customers() << "," << pypIt->num_types() << "," << m_topic_pyp.prob(k, m_topic_p0) << "> "; } - std::cerr.precision(4); + std::cerr.precision(10); std::cerr << std::endl; } } @@ -234,7 +248,7 @@ PYPTopics::F PYPTopics::hresample_docs(int start, int end) assert(start <= end); for (int i=start; i < end; ++i) { - m_document_pyps[i].resample_prior(rnd); + //m_document_pyps[i].resample_prior(rnd); log_p += m_document_pyps[i].log_restaurant_prob(); if (resample_counter++ % 5000 == 0) { std::cerr << "."; std::cerr.flush(); @@ -251,13 +265,47 @@ PYPTopics::F PYPTopics::hresample_topics() for (PYPs::iterator pypIt=levelIt->begin(); pypIt != levelIt->end(); ++pypIt) { - pypIt->resample_prior(rnd); + //pypIt->resample_prior(rnd); log_p += pypIt->log_restaurant_prob(); } + std::cerr << log_p << std::endl; } return log_p; } +PYPTopics::F PYPTopics::log_likelihood() const +{ + F log_p = 0.0; + + // LLH of topic term distribution + size_t i=0; + for (std::vector<PYPs>::const_iterator levelIt=m_word_pyps.begin(); + levelIt != m_word_pyps.end(); ++levelIt, ++i) { + for (PYPs::const_iterator pypIt=levelIt->begin(); + pypIt != levelIt->end(); ++pypIt, ++i) { + log_p += pypIt->log_restaurant_prob(); + + if (i == m_word_pyps.size()-1) + log_p += (pypIt->num_tables() * -log(m_backoff->terms_at_level(i))); + else + log_p += (pypIt->num_tables() * log(m_term_p0)); + } + } + std::cerr << " TERM LLH: " << log_p << " "; //std::endl; + + // LLH of document topic distribution + for (size_t i=0; i < m_document_pyps.size(); ++i) { + log_p += m_document_pyps[i].log_restaurant_prob(); + if (!m_use_topic_pyp) log_p += (m_document_pyps[i].num_tables() * m_topic_p0); + } + if (m_use_topic_pyp) { + log_p += m_topic_pyp.log_restaurant_prob(); + log_p += (m_topic_pyp.num_tables() * log(m_topic_p0)); + } + + return log_p; +} + void PYPTopics::decrement(const Term& term, int topic, int level) { //std::cerr << "PYPTopics::decrement(" << term << "," << topic << "," << level << ")" << std::endl; int table_delta = m_word_pyps.at(level).at(topic).decrement(term); @@ -279,7 +327,7 @@ void PYPTopics::increment(const Term& term, int topic, int level) { } } -int PYPTopics::sample(const DocumentId& doc, const Term& term) { +int PYPTopics::sample(const DocumentId& doc, const Term& term, F inv_temp) { // First pass: collect probs F sum=0.0; std::vector<F> sums; @@ -292,7 +340,14 @@ int PYPTopics::sample(const DocumentId& doc, const Term& term) { //F p_k_d = m_document_pyps[doc].prob(k, topic_prob); F p_k_d = m_document_pyps[doc].unnormalised_prob(k, topic_prob); - sum += (p_w_k*p_k_d); + F prob = p_w_k*p_k_d; + /* + if (prob < 0.0) { std::cerr << "\n\n" << prob << " " << p_w_k << " " << p_k_d << std::endl; assert(false); } + if (prob > 1.0) { std::cerr << "\n\n" << prob << " " << p_w_k << " " << p_k_d << std::endl; assert(false); } + assert (pow(prob, inv_temp) >= 0.0); + assert (pow(prob, inv_temp) <= 1.0); + */ + sum += pow(prob, inv_temp); sums.push_back(sum); } // Second pass: sample a topic @@ -313,13 +368,16 @@ PYPTopics::F PYPTopics::word_pyps_p0(const Term& term, int topic, int level) con //static F fudge=m_backoff_p0; // TODO Term backoff_term = (*m_backoff)[term]; + //std::cerr << "T: " << term << " BO: " << backoff_term << std::endl; if (!m_backoff->is_null(backoff_term)) { assert (level < m_backoff->order()); - //p0 = (1.0/(double)m_backoff->terms_at_level(level))*prob(backoff_term, topic, level+1); + //p0 = (1.0/(F)m_backoff->terms_at_level(level))*prob(backoff_term, topic, level+1); + p0 = m_term_p0*prob(backoff_term, topic, level+1); p0 = prob(backoff_term, topic, level+1); } else - p0 = m_term_p0; + p0 = (1.0/(F) m_backoff->terms_at_level(level)); + //p0 = m_term_p0; } //for (int i=0; i<level+1; ++i) std::cerr << " "; //std::cerr << "PYPTopics::word_pyps_p0(" << term << "," << topic << "," << level << ") = " << p0 << std::endl; @@ -328,14 +386,17 @@ PYPTopics::F PYPTopics::word_pyps_p0(const Term& term, int topic, int level) con PYPTopics::F PYPTopics::prob(const Term& term, int topic, int level) const { //for (int i=0; i<level+1; ++i) std::cerr << " "; - //std::cerr << "PYPTopics::prob(" << term << "," << topic << "," << level << " " << factor << ")" << std::endl; + //std::cerr << "PYPTopics::prob(" << dict->Convert(term) << "," << topic << "," << level << ")" << std::endl; F p0 = word_pyps_p0(term, topic, level); F p_w_k = m_word_pyps.at(level).at(topic).prob(term, p0); - //for (int i=0; i<level+1; ++i) std::cerr << " "; - //std::cerr << "PYPTopics::prob(" << term << "," << topic << "," << level << ") = " << p_w_k << std::endl; - + /* + for (int i=0; i<level+1; ++i) std::cerr << " "; + std::cerr << "PYPTopics::prob(" << dict->Convert(term) << "," << topic << "," << level << ") = " << p_w_k << std::endl; + for (int i=0; i<level+1; ++i) std::cerr << " "; + m_word_pyps.at(level).at(topic).debug_info(std::cerr); + */ return p_w_k; } diff --git a/gi/pyp-topics/src/pyp-topics.hh b/gi/pyp-topics/src/pyp-topics.hh index ebe951b1..3a910540 100644 --- a/gi/pyp-topics/src/pyp-topics.hh +++ b/gi/pyp-topics/src/pyp-topics.hh @@ -17,7 +17,7 @@ class PYPTopics { public: typedef std::vector<int> DocumentTopics; typedef std::vector<DocumentTopics> CorpusTopics; - typedef double F; + typedef long double F; public: PYPTopics(int num_topics, bool use_topic_pyp=false, unsigned long seed = 0, @@ -31,9 +31,10 @@ public: void sample_corpus(const Corpus& corpus, int samples, int freq_cutoff_start=0, int freq_cutoff_end=0, int freq_cutoff_interval=0, - int max_contexts_per_document=0); + int max_contexts_per_document=0, + F temp_start=1.0, F temp_end=1.0); - int sample(const DocumentId& doc, const Term& term); + int sample(const DocumentId& doc, const Term& term, F inv_temp=1.0); std::pair<int,F> max(const DocumentId& doc, const Term& term) const; std::pair<int,F> max(const DocumentId& doc) const; int max_topic() const; @@ -54,6 +55,8 @@ public: void decrement(const Term& term, int topic, int level=0); void increment(const Term& term, int topic, int level=0); + F log_likelihood() const; + std::ostream& print_document_topics(std::ostream& out) const; std::ostream& print_topic_terms(std::ostream& out) const; diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh index 19cd6be8..b1cb62be 100644 --- a/gi/pyp-topics/src/pyp.hh +++ b/gi/pyp-topics/src/pyp.hh @@ -472,7 +472,9 @@ PYP<Dish,Hash>::log_restaurant_prob() const { assert(false); } //return log_prob; - return log_prob + log_prior(); + if (log_prob > 0.0) + std::cerr << log_prob << std::endl; + return log_prob;// + log_prior(); } template <typename Dish, typename Hash> diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc index da2c2b32..9463f9fc 100644 --- a/gi/pyp-topics/src/train-contexts.cc +++ b/gi/pyp-topics/src/train-contexts.cc @@ -55,6 +55,8 @@ int main(int argc, char **argv) ("max-threads", value<int>()->default_value(1), "maximum number of simultaneous threads allowed") ("max-contexts-per-document", value<int>()->default_value(0), "Only sample the n most frequent contexts for a document.") ("num-jobs", value<int>()->default_value(1), "allows finer control over parallelization") + ("temp-start", value<double>()->default_value(1.0), "starting annealing temperature.") + ("temp-end", value<double>()->default_value(1.0), "end annealing temperature.") ; cmdline_specific.add(config_options); @@ -111,7 +113,8 @@ int main(int argc, char **argv) vm["freq-cutoff-start"].as<int>(), vm["freq-cutoff-end"].as<int>(), vm["freq-cutoff-interval"].as<int>(), - vm["max-contexts-per-document"].as<int>()); + vm["max-contexts-per-document"].as<int>(), + vm["temp-start"].as<double>(), vm["temp-end"].as<double>()); if (vm.count("document-topics-out")) { ogzstream documents_out(vm["document-topics-out"].as<string>().c_str()); |