From 41446328cf06a64e729835719d99fef33ec59941 Mon Sep 17 00:00:00 2001 From: bothameister Date: Mon, 5 Jul 2010 23:31:35 +0000 Subject: migrating away from mt19937ar to Boost.Random - separate RNG instances used in various places git-svn-id: https://ws10smt.googlecode.com/svn/trunk@146 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/pyp-topics/src/pyp-topics.cc | 13 +++++++------ gi/pyp-topics/src/pyp-topics.hh | 23 +++++++++++++++++++++-- gi/pyp-topics/src/pyp.hh | 32 +++++++++++++++++++++++--------- gi/pyp-topics/src/train-contexts.cc | 8 +++----- gi/pyp-topics/src/train.cc | 8 +++----- 5 files changed, 57 insertions(+), 27 deletions(-) (limited to 'gi/pyp-topics/src') diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc index 2ad9d080..2b96816e 100644 --- a/gi/pyp-topics/src/pyp-topics.cc +++ b/gi/pyp-topics/src/pyp-topics.cc @@ -5,7 +5,6 @@ #endif #include "pyp-topics.hh" -//#include "mt19937ar.h" #include #include @@ -46,13 +45,13 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, { m_word_pyps.at(i).reserve(m_num_topics); for (int j=0; j(0.5, 1.0)); + m_word_pyps.at(i).push_back(new PYP(0.5, 1.0, m_seed)); } std::cerr << std::endl; m_document_pyps.reserve(corpus.num_documents()); for (int j=0; j(0.5, 1.0)); + m_document_pyps.push_back(new PYP(0.5, 1.0, m_seed)); m_topic_p0 = 1.0/m_num_topics; m_term_p0 = 1.0/corpus.num_types(); @@ -118,8 +117,10 @@ void PYPTopics::sample_corpus(const Corpus& corpus, int samples, int tmp; for (int i = corpus.num_documents()-1; i > 0; --i) { - int j = (int)(mt_genrand_real1() * i); - tmp = randomDocIndices[i]; + //i+1 since j \in [0,i] but rnd() \in [0,1) + int j = (int)(rnd() * (i+1)); + assert(j >= 0 && j <= i); + tmp = randomDocIndices[i]; randomDocIndices[i] = randomDocIndices[j]; randomDocIndices[j] = tmp; } @@ -258,7 +259,7 @@ int PYPTopics::sample(const DocumentId& doc, const Term& term) { sums.push_back(sum); } // Second pass: sample a topic - F cutoff = mt_genrand_res53() * sum; + F cutoff = rnd() * sum; for (int k=0; k #include #include + +#include +#include +#include + #include "pyp.hh" #include "corpus.hh" @@ -15,9 +20,12 @@ public: typedef double F; public: - PYPTopics(int num_topics, bool use_topic_pyp=false) + PYPTopics(int num_topics, bool use_topic_pyp=false, unsigned long seed = 0) : m_num_topics(num_topics), m_word_pyps(1), - m_topic_pyp(0.5,1.0), m_use_topic_pyp(use_topic_pyp) {} + m_topic_pyp(0.5,1.0,seed), m_use_topic_pyp(use_topic_pyp), + m_seed(seed), + uni_dist(0,1), rng(seed == 0 ? (unsigned long)this : seed), + rnd(rng, uni_dist) {} void sample_corpus(const Corpus& corpus, int samples, int freq_cutoff_start=0, int freq_cutoff_end=0, @@ -60,6 +68,17 @@ private: PYP m_topic_pyp; bool m_use_topic_pyp; + unsigned long m_seed; + + typedef boost::mt19937 base_generator_type; + typedef boost::uniform_real<> uni_dist_type; + typedef boost::variate_generator gen_type; + + uni_dist_type uni_dist; + base_generator_type rng; //this gets the seed + gen_type rnd; //instantiate: rnd(rng, uni_dist) + //call: rnd() generates uniform on [0,1) + TermBackoffPtr m_backoff; }; diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh index 80c79fe1..64fb5b58 100644 --- a/gi/pyp-topics/src/pyp.hh +++ b/gi/pyp-topics/src/pyp.hh @@ -5,10 +5,13 @@ #include #include +#include +#include +#include + #include "log_add.h" #include "gammadist.h" #include "slice-sampler.h" -#include "mt19937ar.h" // // Pitman-Yor process with customer and table tracking @@ -23,7 +26,7 @@ public: using std::tr1::unordered_map::begin; using std::tr1::unordered_map::end; - PYP(double a, double b, Hash hash=Hash()); + PYP(double a, double b, unsigned long seed = 0, Hash hash=Hash()); int increment(Dish d, double p0); int decrement(Dish d); @@ -80,6 +83,16 @@ private: DishTableType _dish_tables; int _total_customers, _total_tables; + typedef boost::mt19937 base_generator_type; + typedef boost::uniform_real<> uni_dist_type; + typedef boost::variate_generator gen_type; + + uni_dist_type uni_dist; + base_generator_type rng; //this gets the seed + gen_type rnd; //instantiate: rnd(rng, uni_dist) + //call: rnd() generates uniform on [0,1) + + // Function objects for calculating the parts of the log_prob for // the parameters a and b struct resample_a_type { @@ -122,11 +135,12 @@ private: }; template -PYP::PYP(double a, double b, Hash) +PYP::PYP(double a, double b, unsigned long seed, Hash) : std::tr1::unordered_map(), _a(a), _b(b), _a_beta_a(1), _a_beta_b(1), _b_gamma_s(1), _b_gamma_c(1), //_a_beta_a(1), _a_beta_b(1), _b_gamma_s(10), _b_gamma_c(0.1), - _total_customers(0), _total_tables(0) + _total_customers(0), _total_tables(0), + uni_dist(0,1), rng(seed == 0 ? (unsigned long)this : seed), rnd(rng, uni_dist) { // std::cerr << "\t##PYP::PYP(a=" << _a << ",b=" << _b << ")" << std::endl; } @@ -211,7 +225,7 @@ PYP::increment(Dish dish, double p0) { assert (pshare >= 0.0); //assert (pnew > 0.0); - if (mt_genrand_res53() < pnew / (pshare + pnew)) { + if (rnd() < pnew / (pshare + pnew)) { // assign to a new table tc.tables += 1; tc.table_histogram[1] += 1; @@ -221,7 +235,7 @@ PYP::increment(Dish dish, double p0) { else { // randomly assign to an existing table // remove constant denominator from inner loop - double r = mt_genrand_res53() * (c - _a*t); + double r = rnd() * (c - _a*t); for (std::map::iterator hit = tc.table_histogram.begin(); hit != tc.table_histogram.end(); ++hit) { @@ -283,7 +297,7 @@ PYP::decrement(Dish dish) //std::cerr << "count: " << count(dish) << " "; //std::cerr << "tables: " << tc.tables << "\n"; - double r = mt_genrand_res53() * count(dish); + double r = rnd() * count(dish); for (std::map::iterator hit = tc.table_histogram.begin(); hit != tc.table_histogram.end(); ++hit) { @@ -467,7 +481,7 @@ PYP::resample_prior_b() { int niterations = 10; // number of resampling iterations //std::cerr << "\n## resample_prior_b(), initial a = " << _a << ", b = " << _b << std::endl; resample_b_type b_log_prob(_total_customers, _total_tables, _a, _b_gamma_c, _b_gamma_s); - _b = slice_sampler1d(b_log_prob, _b, mt_genrand_res53, (double) 0.0, std::numeric_limits::infinity(), + _b = slice_sampler1d(b_log_prob, _b, rnd, (double) 0.0, std::numeric_limits::infinity(), (double) 0.0, niterations, 100*niterations); //std::cerr << "\n## resample_prior_b(), final a = " << _a << ", b = " << _b << std::endl; } @@ -481,7 +495,7 @@ PYP::resample_prior_a() { int niterations = 10; //std::cerr << "\n## Initial a = " << _a << ", b = " << _b << std::endl; resample_a_type a_log_prob(_total_customers, _total_tables, _b, _a_beta_a, _a_beta_b, _dish_tables); - _a = slice_sampler1d(a_log_prob, _a, mt_genrand_res53, std::numeric_limits::min(), + _a = slice_sampler1d(a_log_prob, _a, rnd, std::numeric_limits::min(), (double) 1.0, (double) 0.0, niterations, 100*niterations); } diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc index 481f8926..8a0c8949 100644 --- a/gi/pyp-topics/src/train-contexts.cc +++ b/gi/pyp-topics/src/train-contexts.cc @@ -14,7 +14,6 @@ #include "corpus.hh" #include "contexts_corpus.hh" #include "gzstream.hh" -#include "mt19937ar.h" static const char *REVISION = "$Rev$"; @@ -78,10 +77,9 @@ int main(int argc, char **argv) return 1; } - // seed the random number generator - //mt_init_genrand(time(0)); - - PYPTopics model(vm["topics"].as(), vm.count("hierarchical-topics")); + // seed the random number generator: 0 = automatic, specify value otherwise + unsigned long seed = 0; + PYPTopics model(vm["topics"].as(), vm.count("hierarchical-topics"), seed); // read the data BackoffGenerator* backoff_gen=0; diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc index c94010f2..3462f26c 100644 --- a/gi/pyp-topics/src/train.cc +++ b/gi/pyp-topics/src/train.cc @@ -12,7 +12,6 @@ #include "corpus.hh" #include "contexts_corpus.hh" #include "gzstream.hh" -#include "mt19937ar.h" static const char *REVISION = "$Rev$"; @@ -69,10 +68,9 @@ int main(int argc, char **argv) return 1; } - // seed the random number generator - //mt_init_genrand(time(0)); - - PYPTopics model(vm["topics"].as()); + // seed the random number generator: 0 = automatic, specify value otherwise + unsigned long seed = 0; + PYPTopics model(vm["topics"].as(), false, seed); // read the data Corpus corpus; -- cgit v1.2.3