summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-25 02:14:41 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-25 02:14:41 +0000
commit743ce375bd9a3c5f6cba191bcfa9b50be17d9760 (patch)
tree18a9cbab91daf00332299fe42be90084854343b9
parentd195f37d6de3742e732c746d12a84c7a0746b11f (diff)
crp with explicit table tracking
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@618 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--decoder/cdec.cc8
-rw-r--r--gi/clda/src/ccrp.h139
-rw-r--r--gi/clda/src/clda.cc87
-rw-r--r--utils/sampler.h34
4 files changed, 215 insertions, 53 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index ca6284f6..f7b06aa4 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -266,8 +266,8 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c
}
// TODO move out of cdec into some sampling decoder file
-void SampleRecurse(const Hypergraph& hg, const vector<SampleSet>& ss, int n, vector<WordID>* out) {
- const SampleSet& s = ss[n];
+void SampleRecurse(const Hypergraph& hg, const vector<SampleSet<prob_t> >& ss, int n, vector<WordID>* out) {
+ const SampleSet<prob_t>& s = ss[n];
int i = rng->SelectSample(s);
const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]];
vector<vector<WordID> > ants(edge.tail_nodes_.size());
@@ -290,9 +290,9 @@ void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) {
unordered_map<string, int, boost::hash<string> > m;
hg->PushWeightsToGoal();
const int num_nodes = hg->nodes_.size();
- vector<SampleSet> ss(num_nodes);
+ vector<SampleSet<prob_t> > ss(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
- SampleSet& s = ss[i];
+ SampleSet<prob_t>& s = ss[i];
const vector<int>& in_edges = hg->nodes_[i].in_edges_;
for (int j = 0; j < in_edges.size(); ++j) {
s.add(hg->edges_[in_edges[j]].edge_prob_);
diff --git a/gi/clda/src/ccrp.h b/gi/clda/src/ccrp.h
new file mode 100644
index 00000000..e978225d
--- /dev/null
+++ b/gi/clda/src/ccrp.h
@@ -0,0 +1,139 @@
+#ifndef _CCRP_H_
+#define _CCRP_H_
+
+#include <cassert>
+#include <cmath>
+#include <list>
+#include <iostream>
+#include <vector>
+#include <tr1/unordered_map>
+#include <boost/functional/hash.hpp>
+#include "sampler.h"
+
+// Chinese restaurant process (Pitman-Yor parameters) with explicit table
+// tracking.
+
+template <typename Dish, typename DishHash = boost::hash<Dish> >
+class CCRP {
+ public:
+ CCRP(double disc, double conc) : num_tables_(), num_customers_(), discount_(disc), concentration_(conc) {}
+
+ int increment(const Dish& dish, const double& p0, MT19937* rng) {
+ DishLocations& loc = dish_locs_[dish];
+ bool share_table = false;
+ if (loc.dish_count_) {
+ const double p_empty = (concentration_ + num_tables_ * discount_) * p0;
+ const double p_share = (loc.dish_count_ - loc.table_counts_.size() * discount_);
+ share_table = rng->SelectSample(p_empty, p_share);
+ }
+ if (share_table) {
+ double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_);
+ for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
+ ti != loc.table_counts_.end(); ++ti) {
+ r -= (*ti - discount_);
+ if (r <= 0.0) {
+ ++(*ti);
+ break;
+ }
+ }
+ if (r > 0.0) {
+ std::cerr << "Serious error: r=" << r << std::endl;
+ Print(&std::cerr);
+ assert(r <= 0.0);
+ }
+ } else {
+ loc.table_counts_.push_back(1u);
+ ++num_tables_;
+ }
+ ++loc.dish_count_;
+ ++num_customers_;
+ return (share_table ? 0 : 1);
+ }
+
+ int decrement(const Dish& dish, MT19937* rng) {
+ DishLocations& loc = dish_locs_[dish];
+ assert(loc.dish_count_);
+ if (loc.dish_count_ == 1) {
+ dish_locs_.erase(dish);
+ --num_tables_;
+ --num_customers_;
+ return -1;
+ } else {
+ int delta = 0;
+ double r = rng->next() * (loc.dish_count_ - loc.table_counts_.size() * discount_);
+ --loc.dish_count_;
+ for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
+ ti != loc.table_counts_.end(); ++ti) {
+ r -= (*ti - discount_);
+ if (r <= 0.0) {
+ if ((--(*ti)) == 0) {
+ --num_tables_;
+ delta = -1;
+ loc.table_counts_.erase(ti);
+ }
+ break;
+ }
+ }
+ if (r > 0.0) {
+ std::cerr << "Serious error: r=" << r << std::endl;
+ Print(&std::cerr);
+ assert(r <= 0.0);
+ }
+ --num_customers_;
+ return delta;
+ }
+ }
+
+ double prob(const Dish& dish, const double& p0) const {
+ const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
+ const double r = num_tables_ * discount_ + concentration_;
+ if (it == dish_locs_.end()) {
+ return r * p0 / (num_customers_ + concentration_);
+ } else {
+ return (it->second.dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) /
+ (num_customers_ + concentration_);
+ }
+ }
+
+ double llh() const {
+ if (num_customers_) {
+ std::cerr << "not implemented\n";
+ return 0.0;
+ } else {
+ return 0.0;
+ }
+ }
+
+ struct DishLocations {
+ DishLocations() : dish_count_() {}
+ unsigned dish_count_;
+ std::list<unsigned> table_counts_; // list<> gives O(1) deletion and insertion, which we want
+ };
+
+ void Print(std::ostream* out) const {
+ for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
+ it != dish_locs_.end(); ++it) {
+ (*out) << it->first << " (" << it->second.dish_count_ << " on " << it->second.table_counts_.size() << " tables): ";
+ for (typename std::list<unsigned>::const_iterator i = it->second.table_counts_.begin();
+ i != it->second.table_counts_.end(); ++i) {
+ (*out) << " " << *i;
+ }
+ (*out) << std::endl;
+ }
+ }
+
+ unsigned num_tables_;
+ unsigned num_customers_;
+ std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_;
+
+ double discount_;
+ double concentration_;
+};
+
+template <typename T,typename H>
+std::ostream& operator<<(std::ostream& o, const CCRP<T,H>& c) {
+ c.Print(&o);
+ return o;
+}
+
+#endif
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc
index 757a4691..ea224a27 100644
--- a/gi/clda/src/clda.cc
+++ b/gi/clda/src/clda.cc
@@ -1,9 +1,11 @@
#include <iostream>
#include <vector>
#include <map>
+#include <string>
#include "timer.h"
#include "crp.h"
+#include "ccrp.h"
#include "sampler.h"
#include "tdict.h"
const size_t MAX_DOC_LEN_CHARS = 10000000;
@@ -18,12 +20,44 @@ void ShowTopWordsForTopic(const map<WordID, int>& counts) {
for (multimap<int, WordID>::reverse_iterator it = ms.rbegin(); it != ms.rend(); ++it) {
cerr << it->first << ':' << TD::Convert(it->second) << " ";
++cc;
- if (cc==12) break;
+ if (cc==20) break;
}
cerr << endl;
}
+void tc() {
+ MT19937 rng;
+ CCRP<string> crp(0.1, 5);
+ double un = 0.25;
+ int tt = 0;
+ tt += crp.increment("hi", un, &rng);
+ tt += crp.increment("foo", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ tt += crp.increment("bar", un, &rng);
+ cout << "tt=" << tt << endl;
+ cout << crp << endl;
+ cout << " P(bar)=" << crp.prob("bar", un) << endl;
+ cout << " P(hi)=" << crp.prob("hi", un) << endl;
+ cout << " P(baz)=" << crp.prob("baz", un) << endl;
+ cout << " P(foo)=" << crp.prob("foo", un) << endl;
+ double x = crp.prob("bar", un) + crp.prob("hi", un) + crp.prob("baz", un) + crp.prob("foo", un);
+ cout << " tot=" << x << endl;
+ tt += crp.decrement("hi", &rng);
+ tt += crp.decrement("bar", &rng);
+ cout << crp << endl;
+ tt += crp.decrement("bar", &rng);
+ cout << crp << endl;
+ cout << "tt=" << tt << endl;
+ cout << crp.llh() << endl;
+}
+
int main(int argc, char** argv) {
+ tc();
if (argc != 3) {
cerr << "Usage: " << argv[0] << " num-classes num-samples\n";
return 1;
@@ -54,10 +88,13 @@ int main(int argc, char** argv) {
MT19937 rng;
cerr << "INITIALIZING RANDOM TOPIC ASSIGNMENTS\n";
zji.resize(wji.size());
- double beta = 0.1;
- double alpha = 50.0 / num_classes;
- vector<CRP<int> > dr(zji.size(), CRP<int>(beta)); // dr[i] describes the probability of using a topic in document i
- vector<CRP<int> > wr(num_classes, CRP<int>(alpha)); // wr[k] describes the probability of generating a word in topic k
+ double disc = 0.05;
+ double beta = 10.0;
+ double alpha = 50.0;
+ double uniform_topic = 1.0 / num_classes;
+ double uniform_word = 1.0 / TD::NumWords();
+ vector<CCRP<int> > dr(zji.size(), CCRP<int>(disc, beta)); // dr[i] describes the probability of using a topic in document i
+ vector<CCRP<int> > wr(num_classes, CCRP<int>(disc, alpha)); // wr[k] describes the probability of generating a word in topic k
for (int j = 0; j < zji.size(); ++j) {
const size_t num_words = wji[j].size();
vector<int>& zj = zji[j];
@@ -68,19 +105,15 @@ int main(int argc, char** argv) {
if (random_topic == num_classes) { --random_topic; }
zj[i] = random_topic;
const int word = wj[i];
- dr[j].increment(random_topic);
- wr[random_topic].increment(word);
+ dr[j].increment(random_topic, uniform_topic, &rng);
+ wr[random_topic].increment(word, uniform_word, &rng);
}
}
cerr << "SAMPLING\n";
vector<map<WordID, int> > t2w(num_classes);
Timer timer;
- SampleSet ss;
+ SampleSet<double> ss;
const int num_types = TD::NumWords();
- const prob_t class_p0(1.0 / num_classes);
- const prob_t word_p0(1.0 / num_types);
- cerr << "CLASS PRIOR PROB: " << class_p0 << endl;
- cerr << " WORD PRIOR LOGPROB: " << log(word_p0) << endl;
ss.resize(num_classes);
double total_time = 0;
for (int iter = 0; iter < num_iterations; ++iter) {
@@ -88,23 +121,6 @@ int main(int argc, char** argv) {
if (iter && iter % 10 == 0) {
total_time += timer.Elapsed();
timer.Reset();
- prob_t lh = prob_t::One();
- for (int j = 0; j < zji.size(); ++j) {
- const size_t num_words = wji[j].size();
- vector<int>& zj = zji[j];
- const vector<int>& wj = wji[j];
- for (int i = 0; i < num_words; ++i) {
- const int word = wj[i];
- const int cur_topic = zj[i];
- lh *= dr[j].prob(cur_topic, class_p0);
- lh *= wr[cur_topic].prob(word, word_p0);
- if (iter > burnin_size) {
- ++t2w[cur_topic][word];
- }
- }
- }
- if (iter && iter % 40 == 0) { cerr << " [ITER=" << iter << " SEC/SAMPLE=" << (total_time / 40) << " LLH=" << log(lh) << "]\n"; total_time=0; }
- //cerr << "ITERATION " << iter << " LOG LIKELIHOOD: " << log(lh) << endl;
}
for (int j = 0; j < zji.size(); ++j) {
const size_t num_words = wji[j].size();
@@ -113,16 +129,19 @@ int main(int argc, char** argv) {
for (int i = 0; i < num_words; ++i) {
const int word = wj[i];
const int cur_topic = zj[i];
- dr[j].decrement(cur_topic);
- wr[cur_topic].decrement(word);
+ dr[j].decrement(cur_topic, &rng);
+ wr[cur_topic].decrement(word, &rng);
for (int k = 0; k < num_classes; ++k) {
- ss[k]= dr[j].prob(k, class_p0) * wr[k].prob(word, word_p0);
+ ss[k]= dr[j].prob(k, uniform_topic) * wr[k].prob(word, uniform_word);
}
const int new_topic = rng.SelectSample(ss);
- dr[j].increment(new_topic);
- wr[new_topic].increment(word);
+ dr[j].increment(new_topic, uniform_topic, &rng);
+ wr[new_topic].increment(word, uniform_word, &rng);
zj[i] = new_topic;
+ if (iter > burnin_size) {
+ ++t2w[cur_topic][word];
+ }
}
}
}
diff --git a/utils/sampler.h b/utils/sampler.h
index 5fef45d0..f75d96b6 100644
--- a/utils/sampler.h
+++ b/utils/sampler.h
@@ -18,7 +18,7 @@
#include "prob.h"
-struct SampleSet;
+template <typename F> struct SampleSet;
template <typename RNG>
struct RandomNumberGenerator {
@@ -45,7 +45,8 @@ struct RandomNumberGenerator {
m_generator.seed(seed);
}
- size_t SelectSample(const prob_t& a, const prob_t& b, double T = 1.0) {
+ template <typename F>
+ size_t SelectSample(const F& a, const F& b, double T = 1.0) {
if (T == 1.0) {
if (this->next() > (a / (a + b))) return 1; else return 0;
} else {
@@ -54,7 +55,8 @@ struct RandomNumberGenerator {
}
// T is the annealing temperature, if desired
- size_t SelectSample(const SampleSet& ss, double T = 1.0);
+ template <typename F>
+ size_t SelectSample(const SampleSet<F>& ss, double T = 1.0);
// draw a value from U(0,1)
double next() {return m_random();}
@@ -94,36 +96,38 @@ struct RandomNumberGenerator {
typedef RandomNumberGenerator<boost::mt19937> MT19937;
+template <typename F = double>
class SampleSet {
public:
- const prob_t& operator[](int i) const { return m_scores[i]; }
- prob_t& operator[](int i) { return m_scores[i]; }
+ const F& operator[](int i) const { return m_scores[i]; }
+ F& operator[](int i) { return m_scores[i]; }
bool empty() const { return m_scores.empty(); }
void add(const prob_t& s) { m_scores.push_back(s); }
void clear() { m_scores.clear(); }
size_t size() const { return m_scores.size(); }
void resize(int size) { m_scores.resize(size); }
- std::vector<prob_t> m_scores;
+ std::vector<F> m_scores;
};
template <typename RNG>
-size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet& ss, double T) {
+template <typename F>
+size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet<F>& ss, double T) {
assert(T > 0.0);
assert(ss.m_scores.size() > 0);
if (ss.m_scores.size() == 1) return 0;
- const prob_t annealing_factor(1.0 / T);
- const bool anneal = (annealing_factor != prob_t::One());
- prob_t sum = prob_t::Zero();
+ const double annealing_factor = 1.0 / T;
+ const bool anneal = (T != 1.0);
+ F sum = F(0);
if (anneal) {
for (int i = 0; i < ss.m_scores.size(); ++i)
- sum += ss.m_scores[i].pow(annealing_factor); // p^(1/T)
+ sum += pow(ss.m_scores[i], annealing_factor); // p^(1/T)
} else {
- sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), prob_t::Zero());
+ sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), F(0));
}
//for (size_t i = 0; i < ss.m_scores.size(); ++i) std::cerr << ss.m_scores[i] << ",";
//std::cerr << std::endl;
- prob_t random(this->next()); // random number between 0 and 1
+ F random(this->next()); // random number between 0 and 1
random *= sum; // scale with normalization factor
//std::cerr << "Random number " << random << std::endl;
@@ -131,9 +135,9 @@ size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet& ss, double T) {
size_t position = 1;
sum = ss.m_scores[0];
if (anneal) {
- sum.poweq(annealing_factor);
+ sum = pow(sum, annealing_factor);
for (; position < ss.m_scores.size() && sum < random; ++position)
- sum += ss.m_scores[position].pow(annealing_factor);
+ sum += pow(ss.m_scores[position], annealing_factor);
} else {
for (; position < ss.m_scores.size() && sum < random; ++position)
sum += ss.m_scores[position];