summaryrefslogtreecommitdiff
path: root/gi/pyp-topics
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics')
-rw-r--r--gi/pyp-topics/src/pyp-topics.cc4
-rw-r--r--gi/pyp-topics/src/pyp.hh11
2 files changed, 14 insertions, 1 deletions
diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc
index d8a8d815..b727458b 100644
--- a/gi/pyp-topics/src/pyp-topics.cc
+++ b/gi/pyp-topics/src/pyp-topics.cc
@@ -215,7 +215,9 @@ int PYPTopics::sample(const DocumentId& doc, const Term& term) {
F topic_prob = m_topic_p0;
if (m_use_topic_pyp) topic_prob = m_topic_pyp.prob(k, m_topic_p0);
- F p_k_d = m_document_pyps[doc].prob(k, topic_prob);
+
+ //F p_k_d = m_document_pyps[doc].prob(k, topic_prob);
+ F p_k_d = m_document_pyps[doc].unnormalised_prob(k, topic_prob);
sum += (p_w_k*p_k_d);
sums.push_back(sum);
diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh
index 85076c98..80c79fe1 100644
--- a/gi/pyp-topics/src/pyp.hh
+++ b/gi/pyp-topics/src/pyp.hh
@@ -33,6 +33,7 @@ public:
double prob(Dish dish, double p0) const;
double prob(Dish dish, double dcd, double dca,
double dtd, double dta, double p0) const;
+ double unnormalised_prob(Dish dish, double p0) const;
int num_customers() const { return _total_customers; }
int num_types() const { return std::tr1::unordered_map<Dish,int>::size(); }
@@ -145,6 +146,16 @@ PYP<Dish,Hash>::prob(Dish dish, double p0) const
template <typename Dish, typename Hash>
double
+PYP<Dish,Hash>::unnormalised_prob(Dish dish, double p0) const
+{
+ int c = count(dish), t = num_tables(dish);
+ double r = num_tables() * _a + _b;
+ if (c > 0) return (c - _a * t + r * p0);
+ else return r * p0;
+}
+
+template <typename Dish, typename Hash>
+double
PYP<Dish,Hash>::prob(Dish dish, double dcd, double dca,
double dtd, double dta, double p0)
const