summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorphilblunsom@gmail.com <philblunsom@gmail.com@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-01 04:11:26 +0000
committerphilblunsom@gmail.com <philblunsom@gmail.com@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-01 04:11:26 +0000
commit52c65e78485613b24d84a7d96f4d440c347c2028 (patch)
tree8872865244ab30662af7f3879a69d3833bc3aa1a
parent6ea9bd3aa600bce5224fa97dae79bee6f40699a2 (diff)
Added hierarchical topics.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@87 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--gi/pyp-topics/src/pyp-topics.cc59
-rw-r--r--gi/pyp-topics/src/pyp-topics.hh6
-rw-r--r--gi/pyp-topics/src/train-contexts.cc3
3 files changed, 57 insertions, 11 deletions
diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc
index a4ec2463..51511b3a 100644
--- a/gi/pyp-topics/src/pyp-topics.cc
+++ b/gi/pyp-topics/src/pyp-topics.cc
@@ -60,7 +60,14 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {
// add the new topic to the PYPs
m_corpus_topics[document_id][term_index] = new_topic;
increment(term, new_topic);
- m_document_pyps[document_id].increment(new_topic, m_topic_p0);
+
+ if (m_use_topic_pyp) {
+ F p0 = m_topic_pyp.prob(new_topic, m_topic_p0);
+ int table_delta = m_document_pyps[document_id].increment(new_topic, p0);
+ if (table_delta)
+ m_topic_pyp.increment(new_topic, m_topic_p0);
+ }
+ else m_document_pyps[document_id].increment(new_topic, m_topic_p0);
}
}
std::cerr << " Initialized in " << timer.Elapsed() << " seconds\n";
@@ -99,7 +106,10 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {
// remove the prevous topic from the PYPs
int current_topic = m_corpus_topics[document_id][term_index];
decrement(term, current_topic);
- m_document_pyps[document_id].decrement(current_topic);
+
+ int table_delta = m_document_pyps[document_id].decrement(current_topic);
+ if (m_use_topic_pyp && table_delta < 0)
+ m_topic_pyp.decrement(current_topic);
// sample a new_topic
int new_topic = sample(document_id, term);
@@ -107,7 +117,14 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {
// add the new topic to the PYPs
m_corpus_topics[document_id][term_index] = new_topic;
increment(term, new_topic);
- m_document_pyps[document_id].increment(new_topic, m_topic_p0);
+
+ if (m_use_topic_pyp) {
+ F p0 = m_topic_pyp.prob(new_topic, m_topic_p0);
+ int table_delta = m_document_pyps[document_id].increment(new_topic, p0);
+ if (table_delta)
+ m_topic_pyp.increment(new_topic, m_topic_p0);
+ }
+ else m_document_pyps[document_id].increment(new_topic, m_topic_p0);
}
if (document_id && document_id % 10000 == 0) {
std::cerr << "."; std::cerr.flush();
@@ -126,19 +143,35 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {
pypIt != levelIt->end(); ++pypIt) {
pypIt->resample_prior();
log_p += pypIt->log_restaurant_prob();
- if (resample_counter++ % 100 == 0) {
- std::cerr << "."; std::cerr.flush();
- }
}
}
+ resample_counter=0;
for (PYPs::iterator pypIt=m_document_pyps.begin();
- pypIt != m_document_pyps.end(); ++pypIt) {
+ pypIt != m_document_pyps.end(); ++pypIt, ++resample_counter) {
pypIt->resample_prior();
log_p += pypIt->log_restaurant_prob();
+ if (resample_counter++ % 10000 == 0) {
+ std::cerr << "."; std::cerr.flush();
+ }
+ }
+ if (m_use_topic_pyp) {
+ m_topic_pyp.resample_prior();
+ log_p += m_topic_pyp.log_restaurant_prob();
}
+
std::cerr << " ||| LLH=" << log_p << " ||| resampling time=" << timer.Elapsed() << " sec" << std::endl;
timer.Reset();
+
+ int k=0;
+ std::cerr << "Topics distribution: ";
+ for (PYPs::iterator pypIt=m_word_pyps.front().begin();
+ pypIt != m_word_pyps.front().end(); ++pypIt, ++k) {
+ std::cerr << "<" << k << ":" << pypIt->num_customers() << ","
+ << pypIt->num_types() << "," << m_topic_pyp.count(k) << "> ";
+ if (k % 5 == 0) std::cerr << std::endl << '\t';
+ }
+ std::cerr << std::endl;
}
}
delete [] randomDocIndices;
@@ -171,7 +204,11 @@ int PYPTopics::sample(const DocumentId& doc, const Term& term) {
std::vector<F> sums;
for (int k=0; k<m_num_topics; ++k) {
F p_w_k = prob(term, k);
- F p_k_d = m_document_pyps[doc].prob(k, m_topic_p0);
+
+ F topic_prob = m_topic_p0;
+ if (m_use_topic_pyp) topic_prob = m_topic_pyp.prob(k, m_topic_p0);
+ F p_k_d = m_document_pyps[doc].prob(k, topic_prob);
+
sum += (p_w_k*p_k_d);
sums.push_back(sum);
}
@@ -225,7 +262,11 @@ int PYPTopics::max(const DocumentId& doc, const Term& term) {
int current_topic=-1;
for (int k=0; k<m_num_topics; ++k) {
F p_w_k = prob(term, k);
- F p_k_d = m_document_pyps[doc].prob(k, m_topic_p0);
+
+ F topic_prob = m_topic_p0;
+ if (m_use_topic_pyp) topic_prob = m_topic_pyp.prob(k, m_topic_p0);
+ F p_k_d = m_document_pyps[doc].prob(k, topic_prob);
+
F prob = (p_w_k*p_k_d);
if (prob > current_max) {
current_max = prob;
diff --git a/gi/pyp-topics/src/pyp-topics.hh b/gi/pyp-topics/src/pyp-topics.hh
index 47207d65..db0f7468 100644
--- a/gi/pyp-topics/src/pyp-topics.hh
+++ b/gi/pyp-topics/src/pyp-topics.hh
@@ -15,7 +15,9 @@ public:
typedef double F;
public:
- PYPTopics(int num_topics) : m_num_topics(num_topics), m_word_pyps(1) {}
+ PYPTopics(int num_topics, bool use_topic_pyp=false)
+ : m_num_topics(num_topics), m_word_pyps(1),
+ m_topic_pyp(0.5,1.0), m_use_topic_pyp(use_topic_pyp) {}
void sample(const Corpus& corpus, int samples);
int sample(const DocumentId& doc, const Term& term);
@@ -50,6 +52,8 @@ private:
typedef std::vector< PYP<int> > PYPs;
PYPs m_document_pyps;
std::vector<PYPs> m_word_pyps;
+ PYP<int> m_topic_pyp;
+ bool m_use_topic_pyp;
TermBackoffPtr m_backoff;
};
diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc
index 833565cd..02bb7b76 100644
--- a/gi/pyp-topics/src/train-contexts.cc
+++ b/gi/pyp-topics/src/train-contexts.cc
@@ -44,6 +44,7 @@ int main(int argc, char **argv)
("samples,s", value<int>()->default_value(10), "number of sampling passes through the data")
("backoff-type", value<string>(), "backoff type: none|simple")
("filter-singleton-contexts", "filter singleton contexts")
+ ("hierarchical-topics", "Use a backoff hierarchical PYP as the P0 for the document topics distribution.")
;
store(parse_command_line(argc, argv, cmdline_options), vm);
notify(vm);
@@ -63,7 +64,7 @@ int main(int argc, char **argv)
// seed the random number generator
//mt_init_genrand(time(0));
- PYPTopics model(vm["topics"].as<int>());
+ PYPTopics model(vm["topics"].as<int>(), vm.count("hierarchical-topics"));
// read the data
BackoffGenerator* backoff_gen=0;