summaryrefslogtreecommitdiff
path: root/gi/clda/src
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-22 22:31:28 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-22 22:31:28 +0000
commit9be9a5dde934577de314ce8ac6fb3eb0ba787503 (patch)
tree557cc31667174994d39e741203dc7b155622b9a9 /gi/clda/src
parent2f2ba42a1453f4a3a08f9c1ecfc53c1b1c83d550 (diff)
chris's crappy lda
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@6 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/clda/src')
-rw-r--r--gi/clda/src/Makefile.am6
-rw-r--r--gi/clda/src/clda.cc140
-rw-r--r--gi/clda/src/crp.h216
-rw-r--r--gi/clda/src/dict.h43
-rw-r--r--gi/clda/src/logval.h157
-rw-r--r--gi/clda/src/prob.h8
-rw-r--r--gi/clda/src/sampler.h138
-rw-r--r--gi/clda/src/tdict.h49
-rw-r--r--gi/clda/src/timer.h18
-rw-r--r--gi/clda/src/wordid.h6
10 files changed, 781 insertions, 0 deletions
diff --git a/gi/clda/src/Makefile.am b/gi/clda/src/Makefile.am
new file mode 100644
index 00000000..ebb016db
--- /dev/null
+++ b/gi/clda/src/Makefile.am
@@ -0,0 +1,6 @@
+bin_PROGRAMS = clda
+
+clda_SOURCES = clda.cc
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS)
+AM_LDFLAGS = -lz
diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc
new file mode 100644
index 00000000..49702df3
--- /dev/null
+++ b/gi/clda/src/clda.cc
@@ -0,0 +1,140 @@
+#include <iostream>
+#include <vector>
+#include <map>
+
+#include "timer.h"
+#include "crp.h"
+#include "sampler.h"
+#include "tdict.h"
+Dict TD::dict_;
+std::string TD::empty = "";
+std::string TD::space = " ";
+
+using namespace std;
+
+void ShowTopWords(const map<WordID, int>& counts) {
+ multimap<int, WordID> ms;
+ for (map<WordID,int>::const_iterator it = counts.begin(); it != counts.end(); ++it)
+ ms.insert(make_pair(it->second, it->first));
+}
+
+int main(int argc, char** argv) {
+ if (argc != 2) {
+ cerr << "Usage: " << argv[0] << " num-classes\n";
+ return 1;
+ }
+ const int num_classes = atoi(argv[1]);
+ if (num_classes < 2) {
+ cerr << "Must request more than 1 class\n";
+ return 1;
+ }
+ cerr << "CLASSES: " << num_classes << endl;
+ char* buf = new char[800000];
+ vector<vector<int> > wji; // w[j][i] - observed word i of doc j
+ vector<vector<int> > zji; // z[j][i] - topic assignment for word i of doc j
+ cerr << "READING DOCUMENTS\n";
+ while(cin) {
+ cin.getline(buf, 800000);
+ if (buf[0] == 0) continue;
+ wji.push_back(vector<WordID>());
+ TD::ConvertSentence(buf, &wji.back());
+ }
+ cerr << "READ " << wji.size() << " DOCUMENTS\n";
+ MT19937 rng;
+ cerr << "INITIALIZING RANDOM TOPIC ASSIGNMENTS\n";
+ zji.resize(wji.size());
+ double beta = 0.01;
+ double alpha = 0.001;
+ vector<CRP<int> > dr(zji.size(), CRP<int>(beta)); // dr[i] describes the probability of using a topic in document i
+ vector<CRP<int> > wr(num_classes, CRP<int>(alpha)); // wr[k] describes the probability of generating a word in topic k
+ int random_topic = rng.next() * num_classes;
+ for (int j = 0; j < zji.size(); ++j) {
+ const size_t num_words = wji[j].size();
+ vector<int>& zj = zji[j];
+ const vector<int>& wj = wji[j];
+ zj.resize(num_words);
+ for (int i = 0; i < num_words; ++i) {
+ if (random_topic == num_classes) { --random_topic; }
+ zj[i] = random_topic;
+ const int word = wj[i];
+ dr[j].increment(random_topic);
+ wr[random_topic].increment(word);
+ }
+ }
+ cerr << "SAMPLING\n";
+ vector<map<WordID, int> > t2w(num_classes);
+ const int num_iterations = 1000;
+ const int burnin_size = 800;
+ bool needline = false;
+ Timer timer;
+ SampleSet ss;
+ ss.resize(num_classes);
+ double total_time = 0;
+ for (int iter = 0; iter < num_iterations; ++iter) {
+ if (iter && iter % 10 == 0) {
+ total_time += timer.Elapsed();
+ timer.Reset();
+ cerr << '.'; needline=true;
+ prob_t lh = prob_t::One();
+ for (int j = 0; j < zji.size(); ++j) {
+ const size_t num_words = wji[j].size();
+ vector<int>& zj = zji[j];
+ const vector<int>& wj = wji[j];
+ for (int i = 0; i < num_words; ++i) {
+ const int word = wj[i];
+ const int cur_topic = zj[i];
+ lh *= dr[j].prob(cur_topic);
+ lh *= wr[cur_topic].prob(word);
+ if (iter > burnin_size) {
+ ++t2w[cur_topic][word];
+ }
+ }
+ }
+ if (iter && iter % 200 == 0) { cerr << " [ITER=" << iter << " SEC/SAMPLE=" << (total_time / 200) << " LLH=" << log(lh) << "]\n"; needline=false; total_time=0; }
+ //cerr << "ITERATION " << iter << " LOG LIKELIHOOD: " << log(lh) << endl;
+ }
+ for (int j = 0; j < zji.size(); ++j) {
+ const size_t num_words = wji[j].size();
+ vector<int>& zj = zji[j];
+ const vector<int>& wj = wji[j];
+ for (int i = 0; i < num_words; ++i) {
+ const int word = wj[i];
+ const int cur_topic = zj[i];
+ dr[j].decrement(cur_topic);
+ wr[cur_topic].decrement(word);
+
+ for (int k = 0; k < num_classes; ++k) {
+ ss[k]= dr[j].prob(k) * wr[k].prob(word);
+ }
+ const int new_topic = rng.SelectSample(ss);
+ dr[j].increment(new_topic);
+ wr[new_topic].increment(word);
+ zj[i] = new_topic;
+ }
+ }
+ }
+ if (needline) cerr << endl;
+ for (int j = 0; j < zji.size(); ++j) {
+ const size_t num_words = wji[j].size();
+ vector<int>& zj = zji[j];
+ const vector<int>& wj = wji[j];
+ zj.resize(num_words);
+ for (int i = 0; i < num_words; ++i) {
+ cout << TD::Convert(wj[i]) << '(' << zj[i] << ") ";
+ }
+ cout << endl;
+ }
+ for (int i = 0; i < num_classes; ++i) {
+ ShowTopWords(t2w[i]);
+ }
+ for (map<int,int>::iterator it = t2w[0].begin(); it != t2w[0].end(); ++it)
+ cerr << TD::Convert(it->first) << " " << it->second << endl;
+ cerr << "---------------------------------\n";
+ for (map<int,int>::iterator it = t2w[1].begin(); it != t2w[1].end(); ++it)
+ cerr << TD::Convert(it->first) << " " << it->second << endl;
+ cerr << "---------------------------------\n";
+ for (map<int,int>::iterator it = t2w[2].begin(); it != t2w[2].end(); ++it)
+ cerr << TD::Convert(it->first) << " " << it->second << endl;
+ return 0;
+}
+
diff --git a/gi/clda/src/crp.h b/gi/clda/src/crp.h
new file mode 100644
index 00000000..13596cbf
--- /dev/null
+++ b/gi/clda/src/crp.h
@@ -0,0 +1,216 @@
+#ifndef _CRP_H_
+#define _CRP_H_
+
+// shamelessly adapted from code by Phil Blunsom and Trevor Cohn
+// There are TWO CRP classes here: CRPWithTableTracking tracks the
+// (expected) number of customers per table, and CRP just tracks
+// the number of customers / dish.
+// If you are implementing a HDP model, you should use CRP for the
+// base distribution and CRPWithTableTracking for the dependent
+// distribution.
+
+#include <iostream>
+#include <map>
+#include <boost/functional/hash.hpp>
+#include <tr1/unordered_map>
+
+#include "prob.h"
+#include "sampler.h" // RNG
+
+template <typename DishType, typename Hash = boost::hash<DishType> >
+class CRP {
+ public:
+ CRP(double alpha) : alpha_(alpha), palpha_(alpha), total_customers_() {}
+ void increment(const DishType& dish);
+ void decrement(const DishType& dish);
+ void erase(const DishType& dish) {
+ counts_.erase(dish);
+ }
+ inline int count(const DishType& dish) const {
+ const typename MapType::const_iterator i = counts_.find(dish);
+ if (i == counts_.end()) return 0; else return i->second;
+ }
+ inline prob_t prob(const DishType& dish) const {
+ return (prob_t(count(dish) + alpha_)) / prob_t(total_customers_ + alpha_);
+ }
+ inline prob_t prob(const DishType& dish, const prob_t& p0) const {
+ return (prob_t(count(dish)) + palpha_ * p0) / prob_t(total_customers_ + alpha_);
+ }
+ private:
+ typedef std::tr1::unordered_map<DishType, int, Hash> MapType;
+ MapType counts_;
+ const double alpha_;
+ const prob_t palpha_;
+ int total_customers_;
+};
+
+template <typename Dish, typename Hash>
+void CRP<Dish,Hash>::increment(const Dish& dish) {
+ ++counts_[dish];
+ ++total_customers_;
+}
+
+template <typename Dish, typename Hash>
+void CRP<Dish,Hash>::decrement(const Dish& dish) {
+ typename MapType::iterator i = counts_.find(dish);
+ assert(i != counts_.end());
+ if (--i->second == 0)
+ counts_.erase(i);
+ --total_customers_;
+}
+
+template <typename DishType, typename Hash = boost::hash<DishType>, typename RNG = MT19937>
+class CRPWithTableTracking {
+ public:
+ CRPWithTableTracking(double alpha, RNG* rng) :
+ alpha_(alpha), palpha_(alpha), total_customers_(),
+ total_tables_(), rng_(rng) {}
+
+ // seat a customer for dish d, returns the delta in tables
+ // with customers
+ int increment(const DishType& d, const prob_t& p0 = prob_t::One());
+ int decrement(const DishType& d);
+ void erase(const DishType& dish);
+
+ inline int count(const DishType& dish) const {
+ const typename MapType::const_iterator i = counts_.find(dish);
+ if (i == counts_.end()) return 0; else return i->second.count_;
+ }
+ inline prob_t prob(const DishType& dish) const {
+ return (prob_t(count(dish) + alpha_)) / prob_t(total_customers_ + alpha_);
+ }
+ inline prob_t prob(const DishType& dish, const prob_t& p0) const {
+ return (prob_t(count(dish)) + palpha_ * p0) / prob_t(total_customers_ + alpha_);
+ }
+ private:
+ struct TableInfo {
+ TableInfo() : count_(), tables_() {}
+ int count_; // total customers eating dish
+ int tables_; // total tables labeled with dish
+ std::map<int, int> table_histogram_; // num customers at table -> number tables
+ };
+ typedef std::tr1::unordered_map<DishType, TableInfo, Hash> MapType;
+
+ inline prob_t prob_share_table(const double& customer_count) const {
+ if (customer_count)
+ return prob_t(customer_count) / prob_t(customer_count + alpha_);
+ else
+ return prob_t::Zero();
+ }
+ inline prob_t prob_new_table(const double& customer_count, const prob_t& p0) const {
+ if (customer_count)
+ return palpha_ * p0 / prob_t(customer_count + alpha_);
+ else
+ return p0;
+ }
+
+ MapType counts_;
+ const double alpha_;
+ const prob_t palpha_;
+ int total_customers_;
+ int total_tables_;
+ RNG* rng_;
+};
+
+template <typename Dish, typename Hash, typename RNG>
+int CRPWithTableTracking<Dish,Hash,RNG>::increment(const Dish& dish, const prob_t& p0) {
+ TableInfo& tc = counts_[dish];
+
+ //std::cerr << "\nincrement for " << dish << " with p0 " << p0 << "\n";
+ //std::cerr << "\tBEFORE histogram: " << tc.table_histogram_ << " ";
+ //std::cerr << "count: " << tc.count_ << " ";
+ //std::cerr << "tables: " << tc.tables_ << "\n";
+
+ // seated at a new or existing table?
+ prob_t pshare = prob_share_table(tc.count_);
+ prob_t pnew = prob_new_table(tc.count_, p0);
+
+ //std::cerr << "\t\tP0 " << p0 << " count(dish) " << count(dish)
+ // << " tables " << tc
+ // << " p(share) " << pshare << " p(new) " << pnew << "\n";
+
+ int delta = 0;
+ if (tc.count_ == 0 || rng_->SelectSample(pshare, pnew) == 1) {
+ // assign to a new table
+ ++tc.tables_;
+ ++tc.table_histogram_[1];
+ ++total_tables_;
+ delta = 1;
+ } else {
+ // can't share a table if there are no other customers
+ assert(tc.count_ > 0);
+
+ // randomly assign to an existing table
+ // remove constant denominator from inner loop
+ int r = static_cast<int>(rng_->next() * tc.count_);
+ for (std::map<int,int>::iterator hit = tc.table_histogram_.begin();
+ hit != tc.table_histogram_.end(); ++hit) {
+ r -= hit->first * hit->second;
+ if (r <= 0) {
+ ++tc.table_histogram_[hit->first+1];
+ --hit->second;
+ if (hit->second == 0)
+ tc.table_histogram_.erase(hit);
+ break;
+ }
+ }
+ if (r > 0) {
+ std::cerr << "CONSISTENCY ERROR: " << tc.count_ << std::endl;
+ std::cerr << pshare << std::endl;
+ std::cerr << pnew << std::endl;
+ std::cerr << r << std::endl;
+ abort();
+ }
+ }
+ ++tc.count_;
+ ++total_customers_;
+ return delta;
+}
+
+template <typename Dish, typename Hash, typename RNG>
+int CRPWithTableTracking<Dish,Hash,RNG>::decrement(const Dish& dish) {
+ typename MapType::iterator i = counts_.find(dish);
+ if(i == counts_.end()) {
+ std::cerr << "MISSING DISH: " << dish << std::endl;
+ abort();
+ }
+
+ int delta = 0;
+ TableInfo &tc = i->second;
+
+ //std::cout << "\ndecrement for " << dish << " with p0 " << p0 << "\n";
+ //std::cout << "\tBEFORE histogram: " << tc.table_histogram << " ";
+ //std::cout << "count: " << count(dish) << " ";
+ //std::cout << "tables: " << tc.tables << "\n";
+
+ int r = static_cast<int>(rng_->next() * tc.count_);
+ //std::cerr << "FOO: " << r << std::endl;
+ for (std::map<int,int>::iterator hit = tc.table_histogram_.begin();
+ hit != tc.table_histogram_.end(); ++hit) {
+ r -= (hit->first * hit->second);
+ if (r <= 0) {
+ if (hit->first > 1)
+ tc.table_histogram_[hit->first-1] += 1;
+ else {
+ --delta;
+ --tc.tables_;
+ --total_tables_;
+ }
+
+ --hit->second;
+ if (hit->second == 0) tc.table_histogram_.erase(hit);
+ break;
+ }
+ }
+
+ assert(r <= 0);
+
+ // remove the customer
+ --tc.count_;
+ --total_customers_;
+ assert(tc.count_ >= 0);
+ if (tc.count_ == 0) counts_.erase(i);
+ return delta;
+}
+
+#endif
diff --git a/gi/clda/src/dict.h b/gi/clda/src/dict.h
new file mode 100644
index 00000000..72e82e6d
--- /dev/null
+++ b/gi/clda/src/dict.h
@@ -0,0 +1,43 @@
+#ifndef DICT_H_
+#define DICT_H_
+
+#include <cassert>
+#include <cstring>
+#include <tr1/unordered_map>
+#include <string>
+#include <vector>
+
+#include <boost/functional/hash.hpp>
+
+#include "wordid.h"
+
+class Dict {
+ typedef std::tr1::unordered_map<std::string, WordID, boost::hash<std::string> > Map;
+ public:
+ Dict() : b0_("<bad0>") { words_.reserve(1000); }
+ inline int max() const { return words_.size(); }
+ inline WordID Convert(const std::string& word, bool frozen = false) {
+ Map::iterator i = d_.find(word);
+ if (i == d_.end()) {
+ if (frozen)
+ return 0;
+ words_.push_back(word);
+ d_[word] = words_.size();
+ return words_.size();
+ } else {
+ return i->second;
+ }
+ }
+ inline const std::string& Convert(const WordID& id) const {
+ if (id == 0) return b0_;
+ assert(id <= words_.size());
+ return words_[id-1];
+ }
+ void clear() { words_.clear(); d_.clear(); }
+ private:
+ const std::string b0_;
+ std::vector<std::string> words_;
+ Map d_;
+};
+
+#endif
diff --git a/gi/clda/src/logval.h b/gi/clda/src/logval.h
new file mode 100644
index 00000000..7099b9be
--- /dev/null
+++ b/gi/clda/src/logval.h
@@ -0,0 +1,157 @@
+#ifndef LOGVAL_H_
+#define LOGVAL_H_
+
+#include <iostream>
+#include <cstdlib>
+#include <cmath>
+#include <limits>
+
+template <typename T>
+class LogVal {
+ public:
+ LogVal() : s_(), v_(-std::numeric_limits<T>::infinity()) {}
+ explicit LogVal(double x) : s_(std::signbit(x)), v_(s_ ? std::log(-x) : std::log(x)) {}
+ static LogVal<T> One() { return LogVal(1); }
+ static LogVal<T> Zero() { return LogVal(); }
+
+ void logeq(const T& v) { s_ = false; v_ = v; }
+
+ LogVal& operator+=(const LogVal& a) {
+ if (a.v_ == -std::numeric_limits<T>::infinity()) return *this;
+ if (a.s_ == s_) {
+ if (a.v_ < v_) {
+ v_ = v_ + log1p(std::exp(a.v_ - v_));
+ } else {
+ v_ = a.v_ + log1p(std::exp(v_ - a.v_));
+ }
+ } else {
+ if (a.v_ < v_) {
+ v_ = v_ + log1p(-std::exp(a.v_ - v_));
+ } else {
+ v_ = a.v_ + log1p(-std::exp(v_ - a.v_));
+ s_ = !s_;
+ }
+ }
+ return *this;
+ }
+
+ LogVal& operator*=(const LogVal& a) {
+ s_ = (s_ != a.s_);
+ v_ += a.v_;
+ return *this;
+ }
+
+ LogVal& operator/=(const LogVal& a) {
+ s_ = (s_ != a.s_);
+ v_ -= a.v_;
+ return *this;
+ }
+
+ LogVal& operator-=(const LogVal& a) {
+ LogVal b = a;
+ b.invert();
+ return *this += b;
+ }
+
+ LogVal& poweq(const T& power) {
+ if (s_) {
+ std::cerr << "poweq(T) not implemented when s_ is true\n";
+ std::abort();
+ } else {
+ v_ *= power;
+ }
+ return *this;
+ }
+
+ void invert() { s_ = !s_; }
+
+ LogVal pow(const T& power) const {
+ LogVal res = *this;
+ res.poweq(power);
+ return res;
+ }
+
+ operator T() const {
+ if (s_) return -std::exp(v_); else return std::exp(v_);
+ }
+
+ bool s_;
+ T v_;
+};
+
+template<typename T>
+LogVal<T> operator+(const LogVal<T>& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o1);
+ res += o2;
+ return res;
+}
+
+template<typename T>
+LogVal<T> operator*(const LogVal<T>& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o1);
+ res *= o2;
+ return res;
+}
+
+template<typename T>
+LogVal<T> operator/(const LogVal<T>& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o1);
+ res /= o2;
+ return res;
+}
+
+template<typename T>
+LogVal<T> operator-(const LogVal<T>& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o1);
+ res -= o2;
+ return res;
+}
+
+template<typename T>
+T log(const LogVal<T>& o) {
+ if (o.s_) return log(-1.0);
+ return o.v_;
+}
+
+template <typename T>
+LogVal<T> pow(const LogVal<T>& b, const T& e) {
+ return b.pow(e);
+}
+
+template <typename T>
+bool operator<(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ if (lhs.s_ == rhs.s_) {
+ return (lhs.v_ < rhs.v_);
+ } else {
+ return lhs.s_ > rhs.s_;
+ }
+}
+
+#if 0
+template <typename T>
+bool operator<=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ <= rhs.v_);
+}
+
+template <typename T>
+bool operator>(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ > rhs.v_);
+}
+
+template <typename T>
+bool operator>=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ >= rhs.v_);
+}
+#endif
+
+template <typename T>
+bool operator==(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ == rhs.v_) && (lhs.s_ == rhs.s_);
+}
+
+template <typename T>
+bool operator!=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return !(lhs == rhs);
+}
+
+#endif
diff --git a/gi/clda/src/prob.h b/gi/clda/src/prob.h
new file mode 100644
index 00000000..bc297870
--- /dev/null
+++ b/gi/clda/src/prob.h
@@ -0,0 +1,8 @@
+#ifndef _PROB_H_
+#define _PROB_H_
+
+#include "logval.h"
+
+typedef LogVal<double> prob_t;
+
+#endif
diff --git a/gi/clda/src/sampler.h b/gi/clda/src/sampler.h
new file mode 100644
index 00000000..4d0b2e64
--- /dev/null
+++ b/gi/clda/src/sampler.h
@@ -0,0 +1,138 @@
+#ifndef SAMPLER_H_
+#define SAMPLER_H_
+
+#include <algorithm>
+#include <functional>
+#include <numeric>
+#include <iostream>
+#include <fstream>
+#include <vector>
+
+#include <boost/random/mersenne_twister.hpp>
+#include <boost/random/uniform_real.hpp>
+#include <boost/random/variate_generator.hpp>
+#include <boost/random/normal_distribution.hpp>
+#include <boost/random/poisson_distribution.hpp>
+
+#include "prob.h"
+
+struct SampleSet;
+
+template <typename RNG>
+struct RandomNumberGenerator {
+ static uint32_t GetTrulyRandomSeed() {
+ uint32_t seed;
+ std::ifstream r("/dev/urandom");
+ if (r) {
+ r.read((char*)&seed,sizeof(uint32_t));
+ }
+ if (r.fail() || !r) {
+ std::cerr << "Warning: could not read from /dev/urandom. Seeding from clock" << std::endl;
+ seed = time(NULL);
+ }
+ std::cerr << "Seeding random number sequence to " << seed << std::endl;
+ return seed;
+ }
+
+ RandomNumberGenerator() : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) {
+ uint32_t seed = GetTrulyRandomSeed();
+ m_generator.seed(seed);
+ }
+ explicit RandomNumberGenerator(uint32_t seed) : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) {
+ if (!seed) seed = GetTrulyRandomSeed();
+ m_generator.seed(seed);
+ }
+
+ size_t SelectSample(const prob_t& a, const prob_t& b, double T = 1.0) {
+ if (T == 1.0) {
+ if (this->next() > (a / (a + b))) return 1; else return 0;
+ } else {
+ assert(!"not implemented");
+ }
+ }
+
+ // T is the annealing temperature, if desired
+ size_t SelectSample(const SampleSet& ss, double T = 1.0);
+
+ // draw a value from U(0,1)
+ double next() {return m_random();}
+
+ // draw a value from N(mean,var)
+ double NextNormal(double mean, double var) {
+ return boost::normal_distribution<double>(mean, var)(m_random);
+ }
+
+ // draw a value from a Poisson distribution
+ // lambda must be greater than 0
+ int NextPoisson(int lambda) {
+ return boost::poisson_distribution<int>(lambda)(m_random);
+ }
+
+ bool AcceptMetropolisHastings(const prob_t& p_cur,
+ const prob_t& p_prev,
+ const prob_t& q_cur,
+ const prob_t& q_prev) {
+ const prob_t a = (p_cur / p_prev) * (q_prev / q_cur);
+ if (log(a) >= 0.0) return true;
+ return (prob_t(this->next()) < a);
+ }
+
+ private:
+ boost::uniform_real<> m_dist;
+ RNG m_generator;
+ boost::variate_generator<RNG&, boost::uniform_real<> > m_random;
+};
+
+typedef RandomNumberGenerator<boost::mt19937> MT19937;
+
+class SampleSet {
+ public:
+ const prob_t& operator[](int i) const { return m_scores[i]; }
+ prob_t& operator[](int i) { return m_scores[i]; }
+ bool empty() const { return m_scores.empty(); }
+ void add(const prob_t& s) { m_scores.push_back(s); }
+ void clear() { m_scores.clear(); }
+ size_t size() const { return m_scores.size(); }
+ void resize(int size) { m_scores.resize(size); }
+ std::vector<prob_t> m_scores;
+};
+
+template <typename RNG>
+size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet& ss, double T) {
+ assert(T > 0.0);
+ assert(ss.m_scores.size() > 0);
+ if (ss.m_scores.size() == 1) return 0;
+ const prob_t annealing_factor(1.0 / T);
+ const bool anneal = (annealing_factor != prob_t::One());
+ prob_t sum = prob_t::Zero();
+ if (anneal) {
+ for (int i = 0; i < ss.m_scores.size(); ++i)
+ sum += ss.m_scores[i].pow(annealing_factor); // p^(1/T)
+ } else {
+ sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), prob_t::Zero());
+ }
+ //for (size_t i = 0; i < ss.m_scores.size(); ++i) std::cerr << ss.m_scores[i] << ",";
+ //std::cerr << std::endl;
+
+ prob_t random(this->next()); // random number between 0 and 1
+ random *= sum; // scale with normalization factor
+ //std::cerr << "Random number " << random << std::endl;
+
+ //now figure out which sample
+ size_t position = 1;
+ sum = ss.m_scores[0];
+ if (anneal) {
+ sum.poweq(annealing_factor);
+ for (; position < ss.m_scores.size() && sum < random; ++position)
+ sum += ss.m_scores[position].pow(annealing_factor);
+ } else {
+ for (; position < ss.m_scores.size() && sum < random; ++position)
+ sum += ss.m_scores[position];
+ }
+ //std::cout << "random: " << random << " sample: " << position << std::endl;
+ //std::cerr << "Sample: " << position-1 << std::endl;
+ //exit(1);
+ return position-1;
+}
+
+#endif
diff --git a/gi/clda/src/tdict.h b/gi/clda/src/tdict.h
new file mode 100644
index 00000000..97f145a1
--- /dev/null
+++ b/gi/clda/src/tdict.h
@@ -0,0 +1,49 @@
+#ifndef _TDICT_H_
+#define _TDICT_H_
+
+#include <string>
+#include <vector>
+#include "wordid.h"
+#include "dict.h"
+
+class Vocab;
+
+struct TD {
+
+ static Dict dict_;
+ static std::string empty;
+ static std::string space;
+
+ static std::string GetString(const std::vector<WordID>& str) {
+ std::string res;
+ for (std::vector<WordID>::const_iterator i = str.begin(); i != str.end(); ++i)
+ res += (i == str.begin() ? empty : space) + TD::Convert(*i);
+ return res;
+ }
+
+ static void ConvertSentence(const std::string& sent, std::vector<WordID>* ids) {
+ std::string s = sent;
+ int last = 0;
+ ids->clear();
+ for (int i=0; i < s.size(); ++i)
+ if (s[i] == 32 || s[i] == '\t') {
+ s[i]=0;
+ if (last != i) {
+ ids->push_back(Convert(&s[last]));
+ }
+ last = i + 1;
+ }
+ if (last != s.size())
+ ids->push_back(Convert(&s[last]));
+ }
+
+ static WordID Convert(const std::string& s) {
+ return dict_.Convert(s);
+ }
+
+ static const std::string& Convert(const WordID& w) {
+ return dict_.Convert(w);
+ }
+};
+
+#endif
diff --git a/gi/clda/src/timer.h b/gi/clda/src/timer.h
new file mode 100644
index 00000000..ca26b304
--- /dev/null
+++ b/gi/clda/src/timer.h
@@ -0,0 +1,18 @@
+#ifndef _TIMER_STATS_H_
+#define _TIMER_STATS_H_
+
+struct Timer {
+ Timer() { Reset(); }
+ void Reset() {
+ start_t = clock();
+ }
+ double Elapsed() const {
+ const clock_t end_t = clock();
+ const double elapsed = (end_t - start_t) / 1000000.0;
+ return elapsed;
+ }
+ private:
+ clock_t start_t;
+};
+
+#endif
diff --git a/gi/clda/src/wordid.h b/gi/clda/src/wordid.h
new file mode 100644
index 00000000..fb50bcc1
--- /dev/null
+++ b/gi/clda/src/wordid.h
@@ -0,0 +1,6 @@
+#ifndef _WORD_ID_H_
+#define _WORD_ID_H_
+
+typedef int WordID;
+
+#endif