diff options
Diffstat (limited to 'gi/pyp-topics/src')
-rw-r--r-- | gi/pyp-topics/src/contexts_corpus.cc | 34 | ||||
-rw-r--r-- | gi/pyp-topics/src/contexts_corpus.hh | 7 | ||||
-rw-r--r-- | gi/pyp-topics/src/corpus.hh | 5 | ||||
-rw-r--r-- | gi/pyp-topics/src/pyp-topics.cc | 53 | ||||
-rw-r--r-- | gi/pyp-topics/src/pyp-topics.hh | 5 | ||||
-rw-r--r-- | gi/pyp-topics/src/train-contexts.cc | 33 | ||||
-rw-r--r-- | gi/pyp-topics/src/train.cc | 2 |
7 files changed, 102 insertions, 37 deletions
diff --git a/gi/pyp-topics/src/contexts_corpus.cc b/gi/pyp-topics/src/contexts_corpus.cc index f01d352a..280b2976 100644 --- a/gi/pyp-topics/src/contexts_corpus.cc +++ b/gi/pyp-topics/src/contexts_corpus.cc @@ -23,7 +23,7 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* ContextsCorpus* corpus_ptr = extra_pair->get<0>(); BackoffGenerator* backoff_gen = extra_pair->get<1>(); - map<string,int>* counts = extra_pair->get<2>(); + //map<string,int>* counts = extra_pair->get<2>(); Document* doc(new Document()); @@ -33,11 +33,11 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* string context_str = corpus_ptr->m_dict.toString(new_contexts.contexts[i]); // filter out singleton contexts - if (!counts->empty()) { - map<string,int>::const_iterator find_it = counts->find(context_str); - if (find_it == counts->end() || find_it->second < 2) - continue; - } + //if (!counts->empty()) { + // map<string,int>::const_iterator find_it = counts->find(context_str); + // if (find_it == counts->end() || find_it->second < 2) + // continue; + //} WordID id = corpus_ptr->m_dict.Convert(context_str); if (cache_word_count != corpus_ptr->m_dict.max()) { @@ -85,10 +85,10 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* } //cout << endl; - if (!doc->empty()) { + //if (!doc->empty()) { corpus_ptr->m_documents.push_back(doc); corpus_ptr->m_keys.push_back(new_contexts.phrase); - } + //} } void filter_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* extra) { @@ -108,10 +108,12 @@ void filter_callback(const ContextsLexer::PhraseContextsType& new_contexts, void unsigned ContextsCorpus::read_contexts(const string &filename, BackoffGenerator* backoff_gen_ptr, - bool filter_singeltons) { + bool /*filter_singeltons*/) { map<string,int> counts; - if (filter_singeltons) { - cerr << "--- Filtering singleton contexts ---" << endl; + //if (filter_singeltons) + { + // cerr << "--- Filtering singleton contexts ---" << endl; + igzstream in(filename.c_str()); ContextsLexer::ReadContexts(&in, filter_callback, &counts); } @@ -128,7 +130,15 @@ unsigned ContextsCorpus::read_contexts(const string &filename, cerr << "Read backoff with order " << m_backoff->order() << "\n"; for (int o=0; o<m_backoff->order(); o++) cerr << " Terms at " << o << " = " << m_backoff->terms_at_level(o) << endl; - cerr << endl; + //cerr << endl; + + int i=0; double av_freq=0; + for (map<string,int>::const_iterator it=counts.begin(); it != counts.end(); ++it, ++i) { + WordID id = m_dict.Convert(it->first); + m_context_counts[id] = it->second; + av_freq += it->second; + } + cerr << " Average term frequency = " << av_freq / (double) i << endl; return m_documents.size(); } diff --git a/gi/pyp-topics/src/contexts_corpus.hh b/gi/pyp-topics/src/contexts_corpus.hh index 891e3a6b..66b71783 100644 --- a/gi/pyp-topics/src/contexts_corpus.hh +++ b/gi/pyp-topics/src/contexts_corpus.hh @@ -4,6 +4,7 @@ #include <vector> #include <string> #include <map> +#include <tr1/unordered_map> #include <boost/ptr_container/ptr_vector.hpp> @@ -66,6 +67,11 @@ public: return res; } + virtual int context_count(const WordID& id) const { + return m_context_counts.find(id)->second; + } + + const std::string& key(const int& i) const { return m_keys.at(i); } @@ -74,6 +80,7 @@ private: TermBackoffPtr m_backoff; Dict m_dict; std::vector<std::string> m_keys; + std::tr1::unordered_map<int,int> m_context_counts; }; #endif // _CONTEXTS_CORPUS_HH diff --git a/gi/pyp-topics/src/corpus.hh b/gi/pyp-topics/src/corpus.hh index c2f37130..24981946 100644 --- a/gi/pyp-topics/src/corpus.hh +++ b/gi/pyp-topics/src/corpus.hh @@ -4,6 +4,7 @@ #include <vector> #include <string> #include <map> +#include <limits> #include <boost/shared_ptr.hpp> #include <boost/ptr_container/ptr_vector.hpp> @@ -35,6 +36,10 @@ public: int num_terms() const { return m_num_terms; } int num_types() const { return m_num_types; } + virtual int context_count(const int&) const { + return std::numeric_limits<int>::max(); + } + protected: int m_num_terms, m_num_types; boost::ptr_vector<Document> m_documents; diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc index 4fb75caa..0ac1b709 100644 --- a/gi/pyp-topics/src/pyp-topics.cc +++ b/gi/pyp-topics/src/pyp-topics.cc @@ -29,7 +29,9 @@ struct Timer { timespec start_t; }; -void PYPTopics::sample(const Corpus& corpus, int samples) { +void PYPTopics::sample_corpus(const Corpus& corpus, int samples, + int freq_cutoff_start, int freq_cutoff_end, + int freq_cutoff_interval) { Timer timer; if (!m_backoff.get()) { @@ -37,7 +39,7 @@ void PYPTopics::sample(const Corpus& corpus, int samples) { m_word_pyps.push_back(PYPs()); } - std::cerr << " Training with " << m_word_pyps.size()-1 << " backoff level" + std::cerr << "\n Training with " << m_word_pyps.size()-1 << " backoff level" << (m_word_pyps.size()==2 ? ":" : "s:") << std::endl; for (int i=0; i<(int)m_word_pyps.size(); ++i) @@ -53,6 +55,9 @@ void PYPTopics::sample(const Corpus& corpus, int samples) { std::cerr << " Documents: " << corpus.num_documents() << " Terms: " << corpus.num_types() << std::endl; + int frequency_cutoff = freq_cutoff_start; + std::cerr << " Context frequency cutoff set to " << frequency_cutoff << std::endl; + timer.Reset(); // Initialisation pass int document_id=0, topic_counter=0; @@ -68,19 +73,22 @@ void PYPTopics::sample(const Corpus& corpus, int samples) { // sample a new_topic //int new_topic = (topic_counter % m_num_topics); - int new_topic = (document_id % m_num_topics); + int freq = corpus.context_count(term); + int new_topic = (freq > frequency_cutoff ? (document_id % m_num_topics) : -1); // add the new topic to the PYPs m_corpus_topics[document_id][term_index] = new_topic; - increment(term, new_topic); + if (freq > frequency_cutoff) { + increment(term, new_topic); - if (m_use_topic_pyp) { - F p0 = m_topic_pyp.prob(new_topic, m_topic_p0); - int table_delta = m_document_pyps[document_id].increment(new_topic, p0); - if (table_delta) - m_topic_pyp.increment(new_topic, m_topic_p0); + if (m_use_topic_pyp) { + F p0 = m_topic_pyp.prob(new_topic, m_topic_p0); + int table_delta = m_document_pyps[document_id].increment(new_topic, p0); + if (table_delta) + m_topic_pyp.increment(new_topic, m_topic_p0); + } + else m_document_pyps[document_id].increment(new_topic, m_topic_p0); } - else m_document_pyps[document_id].increment(new_topic, m_topic_p0); } } std::cerr << " Initialized in " << timer.Elapsed() << " seconds\n"; @@ -91,6 +99,13 @@ void PYPTopics::sample(const Corpus& corpus, int samples) { // Sampling phase for (int curr_sample=0; curr_sample < samples; ++curr_sample) { + if (freq_cutoff_interval > 0 && curr_sample != 1 + && curr_sample % freq_cutoff_interval == 1 + && frequency_cutoff > freq_cutoff_end) { + frequency_cutoff--; + std::cerr << "\n Context frequency cutoff set to " << frequency_cutoff << std::endl; + } + std::cerr << "\n -- Sample " << curr_sample << " "; std::cerr.flush(); // Randomize the corpus indexing array @@ -115,14 +130,20 @@ void PYPTopics::sample(const Corpus& corpus, int samples) { for (Document::const_iterator docIt=corpus.at(document_id).begin(); docIt != docEnd; ++docIt, ++term_index) { Term term = *docIt; + int freq = corpus.context_count(term); + if (freq < frequency_cutoff) + continue; // remove the prevous topic from the PYPs int current_topic = m_corpus_topics[document_id][term_index]; - decrement(term, current_topic); + // a negative label mean that term hasn't been sampled yet + if (current_topic >= 0) { + decrement(term, current_topic); - int table_delta = m_document_pyps[document_id].decrement(current_topic); - if (m_use_topic_pyp && table_delta < 0) - m_topic_pyp.decrement(current_topic); + int table_delta = m_document_pyps[document_id].decrement(current_topic); + if (m_use_topic_pyp && table_delta < 0) + m_topic_pyp.decrement(current_topic); + } // sample a new_topic int new_topic = sample(document_id, term); @@ -182,9 +203,9 @@ void PYPTopics::sample(const Corpus& corpus, int samples) { std::cerr.precision(2); for (PYPs::iterator pypIt=m_word_pyps.front().begin(); pypIt != m_word_pyps.front().end(); ++pypIt, ++k) { - std::cerr << "<" << k << ":" << pypIt->num_customers() << "," - << pypIt->num_types() << "," << m_topic_pyp.prob(k, m_topic_p0) << "> "; if (k % 5 == 0) std::cerr << std::endl << '\t'; + std::cerr << "<" << k << ":" << pypIt->num_customers() << "," + << pypIt->num_types() << "," << m_topic_pyp.prob(k, m_topic_p0) << "> "; } std::cerr.precision(4); std::cerr << std::endl; diff --git a/gi/pyp-topics/src/pyp-topics.hh b/gi/pyp-topics/src/pyp-topics.hh index c35645aa..d4d87440 100644 --- a/gi/pyp-topics/src/pyp-topics.hh +++ b/gi/pyp-topics/src/pyp-topics.hh @@ -19,7 +19,10 @@ public: : m_num_topics(num_topics), m_word_pyps(1), m_topic_pyp(0.5,1.0), m_use_topic_pyp(use_topic_pyp) {} - void sample(const Corpus& corpus, int samples); + void sample_corpus(const Corpus& corpus, int samples, + int freq_cutoff_start=0, int freq_cutoff_end=0, + int freq_cutoff_interval=0); + int sample(const DocumentId& doc, const Term& term); int max(const DocumentId& doc, const Term& term) const; int max(const DocumentId& doc) const; diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc index 7e2100f8..481f8926 100644 --- a/gi/pyp-topics/src/train-contexts.cc +++ b/gi/pyp-topics/src/train-contexts.cc @@ -34,8 +34,13 @@ int main(int argc, char **argv) // Command line processing { - options_description cmdline_options("Allowed options"); - cmdline_options.add_options() + options_description cmdline_specific("Command line specific options"); + cmdline_specific.add_options() + ("help,h", "print help message") + ("config,c", value<string>(), "config file specifying additional command line options") + ; + options_description config_options("Allowed options"); + config_options.add_options() ("help,h", "print help message") ("data,d", value<string>(), "file containing the documents and context terms") ("topics,t", value<int>()->default_value(50), "number of topics") @@ -44,14 +49,25 @@ int main(int argc, char **argv) ("topic-words-out,w", value<string>(), "file to write the topic word distribution to") ("samples,s", value<int>()->default_value(10), "number of sampling passes through the data") ("backoff-type", value<string>(), "backoff type: none|simple") - ("filter-singleton-contexts", "filter singleton contexts") +// ("filter-singleton-contexts", "filter singleton contexts") ("hierarchical-topics", "Use a backoff hierarchical PYP as the P0 for the document topics distribution.") + ("freq-cutoff-start", value<int>()->default_value(0), "initial frequency cutoff.") + ("freq-cutoff-end", value<int>()->default_value(0), "final frequency cutoff.") + ("freq-cutoff-interval", value<int>()->default_value(0), "number of iterations between frequency decrement.") ; - store(parse_command_line(argc, argv, cmdline_options), vm); + + cmdline_specific.add(config_options); + + store(parse_command_line(argc, argv, cmdline_specific), vm); notify(vm); + if (vm.count("config") > 0) { + ifstream config(vm["config"].as<string>().c_str()); + store(parse_config_file(config, config_options), vm); + } + if (vm.count("help")) { - cout << cmdline_options << "\n"; + cout << cmdline_specific << "\n"; return 1; } } @@ -83,14 +99,17 @@ int main(int argc, char **argv) } ContextsCorpus contexts_corpus; - contexts_corpus.read_contexts(vm["data"].as<string>(), backoff_gen, vm.count("filter-singleton-contexts")); + contexts_corpus.read_contexts(vm["data"].as<string>(), backoff_gen, /*vm.count("filter-singleton-contexts")*/ false); model.set_backoff(contexts_corpus.backoff_index()); if (backoff_gen) delete backoff_gen; // train the sampler - model.sample(contexts_corpus, vm["samples"].as<int>()); + model.sample_corpus(contexts_corpus, vm["samples"].as<int>(), + vm["freq-cutoff-start"].as<int>(), + vm["freq-cutoff-end"].as<int>(), + vm["freq-cutoff-interval"].as<int>()); if (vm.count("document-topics-out")) { ogzstream documents_out(vm["document-topics-out"].as<string>().c_str()); diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc index f7b01af0..c94010f2 100644 --- a/gi/pyp-topics/src/train.cc +++ b/gi/pyp-topics/src/train.cc @@ -83,7 +83,7 @@ int main(int argc, char **argv) model.set_backoff(vm["backoff-paths"].as<string>()); // train the sampler - model.sample(corpus, vm["samples"].as<int>()); + model.sample_corpus(corpus, vm["samples"].as<int>()); if (vm.count("document-topics-out")) { ogzstream documents_out(vm["document-topics-out"].as<string>().c_str()); |