diff options
Diffstat (limited to 'gi/pyp-topics/src')
| -rw-r--r-- | gi/pyp-topics/src/pyp-topics.cc | 59 | ||||
| -rw-r--r-- | gi/pyp-topics/src/pyp-topics.hh | 6 | ||||
| -rw-r--r-- | gi/pyp-topics/src/train-contexts.cc | 3 | 
3 files changed, 57 insertions, 11 deletions
| diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc index a4ec2463..51511b3a 100644 --- a/gi/pyp-topics/src/pyp-topics.cc +++ b/gi/pyp-topics/src/pyp-topics.cc @@ -60,7 +60,14 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {        // add the new topic to the PYPs        m_corpus_topics[document_id][term_index] = new_topic;        increment(term, new_topic); -      m_document_pyps[document_id].increment(new_topic, m_topic_p0); + +      if (m_use_topic_pyp) { +        F p0 = m_topic_pyp.prob(new_topic, m_topic_p0); +        int table_delta = m_document_pyps[document_id].increment(new_topic, p0); +        if (table_delta) +          m_topic_pyp.increment(new_topic, m_topic_p0); +      } +      else m_document_pyps[document_id].increment(new_topic, m_topic_p0);      }    }    std::cerr << "  Initialized in " << timer.Elapsed() << " seconds\n"; @@ -99,7 +106,10 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {          // remove the prevous topic from the PYPs          int current_topic = m_corpus_topics[document_id][term_index];          decrement(term, current_topic); -        m_document_pyps[document_id].decrement(current_topic); + +        int table_delta = m_document_pyps[document_id].decrement(current_topic); +        if (m_use_topic_pyp && table_delta < 0)  +          m_topic_pyp.decrement(current_topic);          // sample a new_topic          int new_topic = sample(document_id, term); @@ -107,7 +117,14 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {          // add the new topic to the PYPs          m_corpus_topics[document_id][term_index] = new_topic;          increment(term, new_topic); -        m_document_pyps[document_id].increment(new_topic, m_topic_p0); + +        if (m_use_topic_pyp) { +          F p0 = m_topic_pyp.prob(new_topic, m_topic_p0); +          int table_delta = m_document_pyps[document_id].increment(new_topic, p0); +          if (table_delta) +            m_topic_pyp.increment(new_topic, m_topic_p0); +        } +        else m_document_pyps[document_id].increment(new_topic, m_topic_p0);        }        if (document_id && document_id % 10000 == 0) {          std::cerr << "."; std::cerr.flush(); @@ -126,19 +143,35 @@ void PYPTopics::sample(const Corpus& corpus, int samples) {               pypIt != levelIt->end(); ++pypIt) {            pypIt->resample_prior();            log_p += pypIt->log_restaurant_prob(); -          if (resample_counter++ % 100 == 0) { -            std::cerr << "."; std::cerr.flush(); -          }          }        } +      resample_counter=0;        for (PYPs::iterator pypIt=m_document_pyps.begin(); -           pypIt != m_document_pyps.end(); ++pypIt) { +           pypIt != m_document_pyps.end(); ++pypIt, ++resample_counter) {          pypIt->resample_prior();          log_p += pypIt->log_restaurant_prob(); +        if (resample_counter++ % 10000 == 0) { +          std::cerr << "."; std::cerr.flush(); +        } +      } +      if (m_use_topic_pyp) { +        m_topic_pyp.resample_prior(); +        log_p += m_topic_pyp.log_restaurant_prob();        } +        std::cerr << " ||| LLH=" << log_p << " ||| resampling time=" << timer.Elapsed() << " sec" << std::endl;        timer.Reset(); + +      int k=0; +      std::cerr << "Topics distribution: "; +      for (PYPs::iterator pypIt=m_word_pyps.front().begin(); +           pypIt != m_word_pyps.front().end(); ++pypIt, ++k) { +        std::cerr << "<" << k << ":" << pypIt->num_customers() << ","  +          << pypIt->num_types() << "," << m_topic_pyp.count(k) << "> "; +        if (k % 5 == 0) std::cerr << std::endl << '\t'; +      } +      std::cerr << std::endl;      }    }    delete [] randomDocIndices; @@ -171,7 +204,11 @@ int PYPTopics::sample(const DocumentId& doc, const Term& term) {    std::vector<F> sums;    for (int k=0; k<m_num_topics; ++k) {      F p_w_k = prob(term, k); -    F p_k_d = m_document_pyps[doc].prob(k, m_topic_p0); + +    F topic_prob = m_topic_p0; +    if (m_use_topic_pyp) topic_prob = m_topic_pyp.prob(k, m_topic_p0); +    F p_k_d = m_document_pyps[doc].prob(k, topic_prob); +      sum += (p_w_k*p_k_d);      sums.push_back(sum);    } @@ -225,7 +262,11 @@ int PYPTopics::max(const DocumentId& doc, const Term& term) {    int current_topic=-1;    for (int k=0; k<m_num_topics; ++k) {      F p_w_k = prob(term, k); -    F p_k_d = m_document_pyps[doc].prob(k, m_topic_p0); + +    F topic_prob = m_topic_p0; +    if (m_use_topic_pyp) topic_prob = m_topic_pyp.prob(k, m_topic_p0); +    F p_k_d = m_document_pyps[doc].prob(k, topic_prob); +      F prob = (p_w_k*p_k_d);      if (prob > current_max) {        current_max = prob; diff --git a/gi/pyp-topics/src/pyp-topics.hh b/gi/pyp-topics/src/pyp-topics.hh index 47207d65..db0f7468 100644 --- a/gi/pyp-topics/src/pyp-topics.hh +++ b/gi/pyp-topics/src/pyp-topics.hh @@ -15,7 +15,9 @@ public:    typedef double F;  public: -  PYPTopics(int num_topics) : m_num_topics(num_topics), m_word_pyps(1) {} +  PYPTopics(int num_topics, bool use_topic_pyp=false)  +    : m_num_topics(num_topics), m_word_pyps(1),  +    m_topic_pyp(0.5,1.0), m_use_topic_pyp(use_topic_pyp) {}    void sample(const Corpus& corpus, int samples);    int sample(const DocumentId& doc, const Term& term); @@ -50,6 +52,8 @@ private:    typedef std::vector< PYP<int> > PYPs;    PYPs m_document_pyps;    std::vector<PYPs> m_word_pyps; +  PYP<int> m_topic_pyp; +  bool m_use_topic_pyp;    TermBackoffPtr m_backoff;  }; diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc index 833565cd..02bb7b76 100644 --- a/gi/pyp-topics/src/train-contexts.cc +++ b/gi/pyp-topics/src/train-contexts.cc @@ -44,6 +44,7 @@ int main(int argc, char **argv)        ("samples,s", value<int>()->default_value(10), "number of sampling passes through the data")        ("backoff-type", value<string>(), "backoff type: none|simple")        ("filter-singleton-contexts", "filter singleton contexts") +      ("hierarchical-topics", "Use a backoff hierarchical PYP as the P0 for the document topics distribution.")        ;      store(parse_command_line(argc, argv, cmdline_options), vm);       notify(vm); @@ -63,7 +64,7 @@ int main(int argc, char **argv)    // seed the random number generator    //mt_init_genrand(time(0)); -  PYPTopics model(vm["topics"].as<int>()); +  PYPTopics model(vm["topics"].as<int>(), vm.count("hierarchical-topics"));    // read the data    BackoffGenerator* backoff_gen=0; | 
