From a144fb07effc59a3aa269d7fd5f3d0ab9dfe5e54 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 3 Jan 2012 16:59:11 -0500 Subject: multi-floor chinese restaurant described by wood&teh (2009) --- utils/mfcr_test.cc | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 utils/mfcr_test.cc (limited to 'utils/mfcr_test.cc') diff --git a/utils/mfcr_test.cc b/utils/mfcr_test.cc new file mode 100644 index 00000000..7c45a37c --- /dev/null +++ b/utils/mfcr_test.cc @@ -0,0 +1,72 @@ +#include "mfcr.h" + +#include +#include +#include + +#include "sampler.h" + +using namespace std; + +void test_exch(MT19937* rng) { + MFCR crp(2, 0.5, 3.0); + vector lambdas(2); + vector p0s(2); + lambdas[0] = 0.2; + lambdas[1] = 0.8; + p0s[0] = 1.0; + p0s[1] = 1.0; + + double tot = 0; + double tot2 = 0; + double xt = 0; + int cust = 10; + vector hist(cust + 1, 0), hist2(cust + 1, 0); + for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); } + const int samples = 100000; + const bool simulate = true; + for (int k = 0; k < samples; ++k) { + if (!simulate) { + crp.clear(); + for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); } + } else { + int da = rng->next() * cust; + bool a = rng->next() < 0.45; + if (a) { + for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); } + for (int i = 0; i < da; ++i) { crp.decrement(1, rng); } + xt += 1.0; + } else { + for (int i = 0; i < da; ++i) { crp.decrement(1, rng); } + for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); } + } + } + int c = crp.num_tables(1); + ++hist[c]; + tot += c; + int c2 = crp.num_tables(1,0); // tables on floor 0 with dish 1 + ++hist2[c2]; + tot2 += c2; + } + cerr << cust << " = " << crp.num_customers() << endl; + cerr << "P(a) = " << (xt / samples) << endl; + cerr << "E[num tables] = " << (tot / samples) << endl; + double error = fabs((tot / samples) - 6.894); + cerr << " error = " << error << endl; + for (int i = 1; i <= cust; ++i) + cerr << i << ' ' << (hist[i]) << endl; + cerr << "E[num tables on floor 0] = " << (tot2 / samples) << endl; + double error2 = fabs((tot2 / samples) - 1.379); + cerr << " error2 = " << error2 << endl; + for (int i = 1; i <= cust; ++i) + cerr << i << ' ' << (hist2[i]) << endl; + assert(error < 0.05); // these can fail with very low probability + assert(error2 < 0.05); +}; + +int main(int argc, char** argv) { + MT19937 rng; + test_exch(&rng); + return 0; +} + -- cgit v1.2.3 From 2048ac9943e2695a75b5f0303ca869e66ee32202 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 5 Mar 2012 16:06:45 -0500 Subject: use template parameter inference to figure out what type to use for probability computations, templatatize number of floors in MFCR rather than compile-time set --- gi/pf/align-lexonly-pyp.cc | 20 +++++++------- gi/pf/conditional_pseg.h | 22 +++++++-------- gi/pf/learn_cfg.cc | 8 +++--- utils/ccrp.h | 48 ++------------------------------ utils/mfcr.h | 68 ++++++++++++++++++++++++---------------------- utils/mfcr_test.cc | 10 +++---- 6 files changed, 68 insertions(+), 108 deletions(-) (limited to 'utils/mfcr_test.cc') diff --git a/gi/pf/align-lexonly-pyp.cc b/gi/pf/align-lexonly-pyp.cc index 87f7f6b5..ac0590e0 100644 --- a/gi/pf/align-lexonly-pyp.cc +++ b/gi/pf/align-lexonly-pyp.cc @@ -68,7 +68,7 @@ struct AlignedSentencePair { struct HierarchicalWordBase { explicit HierarchicalWordBase(const unsigned vocab_e_size) : - base(prob_t::One()), r(1,1,1,25,25), u0(-log(vocab_e_size)), l(1,1.0), v(1, 0.0) {} + base(prob_t::One()), r(1,1,1,1), u0(-log(vocab_e_size)), l(1,prob_t::One()), v(1, prob_t::Zero()) {} void ResampleHyperparameters(MT19937* rng) { r.resample_hyperparameters(rng); @@ -80,14 +80,14 @@ struct HierarchicalWordBase { // return p0 of rule.e_ prob_t operator()(const TRule& rule) const { - v[0] = exp(logp0(rule.e_)); - return prob_t(r.prob(rule.e_, v, l)); + v[0].logeq(logp0(rule.e_)); + return r.prob(rule.e_, v.begin(), l.begin()); } void Increment(const TRule& rule) { - v[0] = exp(logp0(rule.e_)); - if (r.increment(rule.e_, v, l, &*prng).count) { - base *= prob_t(v[0] * l[0]); + v[0].logeq(logp0(rule.e_)); + if (r.increment(rule.e_, v.begin(), l.begin(), &*prng).count) { + base *= v[0] * l[0]; } } @@ -105,15 +105,15 @@ struct HierarchicalWordBase { void Summary() const { cerr << "NUMBER OF CUSTOMERS: " << r.num_customers() << " (d=" << r.discount() << ",s=" << r.strength() << ')' << endl; - for (MFCR >::const_iterator it = r.begin(); it != r.end(); ++it) + for (MFCR<1,vector >::const_iterator it = r.begin(); it != r.end(); ++it) cerr << " " << it->second.total_dish_count_ << " (on " << it->second.table_counts_.size() << " tables)" << TD::GetString(it->first) << endl; } prob_t base; - MFCR > r; + MFCR<1,vector > r; const double u0; - const vector l; - mutable vector v; + const vector l; + mutable vector v; }; struct BasicLexicalAlignment { diff --git a/gi/pf/conditional_pseg.h b/gi/pf/conditional_pseg.h index 86403d8d..ef73e332 100644 --- a/gi/pf/conditional_pseg.h +++ b/gi/pf/conditional_pseg.h @@ -17,13 +17,13 @@ template struct MConditionalTranslationModel { explicit MConditionalTranslationModel(ConditionalBaseMeasure& rcp0) : - rp0(rcp0), lambdas(1, 1.0), p0s(1) {} + rp0(rcp0), lambdas(1, prob_t::One()), p0s(1) {} void Summary() const { std::cerr << "Number of conditioning contexts: " << r.size() << std::endl; for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) { std::cerr << TD::GetString(it->first) << " \t(d=" << it->second.discount() << ",s=" << it->second.strength() << ") --------------------------" << std::endl; - for (MFCR::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2) + for (MFCR<1,TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2) std::cerr << " " << -1 << '\t' << i2->first << std::endl; } } @@ -46,10 +46,10 @@ struct MConditionalTranslationModel { int IncrementRule(const TRule& rule, MT19937* rng) { RuleModelHash::iterator it = r.find(rule.f_); if (it == r.end()) { - it = r.insert(make_pair(rule.f_, MFCR(1, 1.0, 1.0, 1.0, 1.0, 1e-9, 4.0))).first; + it = r.insert(make_pair(rule.f_, MFCR<1,TRule>(1.0, 1.0, 1.0, 1.0, 1e-9, 4.0))).first; } - p0s[0] = rp0(rule).as_float(); - TableCount delta = it->second.increment(rule, p0s, lambdas, rng); + p0s[0] = rp0(rule); + TableCount delta = it->second.increment(rule, p0s.begin(), lambdas.begin(), rng); return delta.count; } @@ -57,10 +57,10 @@ struct MConditionalTranslationModel { prob_t p; RuleModelHash::const_iterator it = r.find(rule.f_); if (it == r.end()) { - p.logeq(log(rp0(rule))); + p = rp0(rule); } else { - p0s[0] = rp0(rule).as_float(); - p = prob_t(it->second.prob(rule, p0s, lambdas)); + p0s[0] = rp0(rule); + p = it->second.prob(rule, p0s.begin(), lambdas.begin()); } return p; } @@ -80,11 +80,11 @@ struct MConditionalTranslationModel { const ConditionalBaseMeasure& rp0; typedef std::tr1::unordered_map, - MFCR, + MFCR<1, TRule>, boost::hash > > RuleModelHash; RuleModelHash r; - std::vector lambdas; - mutable std::vector p0s; + std::vector lambdas; + mutable std::vector p0s; }; template diff --git a/gi/pf/learn_cfg.cc b/gi/pf/learn_cfg.cc index bf157828..ed1772bf 100644 --- a/gi/pf/learn_cfg.cc +++ b/gi/pf/learn_cfg.cc @@ -127,20 +127,20 @@ struct HieroLMModel { nts(num_nts, CCRP(1,1,1,1)) {} prob_t Prob(const TRule& r) const { - return nts[nt_id_to_index[-r.lhs_]].probT(r, p0(r)); + return nts[nt_id_to_index[-r.lhs_]].prob(r, p0(r)); } inline prob_t p0(const TRule& r) const { if (kHIERARCHICAL_PRIOR) - return q0.probT(r, base(r)); + return q0.prob(r, base(r)); else return base(r); } int Increment(const TRule& r, MT19937* rng) { - const int delta = nts[nt_id_to_index[-r.lhs_]].incrementT(r, p0(r), rng); + const int delta = nts[nt_id_to_index[-r.lhs_]].increment(r, p0(r), rng); if (kHIERARCHICAL_PRIOR && delta) - q0.incrementT(r, base(r), rng); + q0.increment(r, base(r), rng); return delta; // return x.increment(r); } diff --git a/utils/ccrp.h b/utils/ccrp.h index 5f9db7a6..e24130ac 100644 --- a/utils/ccrp.h +++ b/utils/ccrp.h @@ -92,42 +92,9 @@ class CCRP { return it->total_dish_count_; } - // returns +1 or 0 indicating whether a new table was opened - int increment(const Dish& dish, const double& p0, MT19937* rng) { - DishLocations& loc = dish_locs_[dish]; - bool share_table = false; - if (loc.total_dish_count_) { - const double p_empty = (strength_ + num_tables_ * discount_) * p0; - const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_); - share_table = rng->SelectSample(p_empty, p_share); - } - if (share_table) { - double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); - for (typename std::list::iterator ti = loc.table_counts_.begin(); - ti != loc.table_counts_.end(); ++ti) { - r -= (*ti - discount_); - if (r <= 0.0) { - ++(*ti); - break; - } - } - if (r > 0.0) { - std::cerr << "Serious error: r=" << r << std::endl; - Print(&std::cerr); - assert(r <= 0.0); - } - } else { - loc.table_counts_.push_back(1u); - ++num_tables_; - } - ++loc.total_dish_count_; - ++num_customers_; - return (share_table ? 0 : 1); - } - // returns +1 or 0 indicating whether a new table was opened template - int incrementT(const Dish& dish, const T& p0, MT19937* rng) { + int increment(const Dish& dish, const T& p0, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; bool share_table = false; if (loc.total_dish_count_) { @@ -196,19 +163,8 @@ class CCRP { } } - double prob(const Dish& dish, const double& p0) const { - const typename std::tr1::unordered_map::const_iterator it = dish_locs_.find(dish); - const double r = num_tables_ * discount_ + strength_; - if (it == dish_locs_.end()) { - return r * p0 / (num_customers_ + strength_); - } else { - return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) / - (num_customers_ + strength_); - } - } - template - T probT(const Dish& dish, const T& p0) const { + T prob(const Dish& dish, const T& p0) const { const typename std::tr1::unordered_map::const_iterator it = dish_locs_.find(dish); const T r = T(num_tables_ * discount_ + strength_); if (it == dish_locs_.end()) { diff --git a/utils/mfcr.h b/utils/mfcr.h index aeaf599d..6cc0ebf1 100644 --- a/utils/mfcr.h +++ b/utils/mfcr.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include "sampler.h" @@ -35,12 +36,11 @@ std::ostream& operator<<(std::ostream& o, const TableCount& tc) { // referenced therein. // http://www.aclweb.org/anthology/P/P09/P09-2085.pdf // -template > +template > class MFCR { public: - MFCR(unsigned num_floors, double d, double strength) : - num_floors_(num_floors), + MFCR(double d, double strength) : num_tables_(), num_customers_(), discount_(d), @@ -50,8 +50,7 @@ class MFCR { strength_prior_shape_(std::numeric_limits::quiet_NaN()), strength_prior_rate_(std::numeric_limits::quiet_NaN()) {} - MFCR(unsigned num_floors, double discount_strength, double discount_beta, double strength_shape, double strength_rate, double d = 0.9, double strength = 10.0) : - num_floors_(num_floors), + MFCR(double discount_strength, double discount_beta, double strength_shape, double strength_rate, double d = 0.9, double strength = 10.0) : num_tables_(), num_customers_(), discount_(d), @@ -111,22 +110,22 @@ class MFCR { } // returns (delta, floor) indicating whether a new table (delta) was opened and on which floor - TableCount increment(const Dish& dish, const std::vector& p0s, const std::vector& lambdas, MT19937* rng) { - assert(p0s.size() == num_floors_); - assert(lambdas.size() == num_floors_); - + template + TableCount increment(const Dish& dish, InputIterator p0s, InputIterator2 lambdas, MT19937* rng) { DishLocations& loc = dish_locs_[dish]; // marg_p0 = marginal probability of opening a new table on any floor with label dish - const double marg_p0 = std::inner_product(p0s.begin(), p0s.end(), lambdas.begin(), 0.0); - assert(marg_p0 <= 1.0); + typedef typename std::iterator_traits::value_type F; + const F marg_p0 = std::inner_product(p0s, p0s + Floors, lambdas, F(0.0)); + assert(marg_p0 <= F(1.0001)); int floor = -1; bool share_table = false; if (loc.total_dish_count_) { - const double p_empty = (strength_ + num_tables_ * discount_) * marg_p0; - const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_); + const F p_empty = F(strength_ + num_tables_ * discount_) * marg_p0; + const F p_share = F(loc.total_dish_count_ - loc.table_counts_.size() * discount_); share_table = rng->SelectSample(p_empty, p_share); } if (share_table) { + // this can be done with doubles since P0 (which may be tiny) is not involved double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_); for (typename std::list::iterator ti = loc.table_counts_.begin(); ti != loc.table_counts_.end(); ++ti) { @@ -143,12 +142,18 @@ class MFCR { assert(r <= 0.0); } } else { // sit at currently empty table -- must sample what floor - double r = rng->next() * marg_p0; - for (unsigned i = 0; i < p0s.size(); ++i) { - r -= p0s[i] * lambdas[i]; - if (r <= 0.0) { - floor = i; - break; + if (Floors == 1) { + floor = 0; + } else { + F r = F(rng->next()) * marg_p0; + for (unsigned i = 0; i < Floors; ++i) { + r -= (*p0s) * (*lambdas); + ++p0s; + ++lambdas; + if (r <= F(0.0)) { + floor = i; + break; + } } } assert(floor >= 0); @@ -200,18 +205,18 @@ class MFCR { return TableCount(delta, floor); } - double prob(const Dish& dish, const std::vector& p0s, const std::vector& lambdas) const { - assert(p0s.size() == num_floors_); - assert(lambdas.size() == num_floors_); - const double marg_p0 = std::inner_product(p0s.begin(), p0s.end(), lambdas.begin(), 0.0); - assert(marg_p0 <= 1.0); + template + typename std::iterator_traits::value_type prob(const Dish& dish, InputIterator p0s, InputIterator2 lambdas) const { + typedef typename std::iterator_traits::value_type F; + const F marg_p0 = std::inner_product(p0s, p0s + Floors, lambdas, F(0.0)); + assert(marg_p0 <= F(1.0001)); const typename std::tr1::unordered_map::const_iterator it = dish_locs_.find(dish); - const double r = num_tables_ * discount_ + strength_; + const F r = F(num_tables_ * discount_ + strength_); if (it == dish_locs_.end()) { - return r * marg_p0 / (num_customers_ + strength_); + return r * marg_p0 / F(num_customers_ + strength_); } else { - return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * marg_p0) / - (num_customers_ + strength_); + return (F(it->second.total_dish_count_ - discount_ * it->second.table_counts_.size()) + F(r * marg_p0)) / + F(num_customers_ + strength_); } } @@ -303,7 +308,7 @@ class MFCR { }; void Print(std::ostream* out) const { - (*out) << "MFCR(d=" << discount_ << ",strength=" << strength_ << ") customers=" << num_customers_ << std::endl; + (*out) << "MFCR<" << Floors << ">(d=" << discount_ << ",strength=" << strength_ << ") customers=" << num_customers_ << std::endl; for (typename std::tr1::unordered_map::const_iterator it = dish_locs_.begin(); it != dish_locs_.end(); ++it) { (*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): "; @@ -323,7 +328,6 @@ class MFCR { return dish_locs_.end(); } - unsigned num_floors_; unsigned num_tables_; unsigned num_customers_; std::tr1::unordered_map dish_locs_; @@ -340,8 +344,8 @@ class MFCR { double strength_prior_rate_; }; -template -std::ostream& operator<<(std::ostream& o, const MFCR& c) { +template +std::ostream& operator<<(std::ostream& o, const MFCR& c) { c.Print(&o); return o; } diff --git a/utils/mfcr_test.cc b/utils/mfcr_test.cc index 7c45a37c..cc886335 100644 --- a/utils/mfcr_test.cc +++ b/utils/mfcr_test.cc @@ -9,7 +9,7 @@ using namespace std; void test_exch(MT19937* rng) { - MFCR crp(2, 0.5, 3.0); + MFCR<2, int> crp(0.5, 3.0); vector lambdas(2); vector p0s(2); lambdas[0] = 0.2; @@ -22,23 +22,23 @@ void test_exch(MT19937* rng) { double xt = 0; int cust = 10; vector hist(cust + 1, 0), hist2(cust + 1, 0); - for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); } + for (int i = 0; i < cust; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); } const int samples = 100000; const bool simulate = true; for (int k = 0; k < samples; ++k) { if (!simulate) { crp.clear(); - for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); } + for (int i = 0; i < cust; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); } } else { int da = rng->next() * cust; bool a = rng->next() < 0.45; if (a) { - for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); } + for (int i = 0; i < da; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); } for (int i = 0; i < da; ++i) { crp.decrement(1, rng); } xt += 1.0; } else { for (int i = 0; i < da; ++i) { crp.decrement(1, rng); } - for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); } + for (int i = 0; i < da; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); } } } int c = crp.num_tables(1); -- cgit v1.2.3