summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics/src')
-rw-r--r--gi/pyp-topics/src/contexts_corpus.cc34
-rw-r--r--gi/pyp-topics/src/contexts_corpus.hh7
-rw-r--r--gi/pyp-topics/src/corpus.hh5
-rw-r--r--gi/pyp-topics/src/pyp-topics.cc53
-rw-r--r--gi/pyp-topics/src/pyp-topics.hh5
-rw-r--r--gi/pyp-topics/src/train-contexts.cc33
-rw-r--r--gi/pyp-topics/src/train.cc2
7 files changed, 102 insertions, 37 deletions
diff --git a/gi/pyp-topics/src/contexts_corpus.cc b/gi/pyp-topics/src/contexts_corpus.cc
index f01d352a..280b2976 100644
--- a/gi/pyp-topics/src/contexts_corpus.cc
+++ b/gi/pyp-topics/src/contexts_corpus.cc
@@ -23,7 +23,7 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void*
ContextsCorpus* corpus_ptr = extra_pair->get<0>();
BackoffGenerator* backoff_gen = extra_pair->get<1>();
- map<string,int>* counts = extra_pair->get<2>();
+ //map<string,int>* counts = extra_pair->get<2>();
Document* doc(new Document());
@@ -33,11 +33,11 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void*
string context_str = corpus_ptr->m_dict.toString(new_contexts.contexts[i]);
// filter out singleton contexts
- if (!counts->empty()) {
- map<string,int>::const_iterator find_it = counts->find(context_str);
- if (find_it == counts->end() || find_it->second < 2)
- continue;
- }
+ //if (!counts->empty()) {
+ // map<string,int>::const_iterator find_it = counts->find(context_str);
+ // if (find_it == counts->end() || find_it->second < 2)
+ // continue;
+ //}
WordID id = corpus_ptr->m_dict.Convert(context_str);
if (cache_word_count != corpus_ptr->m_dict.max()) {
@@ -85,10 +85,10 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void*
}
//cout << endl;
- if (!doc->empty()) {
+ //if (!doc->empty()) {
corpus_ptr->m_documents.push_back(doc);
corpus_ptr->m_keys.push_back(new_contexts.phrase);
- }
+ //}
}
void filter_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* extra) {
@@ -108,10 +108,12 @@ void filter_callback(const ContextsLexer::PhraseContextsType& new_contexts, void
unsigned ContextsCorpus::read_contexts(const string &filename,
BackoffGenerator* backoff_gen_ptr,
- bool filter_singeltons) {
+ bool /*filter_singeltons*/) {
map<string,int> counts;
- if (filter_singeltons) {
- cerr << "--- Filtering singleton contexts ---" << endl;
+ //if (filter_singeltons)
+ {
+ // cerr << "--- Filtering singleton contexts ---" << endl;
+
igzstream in(filename.c_str());
ContextsLexer::ReadContexts(&in, filter_callback, &counts);
}
@@ -128,7 +130,15 @@ unsigned ContextsCorpus::read_contexts(const string &filename,
cerr << "Read backoff with order " << m_backoff->order() << "\n";
for (int o=0; o<m_backoff->order(); o++)
cerr << " Terms at " << o << " = " << m_backoff->terms_at_level(o) << endl;
- cerr << endl;
+ //cerr << endl;
+
+ int i=0; double av_freq=0;
+ for (map<string,int>::const_iterator it=counts.begin(); it != counts.end(); ++it, ++i) {
+ WordID id = m_dict.Convert(it->first);
+ m_context_counts[id] = it->second;
+ av_freq += it->second;
+ }
+ cerr << " Average term frequency = " << av_freq / (double) i << endl;
return m_documents.size();
}
diff --git a/gi/pyp-topics/src/contexts_corpus.hh b/gi/pyp-topics/src/contexts_corpus.hh
index 891e3a6b..66b71783 100644
--- a/gi/pyp-topics/src/contexts_corpus.hh
+++ b/gi/pyp-topics/src/contexts_corpus.hh
@@ -4,6 +4,7 @@
#include <vector>
#include <string>
#include <map>
+#include <tr1/unordered_map>
#include <boost/ptr_container/ptr_vector.hpp>
@@ -66,6 +67,11 @@ public:
return res;
}
+ virtual int context_count(const WordID& id) const {
+ return m_context_counts.find(id)->second;
+ }
+
+
const std::string& key(const int& i) const {
return m_keys.at(i);
}
@@ -74,6 +80,7 @@ private:
TermBackoffPtr m_backoff;
Dict m_dict;
std::vector<std::string> m_keys;
+ std::tr1::unordered_map<int,int> m_context_counts;
};
#endif // _CONTEXTS_CORPUS_HH
diff --git a/gi/pyp-topics/src/corpus.hh b/gi/pyp-topics/src/corpus.hh
index c2f37130..24981946 100644
--- a/gi/pyp-topics/src/corpus.hh
+++ b/gi/pyp-topics/src/corpus.hh
@@ -4,6 +4,7 @@
#include <vector>
#include <string>
#include <map>
+#include <limits>
#include <boost/shared_ptr.hpp>
#include <boost/ptr_container/ptr_vector.hpp>
@@ -35,6 +36,10 @@ public:
int num_terms() const { return m_num_terms; }
int num_types() const { return m_num_types; }
+ virtual int context_count(const int&) const {
+ return std::numeric_limits<int>::max();
+ }
+
protected:
int m_num_terms, m_num_types;
boost::ptr_vector<Document> m_documents;
diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc
index 4fb75caa..0ac1b709 100644
--- a/gi/pyp-topics/src/pyp-topics.cc
+++ b/gi/pyp-topics/src/pyp-topics.cc
@@ -29,7 +29,9 @@ struct Timer {
timespec start_t;
};
-void PYPTopics::sample(const Corpus& corpus, int samples) {
+void PYPTopics::sample_corpus(const Corpus& corpus, int samples,
+ int freq_cutoff_start, int freq_cutoff_end,
+ int freq_cutoff_interval) {
Timer timer;
if (!m_backoff.get()) {
@@ -37,7 +39,7 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {
m_word_pyps.push_back(PYPs());
}
- std::cerr << " Training with " << m_word_pyps.size()-1 << " backoff level"
+ std::cerr << "\n Training with " << m_word_pyps.size()-1 << " backoff level"
<< (m_word_pyps.size()==2 ? ":" : "s:") << std::endl;
for (int i=0; i<(int)m_word_pyps.size(); ++i)
@@ -53,6 +55,9 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {
std::cerr << " Documents: " << corpus.num_documents() << " Terms: "
<< corpus.num_types() << std::endl;
+ int frequency_cutoff = freq_cutoff_start;
+ std::cerr << " Context frequency cutoff set to " << frequency_cutoff << std::endl;
+
timer.Reset();
// Initialisation pass
int document_id=0, topic_counter=0;
@@ -68,19 +73,22 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {
// sample a new_topic
//int new_topic = (topic_counter % m_num_topics);
- int new_topic = (document_id % m_num_topics);
+ int freq = corpus.context_count(term);
+ int new_topic = (freq > frequency_cutoff ? (document_id % m_num_topics) : -1);
// add the new topic to the PYPs
m_corpus_topics[document_id][term_index] = new_topic;
- increment(term, new_topic);
+ if (freq > frequency_cutoff) {
+ increment(term, new_topic);
- if (m_use_topic_pyp) {
- F p0 = m_topic_pyp.prob(new_topic, m_topic_p0);
- int table_delta = m_document_pyps[document_id].increment(new_topic, p0);
- if (table_delta)
- m_topic_pyp.increment(new_topic, m_topic_p0);
+ if (m_use_topic_pyp) {
+ F p0 = m_topic_pyp.prob(new_topic, m_topic_p0);
+ int table_delta = m_document_pyps[document_id].increment(new_topic, p0);
+ if (table_delta)
+ m_topic_pyp.increment(new_topic, m_topic_p0);
+ }
+ else m_document_pyps[document_id].increment(new_topic, m_topic_p0);
}
- else m_document_pyps[document_id].increment(new_topic, m_topic_p0);
}
}
std::cerr << " Initialized in " << timer.Elapsed() << " seconds\n";
@@ -91,6 +99,13 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {
// Sampling phase
for (int curr_sample=0; curr_sample < samples; ++curr_sample) {
+ if (freq_cutoff_interval > 0 && curr_sample != 1
+ && curr_sample % freq_cutoff_interval == 1
+ && frequency_cutoff > freq_cutoff_end) {
+ frequency_cutoff--;
+ std::cerr << "\n Context frequency cutoff set to " << frequency_cutoff << std::endl;
+ }
+
std::cerr << "\n -- Sample " << curr_sample << " "; std::cerr.flush();
// Randomize the corpus indexing array
@@ -115,14 +130,20 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {
for (Document::const_iterator docIt=corpus.at(document_id).begin();
docIt != docEnd; ++docIt, ++term_index) {
Term term = *docIt;
+ int freq = corpus.context_count(term);
+ if (freq < frequency_cutoff)
+ continue;
// remove the prevous topic from the PYPs
int current_topic = m_corpus_topics[document_id][term_index];
- decrement(term, current_topic);
+ // a negative label mean that term hasn't been sampled yet
+ if (current_topic >= 0) {
+ decrement(term, current_topic);
- int table_delta = m_document_pyps[document_id].decrement(current_topic);
- if (m_use_topic_pyp && table_delta < 0)
- m_topic_pyp.decrement(current_topic);
+ int table_delta = m_document_pyps[document_id].decrement(current_topic);
+ if (m_use_topic_pyp && table_delta < 0)
+ m_topic_pyp.decrement(current_topic);
+ }
// sample a new_topic
int new_topic = sample(document_id, term);
@@ -182,9 +203,9 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {
std::cerr.precision(2);
for (PYPs::iterator pypIt=m_word_pyps.front().begin();
pypIt != m_word_pyps.front().end(); ++pypIt, ++k) {
- std::cerr << "<" << k << ":" << pypIt->num_customers() << ","
- << pypIt->num_types() << "," << m_topic_pyp.prob(k, m_topic_p0) << "> ";
if (k % 5 == 0) std::cerr << std::endl << '\t';
+ std::cerr << "<" << k << ":" << pypIt->num_customers() << ","
+ << pypIt->num_types() << "," << m_topic_pyp.prob(k, m_topic_p0) << "> ";
}
std::cerr.precision(4);
std::cerr << std::endl;
diff --git a/gi/pyp-topics/src/pyp-topics.hh b/gi/pyp-topics/src/pyp-topics.hh
index c35645aa..d4d87440 100644
--- a/gi/pyp-topics/src/pyp-topics.hh
+++ b/gi/pyp-topics/src/pyp-topics.hh
@@ -19,7 +19,10 @@ public:
: m_num_topics(num_topics), m_word_pyps(1),
m_topic_pyp(0.5,1.0), m_use_topic_pyp(use_topic_pyp) {}
- void sample(const Corpus& corpus, int samples);
+ void sample_corpus(const Corpus& corpus, int samples,
+ int freq_cutoff_start=0, int freq_cutoff_end=0,
+ int freq_cutoff_interval=0);
+
int sample(const DocumentId& doc, const Term& term);
int max(const DocumentId& doc, const Term& term) const;
int max(const DocumentId& doc) const;
diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc
index 7e2100f8..481f8926 100644
--- a/gi/pyp-topics/src/train-contexts.cc
+++ b/gi/pyp-topics/src/train-contexts.cc
@@ -34,8 +34,13 @@ int main(int argc, char **argv)
// Command line processing
{
- options_description cmdline_options("Allowed options");
- cmdline_options.add_options()
+ options_description cmdline_specific("Command line specific options");
+ cmdline_specific.add_options()
+ ("help,h", "print help message")
+ ("config,c", value<string>(), "config file specifying additional command line options")
+ ;
+ options_description config_options("Allowed options");
+ config_options.add_options()
("help,h", "print help message")
("data,d", value<string>(), "file containing the documents and context terms")
("topics,t", value<int>()->default_value(50), "number of topics")
@@ -44,14 +49,25 @@ int main(int argc, char **argv)
("topic-words-out,w", value<string>(), "file to write the topic word distribution to")
("samples,s", value<int>()->default_value(10), "number of sampling passes through the data")
("backoff-type", value<string>(), "backoff type: none|simple")
- ("filter-singleton-contexts", "filter singleton contexts")
+// ("filter-singleton-contexts", "filter singleton contexts")
("hierarchical-topics", "Use a backoff hierarchical PYP as the P0 for the document topics distribution.")
+ ("freq-cutoff-start", value<int>()->default_value(0), "initial frequency cutoff.")
+ ("freq-cutoff-end", value<int>()->default_value(0), "final frequency cutoff.")
+ ("freq-cutoff-interval", value<int>()->default_value(0), "number of iterations between frequency decrement.")
;
- store(parse_command_line(argc, argv, cmdline_options), vm);
+
+ cmdline_specific.add(config_options);
+
+ store(parse_command_line(argc, argv, cmdline_specific), vm);
notify(vm);
+ if (vm.count("config") > 0) {
+ ifstream config(vm["config"].as<string>().c_str());
+ store(parse_config_file(config, config_options), vm);
+ }
+
if (vm.count("help")) {
- cout << cmdline_options << "\n";
+ cout << cmdline_specific << "\n";
return 1;
}
}
@@ -83,14 +99,17 @@ int main(int argc, char **argv)
}
ContextsCorpus contexts_corpus;
- contexts_corpus.read_contexts(vm["data"].as<string>(), backoff_gen, vm.count("filter-singleton-contexts"));
+ contexts_corpus.read_contexts(vm["data"].as<string>(), backoff_gen, /*vm.count("filter-singleton-contexts")*/ false);
model.set_backoff(contexts_corpus.backoff_index());
if (backoff_gen)
delete backoff_gen;
// train the sampler
- model.sample(contexts_corpus, vm["samples"].as<int>());
+ model.sample_corpus(contexts_corpus, vm["samples"].as<int>(),
+ vm["freq-cutoff-start"].as<int>(),
+ vm["freq-cutoff-end"].as<int>(),
+ vm["freq-cutoff-interval"].as<int>());
if (vm.count("document-topics-out")) {
ogzstream documents_out(vm["document-topics-out"].as<string>().c_str());
diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc
index f7b01af0..c94010f2 100644
--- a/gi/pyp-topics/src/train.cc
+++ b/gi/pyp-topics/src/train.cc
@@ -83,7 +83,7 @@ int main(int argc, char **argv)
model.set_backoff(vm["backoff-paths"].as<string>());
// train the sampler
- model.sample(corpus, vm["samples"].as<int>());
+ model.sample_corpus(corpus, vm["samples"].as<int>());
if (vm.count("document-topics-out")) {
ogzstream documents_out(vm["document-topics-out"].as<string>().c_str());