diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-22 22:31:28 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-22 22:31:28 +0000 |
commit | 9be9a5dde934577de314ce8ac6fb3eb0ba787503 (patch) | |
tree | 557cc31667174994d39e741203dc7b155622b9a9 /gi/clda/src | |
parent | 2f2ba42a1453f4a3a08f9c1ecfc53c1b1c83d550 (diff) |
chris's crappy lda
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@6 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/clda/src')
-rw-r--r-- | gi/clda/src/Makefile.am | 6 | ||||
-rw-r--r-- | gi/clda/src/clda.cc | 140 | ||||
-rw-r--r-- | gi/clda/src/crp.h | 216 | ||||
-rw-r--r-- | gi/clda/src/dict.h | 43 | ||||
-rw-r--r-- | gi/clda/src/logval.h | 157 | ||||
-rw-r--r-- | gi/clda/src/prob.h | 8 | ||||
-rw-r--r-- | gi/clda/src/sampler.h | 138 | ||||
-rw-r--r-- | gi/clda/src/tdict.h | 49 | ||||
-rw-r--r-- | gi/clda/src/timer.h | 18 | ||||
-rw-r--r-- | gi/clda/src/wordid.h | 6 |
10 files changed, 781 insertions, 0 deletions
diff --git a/gi/clda/src/Makefile.am b/gi/clda/src/Makefile.am new file mode 100644 index 00000000..ebb016db --- /dev/null +++ b/gi/clda/src/Makefile.am @@ -0,0 +1,6 @@ +bin_PROGRAMS = clda + +clda_SOURCES = clda.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) +AM_LDFLAGS = -lz diff --git a/gi/clda/src/clda.cc b/gi/clda/src/clda.cc new file mode 100644 index 00000000..49702df3 --- /dev/null +++ b/gi/clda/src/clda.cc @@ -0,0 +1,140 @@ +#include <iostream> +#include <vector> +#include <map> + +#include "timer.h" +#include "crp.h" +#include "sampler.h" +#include "tdict.h" +Dict TD::dict_; +std::string TD::empty = ""; +std::string TD::space = " "; + +using namespace std; + +void ShowTopWords(const map<WordID, int>& counts) { + multimap<int, WordID> ms; + for (map<WordID,int>::const_iterator it = counts.begin(); it != counts.end(); ++it) + ms.insert(make_pair(it->second, it->first)); +} + +int main(int argc, char** argv) { + if (argc != 2) { + cerr << "Usage: " << argv[0] << " num-classes\n"; + return 1; + } + const int num_classes = atoi(argv[1]); + if (num_classes < 2) { + cerr << "Must request more than 1 class\n"; + return 1; + } + cerr << "CLASSES: " << num_classes << endl; + char* buf = new char[800000]; + vector<vector<int> > wji; // w[j][i] - observed word i of doc j + vector<vector<int> > zji; // z[j][i] - topic assignment for word i of doc j + cerr << "READING DOCUMENTS\n"; + while(cin) { + cin.getline(buf, 800000); + if (buf[0] == 0) continue; + wji.push_back(vector<WordID>()); + TD::ConvertSentence(buf, &wji.back()); + } + cerr << "READ " << wji.size() << " DOCUMENTS\n"; + MT19937 rng; + cerr << "INITIALIZING RANDOM TOPIC ASSIGNMENTS\n"; + zji.resize(wji.size()); + double beta = 0.01; + double alpha = 0.001; + vector<CRP<int> > dr(zji.size(), CRP<int>(beta)); // dr[i] describes the probability of using a topic in document i + vector<CRP<int> > wr(num_classes, CRP<int>(alpha)); // wr[k] describes the probability of generating a word in topic k + int random_topic = rng.next() * num_classes; + for (int j = 0; j < zji.size(); ++j) { + const size_t num_words = wji[j].size(); + vector<int>& zj = zji[j]; + const vector<int>& wj = wji[j]; + zj.resize(num_words); + for (int i = 0; i < num_words; ++i) { + if (random_topic == num_classes) { --random_topic; } + zj[i] = random_topic; + const int word = wj[i]; + dr[j].increment(random_topic); + wr[random_topic].increment(word); + } + } + cerr << "SAMPLING\n"; + vector<map<WordID, int> > t2w(num_classes); + const int num_iterations = 1000; + const int burnin_size = 800; + bool needline = false; + Timer timer; + SampleSet ss; + ss.resize(num_classes); + double total_time = 0; + for (int iter = 0; iter < num_iterations; ++iter) { + if (iter && iter % 10 == 0) { + total_time += timer.Elapsed(); + timer.Reset(); + cerr << '.'; needline=true; + prob_t lh = prob_t::One(); + for (int j = 0; j < zji.size(); ++j) { + const size_t num_words = wji[j].size(); + vector<int>& zj = zji[j]; + const vector<int>& wj = wji[j]; + for (int i = 0; i < num_words; ++i) { + const int word = wj[i]; + const int cur_topic = zj[i]; + lh *= dr[j].prob(cur_topic); + lh *= wr[cur_topic].prob(word); + if (iter > burnin_size) { + ++t2w[cur_topic][word]; + } + } + } + if (iter && iter % 200 == 0) { cerr << " [ITER=" << iter << " SEC/SAMPLE=" << (total_time / 200) << " LLH=" << log(lh) << "]\n"; needline=false; total_time=0; } + //cerr << "ITERATION " << iter << " LOG LIKELIHOOD: " << log(lh) << endl; + } + for (int j = 0; j < zji.size(); ++j) { + const size_t num_words = wji[j].size(); + vector<int>& zj = zji[j]; + const vector<int>& wj = wji[j]; + for (int i = 0; i < num_words; ++i) { + const int word = wj[i]; + const int cur_topic = zj[i]; + dr[j].decrement(cur_topic); + wr[cur_topic].decrement(word); + + for (int k = 0; k < num_classes; ++k) { + ss[k]= dr[j].prob(k) * wr[k].prob(word); + } + const int new_topic = rng.SelectSample(ss); + dr[j].increment(new_topic); + wr[new_topic].increment(word); + zj[i] = new_topic; + } + } + } + if (needline) cerr << endl; + for (int j = 0; j < zji.size(); ++j) { + const size_t num_words = wji[j].size(); + vector<int>& zj = zji[j]; + const vector<int>& wj = wji[j]; + zj.resize(num_words); + for (int i = 0; i < num_words; ++i) { + cout << TD::Convert(wj[i]) << '(' << zj[i] << ") "; + } + cout << endl; + } + for (int i = 0; i < num_classes; ++i) { + ShowTopWords(t2w[i]); + } + for (map<int,int>::iterator it = t2w[0].begin(); it != t2w[0].end(); ++it) + cerr << TD::Convert(it->first) << " " << it->second << endl; + cerr << "---------------------------------\n"; + for (map<int,int>::iterator it = t2w[1].begin(); it != t2w[1].end(); ++it) + cerr << TD::Convert(it->first) << " " << it->second << endl; + cerr << "---------------------------------\n"; + for (map<int,int>::iterator it = t2w[2].begin(); it != t2w[2].end(); ++it) + cerr << TD::Convert(it->first) << " " << it->second << endl; + return 0; +} + diff --git a/gi/clda/src/crp.h b/gi/clda/src/crp.h new file mode 100644 index 00000000..13596cbf --- /dev/null +++ b/gi/clda/src/crp.h @@ -0,0 +1,216 @@ +#ifndef _CRP_H_ +#define _CRP_H_ + +// shamelessly adapted from code by Phil Blunsom and Trevor Cohn +// There are TWO CRP classes here: CRPWithTableTracking tracks the +// (expected) number of customers per table, and CRP just tracks +// the number of customers / dish. +// If you are implementing a HDP model, you should use CRP for the +// base distribution and CRPWithTableTracking for the dependent +// distribution. + +#include <iostream> +#include <map> +#include <boost/functional/hash.hpp> +#include <tr1/unordered_map> + +#include "prob.h" +#include "sampler.h" // RNG + +template <typename DishType, typename Hash = boost::hash<DishType> > +class CRP { + public: + CRP(double alpha) : alpha_(alpha), palpha_(alpha), total_customers_() {} + void increment(const DishType& dish); + void decrement(const DishType& dish); + void erase(const DishType& dish) { + counts_.erase(dish); + } + inline int count(const DishType& dish) const { + const typename MapType::const_iterator i = counts_.find(dish); + if (i == counts_.end()) return 0; else return i->second; + } + inline prob_t prob(const DishType& dish) const { + return (prob_t(count(dish) + alpha_)) / prob_t(total_customers_ + alpha_); + } + inline prob_t prob(const DishType& dish, const prob_t& p0) const { + return (prob_t(count(dish)) + palpha_ * p0) / prob_t(total_customers_ + alpha_); + } + private: + typedef std::tr1::unordered_map<DishType, int, Hash> MapType; + MapType counts_; + const double alpha_; + const prob_t palpha_; + int total_customers_; +}; + +template <typename Dish, typename Hash> +void CRP<Dish,Hash>::increment(const Dish& dish) { + ++counts_[dish]; + ++total_customers_; +} + +template <typename Dish, typename Hash> +void CRP<Dish,Hash>::decrement(const Dish& dish) { + typename MapType::iterator i = counts_.find(dish); + assert(i != counts_.end()); + if (--i->second == 0) + counts_.erase(i); + --total_customers_; +} + +template <typename DishType, typename Hash = boost::hash<DishType>, typename RNG = MT19937> +class CRPWithTableTracking { + public: + CRPWithTableTracking(double alpha, RNG* rng) : + alpha_(alpha), palpha_(alpha), total_customers_(), + total_tables_(), rng_(rng) {} + + // seat a customer for dish d, returns the delta in tables + // with customers + int increment(const DishType& d, const prob_t& p0 = prob_t::One()); + int decrement(const DishType& d); + void erase(const DishType& dish); + + inline int count(const DishType& dish) const { + const typename MapType::const_iterator i = counts_.find(dish); + if (i == counts_.end()) return 0; else return i->second.count_; + } + inline prob_t prob(const DishType& dish) const { + return (prob_t(count(dish) + alpha_)) / prob_t(total_customers_ + alpha_); + } + inline prob_t prob(const DishType& dish, const prob_t& p0) const { + return (prob_t(count(dish)) + palpha_ * p0) / prob_t(total_customers_ + alpha_); + } + private: + struct TableInfo { + TableInfo() : count_(), tables_() {} + int count_; // total customers eating dish + int tables_; // total tables labeled with dish + std::map<int, int> table_histogram_; // num customers at table -> number tables + }; + typedef std::tr1::unordered_map<DishType, TableInfo, Hash> MapType; + + inline prob_t prob_share_table(const double& customer_count) const { + if (customer_count) + return prob_t(customer_count) / prob_t(customer_count + alpha_); + else + return prob_t::Zero(); + } + inline prob_t prob_new_table(const double& customer_count, const prob_t& p0) const { + if (customer_count) + return palpha_ * p0 / prob_t(customer_count + alpha_); + else + return p0; + } + + MapType counts_; + const double alpha_; + const prob_t palpha_; + int total_customers_; + int total_tables_; + RNG* rng_; +}; + +template <typename Dish, typename Hash, typename RNG> +int CRPWithTableTracking<Dish,Hash,RNG>::increment(const Dish& dish, const prob_t& p0) { + TableInfo& tc = counts_[dish]; + + //std::cerr << "\nincrement for " << dish << " with p0 " << p0 << "\n"; + //std::cerr << "\tBEFORE histogram: " << tc.table_histogram_ << " "; + //std::cerr << "count: " << tc.count_ << " "; + //std::cerr << "tables: " << tc.tables_ << "\n"; + + // seated at a new or existing table? + prob_t pshare = prob_share_table(tc.count_); + prob_t pnew = prob_new_table(tc.count_, p0); + + //std::cerr << "\t\tP0 " << p0 << " count(dish) " << count(dish) + // << " tables " << tc + // << " p(share) " << pshare << " p(new) " << pnew << "\n"; + + int delta = 0; + if (tc.count_ == 0 || rng_->SelectSample(pshare, pnew) == 1) { + // assign to a new table + ++tc.tables_; + ++tc.table_histogram_[1]; + ++total_tables_; + delta = 1; + } else { + // can't share a table if there are no other customers + assert(tc.count_ > 0); + + // randomly assign to an existing table + // remove constant denominator from inner loop + int r = static_cast<int>(rng_->next() * tc.count_); + for (std::map<int,int>::iterator hit = tc.table_histogram_.begin(); + hit != tc.table_histogram_.end(); ++hit) { + r -= hit->first * hit->second; + if (r <= 0) { + ++tc.table_histogram_[hit->first+1]; + --hit->second; + if (hit->second == 0) + tc.table_histogram_.erase(hit); + break; + } + } + if (r > 0) { + std::cerr << "CONSISTENCY ERROR: " << tc.count_ << std::endl; + std::cerr << pshare << std::endl; + std::cerr << pnew << std::endl; + std::cerr << r << std::endl; + abort(); + } + } + ++tc.count_; + ++total_customers_; + return delta; +} + +template <typename Dish, typename Hash, typename RNG> +int CRPWithTableTracking<Dish,Hash,RNG>::decrement(const Dish& dish) { + typename MapType::iterator i = counts_.find(dish); + if(i == counts_.end()) { + std::cerr << "MISSING DISH: " << dish << std::endl; + abort(); + } + + int delta = 0; + TableInfo &tc = i->second; + + //std::cout << "\ndecrement for " << dish << " with p0 " << p0 << "\n"; + //std::cout << "\tBEFORE histogram: " << tc.table_histogram << " "; + //std::cout << "count: " << count(dish) << " "; + //std::cout << "tables: " << tc.tables << "\n"; + + int r = static_cast<int>(rng_->next() * tc.count_); + //std::cerr << "FOO: " << r << std::endl; + for (std::map<int,int>::iterator hit = tc.table_histogram_.begin(); + hit != tc.table_histogram_.end(); ++hit) { + r -= (hit->first * hit->second); + if (r <= 0) { + if (hit->first > 1) + tc.table_histogram_[hit->first-1] += 1; + else { + --delta; + --tc.tables_; + --total_tables_; + } + + --hit->second; + if (hit->second == 0) tc.table_histogram_.erase(hit); + break; + } + } + + assert(r <= 0); + + // remove the customer + --tc.count_; + --total_customers_; + assert(tc.count_ >= 0); + if (tc.count_ == 0) counts_.erase(i); + return delta; +} + +#endif diff --git a/gi/clda/src/dict.h b/gi/clda/src/dict.h new file mode 100644 index 00000000..72e82e6d --- /dev/null +++ b/gi/clda/src/dict.h @@ -0,0 +1,43 @@ +#ifndef DICT_H_ +#define DICT_H_ + +#include <cassert> +#include <cstring> +#include <tr1/unordered_map> +#include <string> +#include <vector> + +#include <boost/functional/hash.hpp> + +#include "wordid.h" + +class Dict { + typedef std::tr1::unordered_map<std::string, WordID, boost::hash<std::string> > Map; + public: + Dict() : b0_("<bad0>") { words_.reserve(1000); } + inline int max() const { return words_.size(); } + inline WordID Convert(const std::string& word, bool frozen = false) { + Map::iterator i = d_.find(word); + if (i == d_.end()) { + if (frozen) + return 0; + words_.push_back(word); + d_[word] = words_.size(); + return words_.size(); + } else { + return i->second; + } + } + inline const std::string& Convert(const WordID& id) const { + if (id == 0) return b0_; + assert(id <= words_.size()); + return words_[id-1]; + } + void clear() { words_.clear(); d_.clear(); } + private: + const std::string b0_; + std::vector<std::string> words_; + Map d_; +}; + +#endif diff --git a/gi/clda/src/logval.h b/gi/clda/src/logval.h new file mode 100644 index 00000000..7099b9be --- /dev/null +++ b/gi/clda/src/logval.h @@ -0,0 +1,157 @@ +#ifndef LOGVAL_H_ +#define LOGVAL_H_ + +#include <iostream> +#include <cstdlib> +#include <cmath> +#include <limits> + +template <typename T> +class LogVal { + public: + LogVal() : s_(), v_(-std::numeric_limits<T>::infinity()) {} + explicit LogVal(double x) : s_(std::signbit(x)), v_(s_ ? std::log(-x) : std::log(x)) {} + static LogVal<T> One() { return LogVal(1); } + static LogVal<T> Zero() { return LogVal(); } + + void logeq(const T& v) { s_ = false; v_ = v; } + + LogVal& operator+=(const LogVal& a) { + if (a.v_ == -std::numeric_limits<T>::infinity()) return *this; + if (a.s_ == s_) { + if (a.v_ < v_) { + v_ = v_ + log1p(std::exp(a.v_ - v_)); + } else { + v_ = a.v_ + log1p(std::exp(v_ - a.v_)); + } + } else { + if (a.v_ < v_) { + v_ = v_ + log1p(-std::exp(a.v_ - v_)); + } else { + v_ = a.v_ + log1p(-std::exp(v_ - a.v_)); + s_ = !s_; + } + } + return *this; + } + + LogVal& operator*=(const LogVal& a) { + s_ = (s_ != a.s_); + v_ += a.v_; + return *this; + } + + LogVal& operator/=(const LogVal& a) { + s_ = (s_ != a.s_); + v_ -= a.v_; + return *this; + } + + LogVal& operator-=(const LogVal& a) { + LogVal b = a; + b.invert(); + return *this += b; + } + + LogVal& poweq(const T& power) { + if (s_) { + std::cerr << "poweq(T) not implemented when s_ is true\n"; + std::abort(); + } else { + v_ *= power; + } + return *this; + } + + void invert() { s_ = !s_; } + + LogVal pow(const T& power) const { + LogVal res = *this; + res.poweq(power); + return res; + } + + operator T() const { + if (s_) return -std::exp(v_); else return std::exp(v_); + } + + bool s_; + T v_; +}; + +template<typename T> +LogVal<T> operator+(const LogVal<T>& o1, const LogVal<T>& o2) { + LogVal<T> res(o1); + res += o2; + return res; +} + +template<typename T> +LogVal<T> operator*(const LogVal<T>& o1, const LogVal<T>& o2) { + LogVal<T> res(o1); + res *= o2; + return res; +} + +template<typename T> +LogVal<T> operator/(const LogVal<T>& o1, const LogVal<T>& o2) { + LogVal<T> res(o1); + res /= o2; + return res; +} + +template<typename T> +LogVal<T> operator-(const LogVal<T>& o1, const LogVal<T>& o2) { + LogVal<T> res(o1); + res -= o2; + return res; +} + +template<typename T> +T log(const LogVal<T>& o) { + if (o.s_) return log(-1.0); + return o.v_; +} + +template <typename T> +LogVal<T> pow(const LogVal<T>& b, const T& e) { + return b.pow(e); +} + +template <typename T> +bool operator<(const LogVal<T>& lhs, const LogVal<T>& rhs) { + if (lhs.s_ == rhs.s_) { + return (lhs.v_ < rhs.v_); + } else { + return lhs.s_ > rhs.s_; + } +} + +#if 0 +template <typename T> +bool operator<=(const LogVal<T>& lhs, const LogVal<T>& rhs) { + return (lhs.v_ <= rhs.v_); +} + +template <typename T> +bool operator>(const LogVal<T>& lhs, const LogVal<T>& rhs) { + return (lhs.v_ > rhs.v_); +} + +template <typename T> +bool operator>=(const LogVal<T>& lhs, const LogVal<T>& rhs) { + return (lhs.v_ >= rhs.v_); +} +#endif + +template <typename T> +bool operator==(const LogVal<T>& lhs, const LogVal<T>& rhs) { + return (lhs.v_ == rhs.v_) && (lhs.s_ == rhs.s_); +} + +template <typename T> +bool operator!=(const LogVal<T>& lhs, const LogVal<T>& rhs) { + return !(lhs == rhs); +} + +#endif diff --git a/gi/clda/src/prob.h b/gi/clda/src/prob.h new file mode 100644 index 00000000..bc297870 --- /dev/null +++ b/gi/clda/src/prob.h @@ -0,0 +1,8 @@ +#ifndef _PROB_H_ +#define _PROB_H_ + +#include "logval.h" + +typedef LogVal<double> prob_t; + +#endif diff --git a/gi/clda/src/sampler.h b/gi/clda/src/sampler.h new file mode 100644 index 00000000..4d0b2e64 --- /dev/null +++ b/gi/clda/src/sampler.h @@ -0,0 +1,138 @@ +#ifndef SAMPLER_H_ +#define SAMPLER_H_ + +#include <algorithm> +#include <functional> +#include <numeric> +#include <iostream> +#include <fstream> +#include <vector> + +#include <boost/random/mersenne_twister.hpp> +#include <boost/random/uniform_real.hpp> +#include <boost/random/variate_generator.hpp> +#include <boost/random/normal_distribution.hpp> +#include <boost/random/poisson_distribution.hpp> + +#include "prob.h" + +struct SampleSet; + +template <typename RNG> +struct RandomNumberGenerator { + static uint32_t GetTrulyRandomSeed() { + uint32_t seed; + std::ifstream r("/dev/urandom"); + if (r) { + r.read((char*)&seed,sizeof(uint32_t)); + } + if (r.fail() || !r) { + std::cerr << "Warning: could not read from /dev/urandom. Seeding from clock" << std::endl; + seed = time(NULL); + } + std::cerr << "Seeding random number sequence to " << seed << std::endl; + return seed; + } + + RandomNumberGenerator() : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) { + uint32_t seed = GetTrulyRandomSeed(); + m_generator.seed(seed); + } + explicit RandomNumberGenerator(uint32_t seed) : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) { + if (!seed) seed = GetTrulyRandomSeed(); + m_generator.seed(seed); + } + + size_t SelectSample(const prob_t& a, const prob_t& b, double T = 1.0) { + if (T == 1.0) { + if (this->next() > (a / (a + b))) return 1; else return 0; + } else { + assert(!"not implemented"); + } + } + + // T is the annealing temperature, if desired + size_t SelectSample(const SampleSet& ss, double T = 1.0); + + // draw a value from U(0,1) + double next() {return m_random();} + + // draw a value from N(mean,var) + double NextNormal(double mean, double var) { + return boost::normal_distribution<double>(mean, var)(m_random); + } + + // draw a value from a Poisson distribution + // lambda must be greater than 0 + int NextPoisson(int lambda) { + return boost::poisson_distribution<int>(lambda)(m_random); + } + + bool AcceptMetropolisHastings(const prob_t& p_cur, + const prob_t& p_prev, + const prob_t& q_cur, + const prob_t& q_prev) { + const prob_t a = (p_cur / p_prev) * (q_prev / q_cur); + if (log(a) >= 0.0) return true; + return (prob_t(this->next()) < a); + } + + private: + boost::uniform_real<> m_dist; + RNG m_generator; + boost::variate_generator<RNG&, boost::uniform_real<> > m_random; +}; + +typedef RandomNumberGenerator<boost::mt19937> MT19937; + +class SampleSet { + public: + const prob_t& operator[](int i) const { return m_scores[i]; } + prob_t& operator[](int i) { return m_scores[i]; } + bool empty() const { return m_scores.empty(); } + void add(const prob_t& s) { m_scores.push_back(s); } + void clear() { m_scores.clear(); } + size_t size() const { return m_scores.size(); } + void resize(int size) { m_scores.resize(size); } + std::vector<prob_t> m_scores; +}; + +template <typename RNG> +size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet& ss, double T) { + assert(T > 0.0); + assert(ss.m_scores.size() > 0); + if (ss.m_scores.size() == 1) return 0; + const prob_t annealing_factor(1.0 / T); + const bool anneal = (annealing_factor != prob_t::One()); + prob_t sum = prob_t::Zero(); + if (anneal) { + for (int i = 0; i < ss.m_scores.size(); ++i) + sum += ss.m_scores[i].pow(annealing_factor); // p^(1/T) + } else { + sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), prob_t::Zero()); + } + //for (size_t i = 0; i < ss.m_scores.size(); ++i) std::cerr << ss.m_scores[i] << ","; + //std::cerr << std::endl; + + prob_t random(this->next()); // random number between 0 and 1 + random *= sum; // scale with normalization factor + //std::cerr << "Random number " << random << std::endl; + + //now figure out which sample + size_t position = 1; + sum = ss.m_scores[0]; + if (anneal) { + sum.poweq(annealing_factor); + for (; position < ss.m_scores.size() && sum < random; ++position) + sum += ss.m_scores[position].pow(annealing_factor); + } else { + for (; position < ss.m_scores.size() && sum < random; ++position) + sum += ss.m_scores[position]; + } + //std::cout << "random: " << random << " sample: " << position << std::endl; + //std::cerr << "Sample: " << position-1 << std::endl; + //exit(1); + return position-1; +} + +#endif diff --git a/gi/clda/src/tdict.h b/gi/clda/src/tdict.h new file mode 100644 index 00000000..97f145a1 --- /dev/null +++ b/gi/clda/src/tdict.h @@ -0,0 +1,49 @@ +#ifndef _TDICT_H_ +#define _TDICT_H_ + +#include <string> +#include <vector> +#include "wordid.h" +#include "dict.h" + +class Vocab; + +struct TD { + + static Dict dict_; + static std::string empty; + static std::string space; + + static std::string GetString(const std::vector<WordID>& str) { + std::string res; + for (std::vector<WordID>::const_iterator i = str.begin(); i != str.end(); ++i) + res += (i == str.begin() ? empty : space) + TD::Convert(*i); + return res; + } + + static void ConvertSentence(const std::string& sent, std::vector<WordID>* ids) { + std::string s = sent; + int last = 0; + ids->clear(); + for (int i=0; i < s.size(); ++i) + if (s[i] == 32 || s[i] == '\t') { + s[i]=0; + if (last != i) { + ids->push_back(Convert(&s[last])); + } + last = i + 1; + } + if (last != s.size()) + ids->push_back(Convert(&s[last])); + } + + static WordID Convert(const std::string& s) { + return dict_.Convert(s); + } + + static const std::string& Convert(const WordID& w) { + return dict_.Convert(w); + } +}; + +#endif diff --git a/gi/clda/src/timer.h b/gi/clda/src/timer.h new file mode 100644 index 00000000..ca26b304 --- /dev/null +++ b/gi/clda/src/timer.h @@ -0,0 +1,18 @@ +#ifndef _TIMER_STATS_H_ +#define _TIMER_STATS_H_ + +struct Timer { + Timer() { Reset(); } + void Reset() { + start_t = clock(); + } + double Elapsed() const { + const clock_t end_t = clock(); + const double elapsed = (end_t - start_t) / 1000000.0; + return elapsed; + } + private: + clock_t start_t; +}; + +#endif diff --git a/gi/clda/src/wordid.h b/gi/clda/src/wordid.h new file mode 100644 index 00000000..fb50bcc1 --- /dev/null +++ b/gi/clda/src/wordid.h @@ -0,0 +1,6 @@ +#ifndef _WORD_ID_H_ +#define _WORD_ID_H_ + +typedef int WordID; + +#endif |