diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-25 02:14:41 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-25 02:14:41 +0000 |
commit | 743ce375bd9a3c5f6cba191bcfa9b50be17d9760 (patch) | |
tree | 18a9cbab91daf00332299fe42be90084854343b9 | |
parent | d195f37d6de3742e732c746d12a84c7a0746b11f (diff) |
crp with explicit table tracking
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@618 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r-- | decoder/cdec.cc | 8 | ||||
-rw-r--r-- | gi/clda/src/ccrp.h | 139 | ||||
-rw-r--r-- | gi/clda/src/clda.cc | 87 | ||||
-rw-r--r-- | utils/sampler.h | 34 |
4 files changed, 215 insertions, 53 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc index ca6284f6..f7b06aa4 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -266,8 +266,8 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c } // TODO move out of cdec into some sampling decoder file -void SampleRecurse(const Hypergraph& hg, const vector<SampleSet>& ss, int n, vector<WordID>* out) { - const SampleSet& s = ss[n]; +void SampleRecurse(const Hypergraph& hg, const vector<SampleSet<prob_t> >& ss, int n, vector<WordID>* out) { + const SampleSet<prob_t>& s = ss[n]; int i = rng->SelectSample(s); const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]]; vector<vector<WordID> > ants(edge.tail_nodes_.size()); @@ -290,9 +290,9 @@ void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) { unordered_map<string, int, boost::hash<string> > m; hg->PushWeightsToGoal(); const int num_nodes = hg->nodes_.size(); - vector<SampleSet> ss(num_nodes); + vector<SampleSet<prob_t> > ss(num_nodes); for (int i = 0; i < num_nodes; ++i) { - SampleSet& s = ss[i]; + SampleSet<prob_t>& s = ss[i]; const vector<int>& in_edges = hg->nodes_[i].in_edges_; for (int j = 0; j < in_edges.size(); ++j) { s.add(hg->edges_[in_edges[j]].edge_prob_); diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h new file mode 100644 index 00000000..e978225d --- /dev/null +++ b/gi/clda/src/ccrp.h @@ -0,0 +1,139 @@ +#ifndef _CCRP_H_ +#define _CCRP_H_ + +#include <cassert> +#include <cmath> +#include <list> +#include <iostream> +#include <vector> +#include <tr1/unordered_map> +#include <boost/functional/hash.hpp> +#include "sampler.h" + +// Chinese restaurant process (Pitman-Yor parameters) with explicit table +// tracking. + +template <typename Dish, typename DishHash = boost::hash<Dish> > +class CCRP { + public: + CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {} + + int increment(const Dish& dish, const double& p0, MT19937* rng) { + DishLocations& loc = dish_locs_[dish]; + bool share_table = false; + if (loc.dish_count_) { + const double p_empty = (concentration_ + num_tables_ * discount_) * p0; + const double p_share = (loc.dish_count_ - loc.table_counts_.size() * discount_); + share_table = rng->SelectSample(p_empty, p_share); + } + if (share_table) { + double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_); + for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); + ti != loc.table_counts_.end(); ++ti) { + r -= (*ti - discount_); + if (r <= 0.0) { + ++(*ti); + break; + } + } + if (r > 0.0) { + std::cerr << "Serious error: r=" << r << std::endl; + Print(&std::cerr); + assert(r <= 0.0); + } + } else { + loc.table_counts_.push_back(1u); + ++num_tables_; + } + ++loc.dish_count_; + ++num_customers_; + return (share_table ? 0 : 1); + } + + int decrement(const Dish& dish, MT19937* rng) { + DishLocations& loc = dish_locs_[dish]; + assert(loc.dish_count_); + if (loc.dish_count_ == 1) { + dish_locs_.erase(dish); + --num_tables_; + --num_customers_; + return -1; + } else { + int delta = 0; + double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_); + --loc.dish_count_; + for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin(); + ti != loc.table_counts_.end(); ++ti) { + r -= (*ti - discount_); + if (r <= 0.0) { + if ((--(*ti)) == 0) { + --num_tables_; + delta = -1; + loc.table_counts_.erase(ti); + } + break; + } + } + if (r > 0.0) { + std::cerr << "Serious error: r=" << r << std::endl; + Print(&std::cerr); + assert(r <= 0.0); + } + --num_customers_; + return delta; + } + } + + double prob(const Dish& dish, const double& p0) const { + const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish); + const double r = num_tables_ * discount_ + concentration_; + if (it == dish_locs_.end()) { + return r * p0 / (num_customers_ + concentration_); + } else { + return (it->second.dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) / + (num_customers_ + concentration_); + } + } + + double llh() const { + if (num_customers_) { + std::cerr << "not implemented\n"; + return 0.0; + } else { + return 0.0; + } + } + + struct DishLocations { + DishLocations() : dish_count_() {} + unsigned dish_count_; + std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want + }; + + void Print(std::ostream* out) const { + for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin(); + it != dish_locs_.end(); ++it) { + (*out) << it->first << " (" << it->second.dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; + for (typename std::list<unsigned>::const_iterator i = it->second.table_counts_.begin(); + i != it->second.table_counts_.end(); ++i) { + (*out) << " " << *i; + } + (*out) << std::endl; + } + } + + unsigned num_tables_; + unsigned num_customers_; + std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_; + + double discount_; + double concentration_; +}; + +template <typename T,typename H> +std::ostream& operator<<(std::ostream& o, const CCRP<T,H>& c) { + c.Print(&o); + return o; +} + +#endif diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc index 757a4691..ea224a27 100644 --- a/gi/clda/src/clda.cc +++ b/gi/clda/src/clda.cc @@ -1,9 +1,11 @@ #include <iostream> #include <vector> #include <map> +#include <string> #include "timer.h" #include "crp.h" +#include "ccrp.h" #include "sampler.h" #include "tdict.h" const size_t MAX_DOC_LEN_CHARS = 10000000; @@ -18,12 +20,44 @@ void ShowTopWordsForTopic(const map<WordID, int>& counts) { for (multimap<int, WordID>::reverse_iterator it = ms.rbegin(); it != ms.rend(); ++it) { cerr << it->first << ':' << TD::Convert(it->second) << " "; ++cc; - if (cc==12) break; + if (cc==20) break; } cerr << endl; } +void tc() { + MT19937 rng; + CCRP<string> crp(0.1, 5); + double un = 0.25; + int tt = 0; + tt += crp.increment("hi", un, &rng); + tt += crp.increment("foo", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + tt += crp.increment("bar", un, &rng); + cout << "tt=" << tt << endl; + cout << crp << endl; + cout << " P(bar)=" << crp.prob("bar", un) << endl; + cout << " P(hi)=" << crp.prob("hi", un) << endl; + cout << " P(baz)=" << crp.prob("baz", un) << endl; + cout << " P(foo)=" << crp.prob("foo", un) << endl; + double x = crp.prob("bar", un) + crp.prob("hi", un) + crp.prob("baz", un) + crp.prob("foo", un); + cout << " tot=" << x << endl; + tt += crp.decrement("hi", &rng); + tt += crp.decrement("bar", &rng); + cout << crp << endl; + tt += crp.decrement("bar", &rng); + cout << crp << endl; + cout << "tt=" << tt << endl; + cout << crp.llh() << endl; +} + int main(int argc, char** argv) { + tc(); if (argc != 3) { cerr << "Usage: " << argv[0] << " num-classes num-samples\n"; return 1; @@ -54,10 +88,13 @@ int main(int argc, char** argv) { MT19937 rng; cerr << "INITIALIZING RANDOM TOPIC ASSIGNMENTS\n"; zji.resize(wji.size()); - double beta = 0.1; - double alpha = 50.0 / num_classes; - vector<CRP<int> > dr(zji.size(), CRP<int>(beta)); // dr[i] describes the probability of using a topic in document i - vector<CRP<int> > wr(num_classes, CRP<int>(alpha)); // wr[k] describes the probability of generating a word in topic k + double disc = 0.05; + double beta = 10.0; + double alpha = 50.0; + double uniform_topic = 1.0 / num_classes; + double uniform_word = 1.0 / TD::NumWords(); + vector<CCRP<int> > dr(zji.size(), CCRP<int>(disc, beta)); // dr[i] describes the probability of using a topic in document i + vector<CCRP<int> > wr(num_classes, CCRP<int>(disc, alpha)); // wr[k] describes the probability of generating a word in topic k for (int j = 0; j < zji.size(); ++j) { const size_t num_words = wji[j].size(); vector<int>& zj = zji[j]; @@ -68,19 +105,15 @@ int main(int argc, char** argv) { if (random_topic == num_classes) { --random_topic; } zj[i] = random_topic; const int word = wj[i]; - dr[j].increment(random_topic); - wr[random_topic].increment(word); + dr[j].increment(random_topic, uniform_topic, &rng); + wr[random_topic].increment(word, uniform_word, &rng); } } cerr << "SAMPLING\n"; vector<map<WordID, int> > t2w(num_classes); Timer timer; - SampleSet ss; + SampleSet<double> ss; const int num_types = TD::NumWords(); - const prob_t class_p0(1.0 / num_classes); - const prob_t word_p0(1.0 / num_types); - cerr << "CLASS PRIOR PROB: " << class_p0 << endl; - cerr << " WORD PRIOR LOGPROB: " << log(word_p0) << endl; ss.resize(num_classes); double total_time = 0; for (int iter = 0; iter < num_iterations; ++iter) { @@ -88,23 +121,6 @@ int main(int argc, char** argv) { if (iter && iter % 10 == 0) { total_time += timer.Elapsed(); timer.Reset(); - prob_t lh = prob_t::One(); - for (int j = 0; j < zji.size(); ++j) { - const size_t num_words = wji[j].size(); - vector<int>& zj = zji[j]; - const vector<int>& wj = wji[j]; - for (int i = 0; i < num_words; ++i) { - const int word = wj[i]; - const int cur_topic = zj[i]; - lh *= dr[j].prob(cur_topic, class_p0); - lh *= wr[cur_topic].prob(word, word_p0); - if (iter > burnin_size) { - ++t2w[cur_topic][word]; - } - } - } - if (iter && iter % 40 == 0) { cerr << " [ITER=" << iter << " SEC/SAMPLE=" << (total_time / 40) << " LLH=" << log(lh) << "]\n"; total_time=0; } - //cerr << "ITERATION " << iter << " LOG LIKELIHOOD: " << log(lh) << endl; } for (int j = 0; j < zji.size(); ++j) { const size_t num_words = wji[j].size(); @@ -113,16 +129,19 @@ int main(int argc, char** argv) { for (int i = 0; i < num_words; ++i) { const int word = wj[i]; const int cur_topic = zj[i]; - dr[j].decrement(cur_topic); - wr[cur_topic].decrement(word); + dr[j].decrement(cur_topic, &rng); + wr[cur_topic].decrement(word, &rng); for (int k = 0; k < num_classes; ++k) { - ss[k]= dr[j].prob(k, class_p0) * wr[k].prob(word, word_p0); + ss[k]= dr[j].prob(k, uniform_topic) * wr[k].prob(word, uniform_word); } const int new_topic = rng.SelectSample(ss); - dr[j].increment(new_topic); - wr[new_topic].increment(word); + dr[j].increment(new_topic, uniform_topic, &rng); + wr[new_topic].increment(word, uniform_word, &rng); zj[i] = new_topic; + if (iter > burnin_size) { + ++t2w[cur_topic][word]; + } } } } diff --git a/utils/sampler.h b/utils/sampler.h index 5fef45d0..f75d96b6 100644 --- a/utils/sampler.h +++ b/utils/sampler.h @@ -18,7 +18,7 @@ #include "prob.h" -struct SampleSet; +template <typename F> struct SampleSet; template <typename RNG> struct RandomNumberGenerator { @@ -45,7 +45,8 @@ struct RandomNumberGenerator { m_generator.seed(seed); } - size_t SelectSample(const prob_t& a, const prob_t& b, double T = 1.0) { + template <typename F> + size_t SelectSample(const F& a, const F& b, double T = 1.0) { if (T == 1.0) { if (this->next() > (a / (a + b))) return 1; else return 0; } else { @@ -54,7 +55,8 @@ struct RandomNumberGenerator { } // T is the annealing temperature, if desired - size_t SelectSample(const SampleSet& ss, double T = 1.0); + template <typename F> + size_t SelectSample(const SampleSet<F>& ss, double T = 1.0); // draw a value from U(0,1) double next() {return m_random();} @@ -94,36 +96,38 @@ struct RandomNumberGenerator { typedef RandomNumberGenerator<boost::mt19937> MT19937; +template <typename F = double> class SampleSet { public: - const prob_t& operator[](int i) const { return m_scores[i]; } - prob_t& operator[](int i) { return m_scores[i]; } + const F& operator[](int i) const { return m_scores[i]; } + F& operator[](int i) { return m_scores[i]; } bool empty() const { return m_scores.empty(); } void add(const prob_t& s) { m_scores.push_back(s); } void clear() { m_scores.clear(); } size_t size() const { return m_scores.size(); } void resize(int size) { m_scores.resize(size); } - std::vector<prob_t> m_scores; + std::vector<F> m_scores; }; template <typename RNG> -size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet& ss, double T) { +template <typename F> +size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet<F>& ss, double T) { assert(T > 0.0); assert(ss.m_scores.size() > 0); if (ss.m_scores.size() == 1) return 0; - const prob_t annealing_factor(1.0 / T); - const bool anneal = (annealing_factor != prob_t::One()); - prob_t sum = prob_t::Zero(); + const double annealing_factor = 1.0 / T; + const bool anneal = (T != 1.0); + F sum = F(0); if (anneal) { for (int i = 0; i < ss.m_scores.size(); ++i) - sum += ss.m_scores[i].pow(annealing_factor); // p^(1/T) + sum += pow(ss.m_scores[i], annealing_factor); // p^(1/T) } else { - sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), prob_t::Zero()); + sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), F(0)); } //for (size_t i = 0; i < ss.m_scores.size(); ++i) std::cerr << ss.m_scores[i] << ","; //std::cerr << std::endl; - prob_t random(this->next()); // random number between 0 and 1 + F random(this->next()); // random number between 0 and 1 random *= sum; // scale with normalization factor //std::cerr << "Random number " << random << std::endl; @@ -131,9 +135,9 @@ size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet& ss, double T) { size_t position = 1; sum = ss.m_scores[0]; if (anneal) { - sum.poweq(annealing_factor); + sum = pow(sum, annealing_factor); for (; position < ss.m_scores.size() && sum < random; ++position) - sum += ss.m_scores[position].pow(annealing_factor); + sum += pow(ss.m_scores[position], annealing_factor); } else { for (; position < ss.m_scores.size() && sum < random; ++position) sum += ss.m_scores[position]; |