diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-25 04:03:35 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-25 04:03:35 +0000 |
commit | eb7dd39a10474a58fda0ab237e92b30f9b352d55 (patch) | |
tree | cd53cd287ed21c727a0d22aead46d382d610071d | |
parent | 0ef9ee63a44bf2659756eb2f1d34f5fddfb458a4 (diff) |
compute prob of crp
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@619 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r-- | gi/clda/src/ccrp.h | 55 | ||||
-rw-r--r-- | gi/clda/src/clda.cc | 11 |
2 files changed, 46 insertions, 20 deletions
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h index e978225d..47f364f2 100644 --- a/gi/clda/src/ccrp.h +++ b/gi/clda/src/ccrp.h @@ -18,16 +18,22 @@ class CCRP { public: CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {} + unsigned num_tables(const Dish& dish) const { + const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); + if (it == dish_locs_.end()) return 0; + return it->second.table_counts_.size(); + } + int increment(const Dish& dish, const double& p0, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; bool share_table = false; - if (loc.dish_count_) { + if (loc.total_dish_count_) { const double p_empty = (concentration_ + num_tables_ * discount_) * p0; - const double p_share = (loc.dish_count_ - loc.table_counts_.size() * discount_); + const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_); share_table = rng->SelectSample(p_empty, p_share); } if (share_table) { - double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_); + double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); ti != loc.table_counts_.end(); ++ti) { r -= (*ti - discount_); @@ -45,23 +51,23 @@ class CCRP { loc.table_counts_.push_back(1u); ++num_tables_; } - ++loc.dish_count_; + ++loc.total_dish_count_; ++num_customers_; return (share_table ? 0 : 1); } int decrement(const Dish& dish, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; - assert(loc.dish_count_); - if (loc.dish_count_ == 1) { + assert(loc.total_dish_count_); + if (loc.total_dish_count_ == 1) { dish_locs_.erase(dish); --num_tables_; --num_customers_; return -1; } else { int delta = 0; - double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_); - --loc.dish_count_; + double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); + --loc.total_dish_count_; for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); ti != loc.table_counts_.end(); ++ti) { r -= (*ti - discount_); @@ -90,30 +96,37 @@ class CCRP { if (it == dish_locs_.end()) { return r * p0 / (num_customers_ + concentration_); } else { - return (it->second.dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) / + return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) / (num_customers_ + concentration_); } } - double llh() const { + // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process + // does not include P_0's + double log_crp_prob() const { + double lp = 0.0; if (num_customers_) { - std::cerr << "not implemented\n"; - return 0.0; - } else { - return 0.0; + const double r = lgamma(1.0 - discount_); + lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) - lgamma(concentration_ / discount_); + for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); + it != dish_locs_.end(); ++it) { + const DishLocations& cur = it->second; + lp += lgamma(cur.total_dish_count_ - discount_) - r; + } } + return lp; } struct DishLocations { - DishLocations() : dish_count_() {} - unsigned dish_count_; + DishLocations() : total_dish_count_() {} + unsigned total_dish_count_; std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want }; void Print(std::ostream* out) const { for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { - (*out) << it->first << " (" << it->second.dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; + (*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; for (typename std::list<unsigned>::const_iterator i = it->second.table_counts_.begin(); i != it->second.table_counts_.end(); ++i) { (*out) << " " << *i; @@ -122,6 +135,14 @@ class CCRP { } } + typedef typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator const_iterator; + const_iterator begin() const { + return dish_locs_.begin(); + } + const_iterator end() const { + return dish_locs_.end(); + } + unsigned num_tables_; unsigned num_customers_; std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_; diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc index ea224a27..0232331b 100644 --- a/gi/clda/src/clda.cc +++ b/gi/clda/src/clda.cc @@ -53,7 +53,7 @@ void tc() { tt += crp.decrement("bar", &rng); cout << crp << endl; cout << "tt=" << tt << endl; - cout << crp.llh() << endl; + cout << crp.log_crp_prob() << endl; } int main(int argc, char** argv) { @@ -64,7 +64,7 @@ int main(int argc, char** argv) { } const int num_classes = atoi(argv[1]); const int num_iterations = atoi(argv[2]); - const int burnin_size = num_iterations * 0.666; + const int burnin_size = num_iterations * 0.9; if (num_classes < 2) { cerr << "Must request more than 1 class\n"; return 1; @@ -113,7 +113,6 @@ int main(int argc, char** argv) { vector<map<WordID, int> > t2w(num_classes); Timer timer; SampleSet<double> ss; - const int num_types = TD::NumWords(); ss.resize(num_classes); double total_time = 0; for (int iter = 0; iter < num_iterations; ++iter) { @@ -121,6 +120,12 @@ int main(int argc, char** argv) { if (iter && iter % 10 == 0) { total_time += timer.Elapsed(); timer.Reset(); + double llh = 0; + for (int j = 0; j < dr.size(); ++j) + llh += dr[j].log_crp_prob(); + for (int j = 0; j < wr.size(); ++j) + llh += wr[j].log_crp_prob(); + cerr << " [LLH=" << llh << " I=" << iter << "]\n"; } for (int j = 0; j < zji.size(); ++j) { const size_t num_words = wji[j].size(); |