diff options
Diffstat (limited to 'gi/pyp-topics/src')
| -rw-r--r-- | gi/pyp-topics/src/pyp-topics.cc | 4 | ||||
| -rw-r--r-- | gi/pyp-topics/src/pyp.hh | 11 | 
2 files changed, 14 insertions, 1 deletions
diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc index d8a8d815..b727458b 100644 --- a/gi/pyp-topics/src/pyp-topics.cc +++ b/gi/pyp-topics/src/pyp-topics.cc @@ -215,7 +215,9 @@ int PYPTopics::sample(const DocumentId& doc, const Term& term) {      F topic_prob = m_topic_p0;      if (m_use_topic_pyp) topic_prob = m_topic_pyp.prob(k, m_topic_p0); -    F p_k_d = m_document_pyps[doc].prob(k, topic_prob); + +    //F p_k_d = m_document_pyps[doc].prob(k, topic_prob); +    F p_k_d = m_document_pyps[doc].unnormalised_prob(k, topic_prob);      sum += (p_w_k*p_k_d);      sums.push_back(sum); diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh index 85076c98..80c79fe1 100644 --- a/gi/pyp-topics/src/pyp.hh +++ b/gi/pyp-topics/src/pyp.hh @@ -33,6 +33,7 @@ public:    double prob(Dish dish, double p0) const;    double prob(Dish dish, double dcd, double dca,                      double dtd, double dta, double p0) const; +  double unnormalised_prob(Dish dish, double p0) const;    int num_customers() const { return _total_customers; }    int num_types() const { return std::tr1::unordered_map<Dish,int>::size(); } @@ -145,6 +146,16 @@ PYP<Dish,Hash>::prob(Dish dish, double p0) const  template <typename Dish, typename Hash>  double  +PYP<Dish,Hash>::unnormalised_prob(Dish dish, double p0) const +{ +  int c = count(dish), t = num_tables(dish); +  double r = num_tables() * _a + _b; +  if (c > 0) return (c - _a * t + r * p0); +  else       return r * p0; +} + +template <typename Dish, typename Hash> +double   PYP<Dish,Hash>::prob(Dish dish, double dcd, double dca,                        double dtd, double dta, double p0)  const  | 
