summaryrefslogtreecommitdiff
path: root/gi/clda/src
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-25 04:03:35 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-25 04:03:35 +0000
commiteb7dd39a10474a58fda0ab237e92b30f9b352d55 (patch)
treecd53cd287ed21c727a0d22aead46d382d610071d /gi/clda/src
parent0ef9ee63a44bf2659756eb2f1d34f5fddfb458a4 (diff)
compute prob of crp
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@619 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/clda/src')
-rw-r--r--gi/clda/src/ccrp.h55
-rw-r--r--gi/clda/src/clda.cc11
2 files changed, 46 insertions, 20 deletions
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h
index e978225d..47f364f2 100644
--- a/gi/clda/src/ccrp.h
+++ b/gi/clda/src/ccrp.h
@@ -18,16 +18,22 @@ class CCRP {
public:
CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {}
+ unsigned num_tables(const Dish& dish) const {
+ const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
+ if (it == dish_locs_.end()) return 0;
+ return it->second.table_counts_.size();
+ }
+
int increment(const Dish& dish, const double& p0, MT19937* rng) {
DishLocations& loc = dish_locs_[dish];
bool share_table = false;
- if (loc.dish_count_) {
+ if (loc.total_dish_count_) {
const double p_empty = (concentration_ + num_tables_ * discount_) * p0;
- const double p_share = (loc.dish_count_ - loc.table_counts_.size() * discount_);
+ const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_);
share_table = rng->SelectSample(p_empty, p_share);
}
if (share_table) {
- double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_);
+ double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_);
for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
ti != loc.table_counts_.end(); ++ti) {
r -= (*ti - discount_);
@@ -45,23 +51,23 @@ class CCRP {
loc.table_counts_.push_back(1u);
++num_tables_;
}
- ++loc.dish_count_;
+ ++loc.total_dish_count_;
++num_customers_;
return (share_table ? 0 : 1);
}
int decrement(const Dish& dish, MT19937* rng) {
DishLocations& loc = dish_locs_[dish];
- assert(loc.dish_count_);
- if (loc.dish_count_ == 1) {
+ assert(loc.total_dish_count_);
+ if (loc.total_dish_count_ == 1) {
dish_locs_.erase(dish);
--num_tables_;
--num_customers_;
return -1;
} else {
int delta = 0;
- double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_);
- --loc.dish_count_;
+ double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_);
+ --loc.total_dish_count_;
for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
ti != loc.table_counts_.end(); ++ti) {
r -= (*ti - discount_);
@@ -90,30 +96,37 @@ class CCRP {
if (it == dish_locs_.end()) {
return r * p0 / (num_customers_ + concentration_);
} else {
- return (it->second.dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) /
+ return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) /
(num_customers_ + concentration_);
}
}
- double llh() const {
+ // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process
+ // does not include P_0's
+ double log_crp_prob() const {
+ double lp = 0.0;
if (num_customers_) {
- std::cerr << "not implemented\n";
- return 0.0;
- } else {
- return 0.0;
+ const double r = lgamma(1.0 - discount_);
+ lp = lgamma(concentration_) - lgamma(concentration_ + num_customers_) + num_tables_ * discount_ + lgamma(concentration_ / discount_ + num_tables_) - lgamma(concentration_ / discount_);
+ for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
+ it != dish_locs_.end(); ++it) {
+ const DishLocations& cur = it->second;
+ lp += lgamma(cur.total_dish_count_ - discount_) - r;
+ }
}
+ return lp;
}
struct DishLocations {
- DishLocations() : dish_count_() {}
- unsigned dish_count_;
+ DishLocations() : total_dish_count_() {}
+ unsigned total_dish_count_;
std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want
};
void Print(std::ostream* out) const {
for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
it != dish_locs_.end(); ++it) {
- (*out) << it->first << " (" << it->second.dish_count_ << " on " << it->second.table_counts_.size() << " tables): ";
+ (*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): ";
for (typename std::list<unsigned>::const_iterator i = it->second.table_counts_.begin();
i != it->second.table_counts_.end(); ++i) {
(*out) << " " << *i;
@@ -122,6 +135,14 @@ class CCRP {
}
}
+ typedef typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator const_iterator;
+ const_iterator begin() const {
+ return dish_locs_.begin();
+ }
+ const_iterator end() const {
+ return dish_locs_.end();
+ }
+
unsigned num_tables_;
unsigned num_customers_;
std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_;
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc
index ea224a27..0232331b 100644
--- a/gi/clda/src/clda.cc
+++ b/gi/clda/src/clda.cc
@@ -53,7 +53,7 @@ void tc() {
tt += crp.decrement("bar", &rng);
cout << crp << endl;
cout << "tt=" << tt << endl;
- cout << crp.llh() << endl;
+ cout << crp.log_crp_prob() << endl;
}
int main(int argc, char** argv) {
@@ -64,7 +64,7 @@ int main(int argc, char** argv) {
}
const int num_classes = atoi(argv[1]);
const int num_iterations = atoi(argv[2]);
- const int burnin_size = num_iterations * 0.666;
+ const int burnin_size = num_iterations * 0.9;
if (num_classes < 2) {
cerr << "Must request more than 1 class\n";
return 1;
@@ -113,7 +113,6 @@ int main(int argc, char** argv) {
vector<map<WordID, int> > t2w(num_classes);
Timer timer;
SampleSet<double> ss;
- const int num_types = TD::NumWords();
ss.resize(num_classes);
double total_time = 0;
for (int iter = 0; iter < num_iterations; ++iter) {
@@ -121,6 +120,12 @@ int main(int argc, char** argv) {
if (iter && iter % 10 == 0) {
total_time += timer.Elapsed();
timer.Reset();
+ double llh = 0;
+ for (int j = 0; j < dr.size(); ++j)
+ llh += dr[j].log_crp_prob();
+ for (int j = 0; j < wr.size(); ++j)
+ llh += wr[j].log_crp_prob();
+ cerr << " [LLH=" << llh << " I=" << iter << "]\n";
}
for (int j = 0; j < zji.size(); ++j) {
const size_t num_words = wji[j].size();