diff options
Diffstat (limited to 'gi/clda')
| -rw-r--r-- | gi/clda/src/ccrp.h | 55 | ||||
| -rw-r--r-- | gi/clda/src/clda.cc | 11 | 
2 files changed, 46 insertions, 20 deletions
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h index e978225d..47f364f2 100644 --- a/gi/clda/src/ccrp.h +++ b/gi/clda/src/ccrp.h @@ -18,16 +18,22 @@ class CCRP {   public:    CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {} +  unsigned num_tables(const Dish& dish) const { +    const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); +    if (it == dish_locs_.end()) return 0; +    return it->second.table_counts_.size(); +  } +    int increment(const Dish& dish, const double& p0, MT19937* rng) {      DishLocations& loc = dish_locs_[dish];      bool share_table = false; -    if (loc.dish_count_) { +    if (loc.total_dish_count_) {        const double p_empty = (concentration_ + num_tables_ * discount_) * p0; -      const double p_share = (loc.dish_count_ - loc.table_counts_.size() * discount_); +      const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_);        share_table = rng->SelectSample(p_empty, p_share);      }      if (share_table) { -      double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_); +      double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_);        for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();             ti != loc.table_counts_.end(); ++ti) {          r -= (*ti - discount_); @@ -45,23 +51,23 @@ class CCRP {        loc.table_counts_.push_back(1u);        ++num_tables_;      } -    ++loc.dish_count_; +    ++loc.total_dish_count_;      ++num_customers_;      return (share_table ? 0 : 1);    }    int decrement(const Dish& dish, MT19937* rng) {      DishLocations& loc = dish_locs_[dish]; -    assert(loc.dish_count_); -    if (loc.dish_count_ == 1) { +    assert(loc.total_dish_count_); +    if (loc.total_dish_count_ == 1) {        dish_locs_.erase(dish);        --num_tables_;        --num_customers_;        return -1;      } else {        int delta = 0; -      double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_); -      --loc.dish_count_; +      double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); +      --loc.total_dish_count_;        for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();             ti != loc.table_counts_.end(); ++ti) {          r -= (*ti - discount_); @@ -90,30 +96,37 @@ class CCRP {      if (it == dish_locs_.end()) {        return r * p0 / (num_customers_ + concentration_);      } else { -      return (it->second.dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) / +      return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) /                 (num_customers_ + concentration_);      }    } -  double llh() const { +  // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process +  // does not include P_0's +  double log_crp_prob() const { +    double lp = 0.0;      if (num_customers_) { -      std::cerr << "not implemented\n"; -      return 0.0; -    } else { -      return 0.0; +      const double r = lgamma(1.0 - discount_); +      lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) - lgamma(concentration_ / discount_); +      for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); +           it != dish_locs_.end(); ++it) { +        const DishLocations& cur = it->second; +        lp += lgamma(cur.total_dish_count_ - discount_) - r; +      }      } +    return lp;    }    struct DishLocations { -    DishLocations() : dish_count_() {} -    unsigned dish_count_; +    DishLocations() : total_dish_count_() {} +    unsigned total_dish_count_;      std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want    };    void Print(std::ostream* out) const {      for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();           it != dish_locs_.end(); ++it) { -      (*out) << it->first << " (" << it->second.dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; +      (*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): ";        for (typename std::list<unsigned>::const_iterator i = it->second.table_counts_.begin();             i != it->second.table_counts_.end(); ++i) {          (*out) << " " << *i; @@ -122,6 +135,14 @@ class CCRP {      }    } +  typedef typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator const_iterator; +  const_iterator begin() const { +    return dish_locs_.begin(); +  } +  const_iterator end() const { +    return dish_locs_.end(); +  } +    unsigned num_tables_;    unsigned num_customers_;    std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_; diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc index ea224a27..0232331b 100644 --- a/gi/clda/src/clda.cc +++ b/gi/clda/src/clda.cc @@ -53,7 +53,7 @@ void tc() {    tt += crp.decrement("bar", &rng);    cout << crp << endl;    cout << "tt=" << tt << endl; -  cout << crp.llh() << endl; +  cout << crp.log_crp_prob() << endl;  }  int main(int argc, char** argv) { @@ -64,7 +64,7 @@ int main(int argc, char** argv) {    }    const int num_classes = atoi(argv[1]);    const int num_iterations = atoi(argv[2]); -  const int burnin_size = num_iterations * 0.666; +  const int burnin_size = num_iterations * 0.9;    if (num_classes < 2) {      cerr << "Must request more than 1 class\n";      return 1; @@ -113,7 +113,6 @@ int main(int argc, char** argv) {    vector<map<WordID, int> > t2w(num_classes);    Timer timer;    SampleSet<double> ss; -  const int num_types = TD::NumWords();    ss.resize(num_classes);    double total_time = 0;    for (int iter = 0; iter < num_iterations; ++iter) { @@ -121,6 +120,12 @@ int main(int argc, char** argv) {      if (iter && iter % 10 == 0) {        total_time += timer.Elapsed();        timer.Reset(); +      double llh = 0; +      for (int j = 0; j < dr.size(); ++j) +        llh += dr[j].log_crp_prob(); +      for (int j = 0; j < wr.size(); ++j) +        llh += wr[j].log_crp_prob(); +      cerr << " [LLH=" << llh << " I=" << iter << "]\n";      }      for (int j = 0; j < zji.size(); ++j) {        const size_t num_words = wji[j].size();  | 
