summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-03-05 16:06:45 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-03-05 16:06:45 -0500
commit4c007d48d5829233d0ae3c3c8b48f8c25631bf81 (patch)
treede540fa94cd96ac3721f52e3c9095bd2036b19b3
parent1d5a0055a948663d799b4c5b1380ce1d9742bf6b (diff)
use template parameter inference to figure out what type to use for probability computations, templatatize number of floors in MFCR rather than compile-time set
-rw-r--r--gi/pf/align-lexonly-pyp.cc20
-rw-r--r--gi/pf/conditional_pseg.h22
-rw-r--r--gi/pf/learn_cfg.cc8
-rw-r--r--utils/ccrp.h48
-rw-r--r--utils/mfcr.h68
-rw-r--r--utils/mfcr_test.cc10
6 files changed, 68 insertions, 108 deletions
diff --git a/gi/pf/align-lexonly-pyp.cc b/gi/pf/align-lexonly-pyp.cc
index 87f7f6b5..ac0590e0 100644
--- a/gi/pf/align-lexonly-pyp.cc
+++ b/gi/pf/align-lexonly-pyp.cc
@@ -68,7 +68,7 @@ struct AlignedSentencePair {
struct HierarchicalWordBase {
explicit HierarchicalWordBase(const unsigned vocab_e_size) :
- base(prob_t::One()), r(1,1,1,25,25), u0(-log(vocab_e_size)), l(1,1.0), v(1, 0.0) {}
+ base(prob_t::One()), r(1,1,1,1), u0(-log(vocab_e_size)), l(1,prob_t::One()), v(1, prob_t::Zero()) {}
void ResampleHyperparameters(MT19937* rng) {
r.resample_hyperparameters(rng);
@@ -80,14 +80,14 @@ struct HierarchicalWordBase {
// return p0 of rule.e_
prob_t operator()(const TRule& rule) const {
- v[0] = exp(logp0(rule.e_));
- return prob_t(r.prob(rule.e_, v, l));
+ v[0].logeq(logp0(rule.e_));
+ return r.prob(rule.e_, v.begin(), l.begin());
}
void Increment(const TRule& rule) {
- v[0] = exp(logp0(rule.e_));
- if (r.increment(rule.e_, v, l, &*prng).count) {
- base *= prob_t(v[0] * l[0]);
+ v[0].logeq(logp0(rule.e_));
+ if (r.increment(rule.e_, v.begin(), l.begin(), &*prng).count) {
+ base *= v[0] * l[0];
}
}
@@ -105,15 +105,15 @@ struct HierarchicalWordBase {
void Summary() const {
cerr << "NUMBER OF CUSTOMERS: " << r.num_customers() << " (d=" << r.discount() << ",s=" << r.strength() << ')' << endl;
- for (MFCR<vector<WordID> >::const_iterator it = r.begin(); it != r.end(); ++it)
+ for (MFCR<1,vector<WordID> >::const_iterator it = r.begin(); it != r.end(); ++it)
cerr << " " << it->second.total_dish_count_ << " (on " << it->second.table_counts_.size() << " tables)" << TD::GetString(it->first) << endl;
}
prob_t base;
- MFCR<vector<WordID> > r;
+ MFCR<1,vector<WordID> > r;
const double u0;
- const vector<double> l;
- mutable vector<double> v;
+ const vector<prob_t> l;
+ mutable vector<prob_t> v;
};
struct BasicLexicalAlignment {
diff --git a/gi/pf/conditional_pseg.h b/gi/pf/conditional_pseg.h
index 86403d8d..ef73e332 100644
--- a/gi/pf/conditional_pseg.h
+++ b/gi/pf/conditional_pseg.h
@@ -17,13 +17,13 @@
template <typename ConditionalBaseMeasure>
struct MConditionalTranslationModel {
explicit MConditionalTranslationModel(ConditionalBaseMeasure& rcp0) :
- rp0(rcp0), lambdas(1, 1.0), p0s(1) {}
+ rp0(rcp0), lambdas(1, prob_t::One()), p0s(1) {}
void Summary() const {
std::cerr << "Number of conditioning contexts: " << r.size() << std::endl;
for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
std::cerr << TD::GetString(it->first) << " \t(d=" << it->second.discount() << ",s=" << it->second.strength() << ") --------------------------" << std::endl;
- for (MFCR<TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
+ for (MFCR<1,TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
std::cerr << " " << -1 << '\t' << i2->first << std::endl;
}
}
@@ -46,10 +46,10 @@ struct MConditionalTranslationModel {
int IncrementRule(const TRule& rule, MT19937* rng) {
RuleModelHash::iterator it = r.find(rule.f_);
if (it == r.end()) {
- it = r.insert(make_pair(rule.f_, MFCR<TRule>(1, 1.0, 1.0, 1.0, 1.0, 1e-9, 4.0))).first;
+ it = r.insert(make_pair(rule.f_, MFCR<1,TRule>(1.0, 1.0, 1.0, 1.0, 1e-9, 4.0))).first;
}
- p0s[0] = rp0(rule).as_float();
- TableCount delta = it->second.increment(rule, p0s, lambdas, rng);
+ p0s[0] = rp0(rule);
+ TableCount delta = it->second.increment(rule, p0s.begin(), lambdas.begin(), rng);
return delta.count;
}
@@ -57,10 +57,10 @@ struct MConditionalTranslationModel {
prob_t p;
RuleModelHash::const_iterator it = r.find(rule.f_);
if (it == r.end()) {
- p.logeq(log(rp0(rule)));
+ p = rp0(rule);
} else {
- p0s[0] = rp0(rule).as_float();
- p = prob_t(it->second.prob(rule, p0s, lambdas));
+ p0s[0] = rp0(rule);
+ p = it->second.prob(rule, p0s.begin(), lambdas.begin());
}
return p;
}
@@ -80,11 +80,11 @@ struct MConditionalTranslationModel {
const ConditionalBaseMeasure& rp0;
typedef std::tr1::unordered_map<std::vector<WordID>,
- MFCR<TRule>,
+ MFCR<1, TRule>,
boost::hash<std::vector<WordID> > > RuleModelHash;
RuleModelHash r;
- std::vector<double> lambdas;
- mutable std::vector<double> p0s;
+ std::vector<prob_t> lambdas;
+ mutable std::vector<prob_t> p0s;
};
template <typename ConditionalBaseMeasure>
diff --git a/gi/pf/learn_cfg.cc b/gi/pf/learn_cfg.cc
index bf157828..ed1772bf 100644
--- a/gi/pf/learn_cfg.cc
+++ b/gi/pf/learn_cfg.cc
@@ -127,20 +127,20 @@ struct HieroLMModel {
nts(num_nts, CCRP<TRule>(1,1,1,1)) {}
prob_t Prob(const TRule& r) const {
- return nts[nt_id_to_index[-r.lhs_]].probT<prob_t>(r, p0(r));
+ return nts[nt_id_to_index[-r.lhs_]].prob(r, p0(r));
}
inline prob_t p0(const TRule& r) const {
if (kHIERARCHICAL_PRIOR)
- return q0.probT<prob_t>(r, base(r));
+ return q0.prob(r, base(r));
else
return base(r);
}
int Increment(const TRule& r, MT19937* rng) {
- const int delta = nts[nt_id_to_index[-r.lhs_]].incrementT<prob_t>(r, p0(r), rng);
+ const int delta = nts[nt_id_to_index[-r.lhs_]].increment(r, p0(r), rng);
if (kHIERARCHICAL_PRIOR && delta)
- q0.incrementT<prob_t>(r, base(r), rng);
+ q0.increment(r, base(r), rng);
return delta;
// return x.increment(r);
}
diff --git a/utils/ccrp.h b/utils/ccrp.h
index 5f9db7a6..e24130ac 100644
--- a/utils/ccrp.h
+++ b/utils/ccrp.h
@@ -93,41 +93,8 @@ class CCRP {
}
// returns +1 or 0 indicating whether a new table was opened
- int increment(const Dish& dish, const double& p0, MT19937* rng) {
- DishLocations& loc = dish_locs_[dish];
- bool share_table = false;
- if (loc.total_dish_count_) {
- const double p_empty = (strength_ + num_tables_ * discount_) * p0;
- const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_);
- share_table = rng->SelectSample(p_empty, p_share);
- }
- if (share_table) {
- double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_);
- for (typename std::list<unsigned>::iterator ti = loc.table_counts_.begin();
- ti != loc.table_counts_.end(); ++ti) {
- r -= (*ti - discount_);
- if (r <= 0.0) {
- ++(*ti);
- break;
- }
- }
- if (r > 0.0) {
- std::cerr << "Serious error: r=" << r << std::endl;
- Print(&std::cerr);
- assert(r <= 0.0);
- }
- } else {
- loc.table_counts_.push_back(1u);
- ++num_tables_;
- }
- ++loc.total_dish_count_;
- ++num_customers_;
- return (share_table ? 0 : 1);
- }
-
- // returns +1 or 0 indicating whether a new table was opened
template <typename T>
- int incrementT(const Dish& dish, const T& p0, MT19937* rng) {
+ int increment(const Dish& dish, const T& p0, MT19937* rng) {
DishLocations& loc = dish_locs_[dish];
bool share_table = false;
if (loc.total_dish_count_) {
@@ -196,19 +163,8 @@ class CCRP {
}
}
- double prob(const Dish& dish, const double& p0) const {
- const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
- const double r = num_tables_ * discount_ + strength_;
- if (it == dish_locs_.end()) {
- return r * p0 / (num_customers_ + strength_);
- } else {
- return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * p0) /
- (num_customers_ + strength_);
- }
- }
-
template <typename T>
- T probT(const Dish& dish, const T& p0) const {
+ T prob(const Dish& dish, const T& p0) const {
const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
const T r = T(num_tables_ * discount_ + strength_);
if (it == dish_locs_.end()) {
diff --git a/utils/mfcr.h b/utils/mfcr.h
index aeaf599d..6cc0ebf1 100644
--- a/utils/mfcr.h
+++ b/utils/mfcr.h
@@ -8,6 +8,7 @@
#include <list>
#include <iostream>
#include <vector>
+#include <iterator>
#include <tr1/unordered_map>
#include <boost/functional/hash.hpp>
#include "sampler.h"
@@ -35,12 +36,11 @@ std::ostream& operator<<(std::ostream& o, const TableCount& tc) {
// referenced therein.
// http://www.aclweb.org/anthology/P/P09/P09-2085.pdf
//
-template <typename Dish, typename DishHash = boost::hash<Dish> >
+template <unsigned Floors, typename Dish, typename DishHash = boost::hash<Dish> >
class MFCR {
public:
- MFCR(unsigned num_floors, double d, double strength) :
- num_floors_(num_floors),
+ MFCR(double d, double strength) :
num_tables_(),
num_customers_(),
discount_(d),
@@ -50,8 +50,7 @@ class MFCR {
strength_prior_shape_(std::numeric_limits<double>::quiet_NaN()),
strength_prior_rate_(std::numeric_limits<double>::quiet_NaN()) {}
- MFCR(unsigned num_floors, double discount_strength, double discount_beta, double strength_shape, double strength_rate, double d = 0.9, double strength = 10.0) :
- num_floors_(num_floors),
+ MFCR(double discount_strength, double discount_beta, double strength_shape, double strength_rate, double d = 0.9, double strength = 10.0) :
num_tables_(),
num_customers_(),
discount_(d),
@@ -111,22 +110,22 @@ class MFCR {
}
// returns (delta, floor) indicating whether a new table (delta) was opened and on which floor
- TableCount increment(const Dish& dish, const std::vector<double>& p0s, const std::vector<double>& lambdas, MT19937* rng) {
- assert(p0s.size() == num_floors_);
- assert(lambdas.size() == num_floors_);
-
+ template <class InputIterator, class InputIterator2>
+ TableCount increment(const Dish& dish, InputIterator p0s, InputIterator2 lambdas, MT19937* rng) {
DishLocations& loc = dish_locs_[dish];
// marg_p0 = marginal probability of opening a new table on any floor with label dish
- const double marg_p0 = std::inner_product(p0s.begin(), p0s.end(), lambdas.begin(), 0.0);
- assert(marg_p0 <= 1.0);
+ typedef typename std::iterator_traits<InputIterator>::value_type F;
+ const F marg_p0 = std::inner_product(p0s, p0s + Floors, lambdas, F(0.0));
+ assert(marg_p0 <= F(1.0001));
int floor = -1;
bool share_table = false;
if (loc.total_dish_count_) {
- const double p_empty = (strength_ + num_tables_ * discount_) * marg_p0;
- const double p_share = (loc.total_dish_count_ - loc.table_counts_.size() * discount_);
+ const F p_empty = F(strength_ + num_tables_ * discount_) * marg_p0;
+ const F p_share = F(loc.total_dish_count_ - loc.table_counts_.size() * discount_);
share_table = rng->SelectSample(p_empty, p_share);
}
if (share_table) {
+ // this can be done with doubles since P0 (which may be tiny) is not involved
double r = rng->next() * (loc.total_dish_count_ - loc.table_counts_.size() * discount_);
for (typename std::list<TableCount>::iterator ti = loc.table_counts_.begin();
ti != loc.table_counts_.end(); ++ti) {
@@ -143,12 +142,18 @@ class MFCR {
assert(r <= 0.0);
}
} else { // sit at currently empty table -- must sample what floor
- double r = rng->next() * marg_p0;
- for (unsigned i = 0; i < p0s.size(); ++i) {
- r -= p0s[i] * lambdas[i];
- if (r <= 0.0) {
- floor = i;
- break;
+ if (Floors == 1) {
+ floor = 0;
+ } else {
+ F r = F(rng->next()) * marg_p0;
+ for (unsigned i = 0; i < Floors; ++i) {
+ r -= (*p0s) * (*lambdas);
+ ++p0s;
+ ++lambdas;
+ if (r <= F(0.0)) {
+ floor = i;
+ break;
+ }
}
}
assert(floor >= 0);
@@ -200,18 +205,18 @@ class MFCR {
return TableCount(delta, floor);
}
- double prob(const Dish& dish, const std::vector<double>& p0s, const std::vector<double>& lambdas) const {
- assert(p0s.size() == num_floors_);
- assert(lambdas.size() == num_floors_);
- const double marg_p0 = std::inner_product(p0s.begin(), p0s.end(), lambdas.begin(), 0.0);
- assert(marg_p0 <= 1.0);
+ template <class InputIterator, class InputIterator2>
+ typename std::iterator_traits<InputIterator>::value_type prob(const Dish& dish, InputIterator p0s, InputIterator2 lambdas) const {
+ typedef typename std::iterator_traits<InputIterator>::value_type F;
+ const F marg_p0 = std::inner_product(p0s, p0s + Floors, lambdas, F(0.0));
+ assert(marg_p0 <= F(1.0001));
const typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.find(dish);
- const double r = num_tables_ * discount_ + strength_;
+ const F r = F(num_tables_ * discount_ + strength_);
if (it == dish_locs_.end()) {
- return r * marg_p0 / (num_customers_ + strength_);
+ return r * marg_p0 / F(num_customers_ + strength_);
} else {
- return (it->second.total_dish_count_ - discount_ * it->second.table_counts_.size() + r * marg_p0) /
- (num_customers_ + strength_);
+ return (F(it->second.total_dish_count_ - discount_ * it->second.table_counts_.size()) + F(r * marg_p0)) /
+ F(num_customers_ + strength_);
}
}
@@ -303,7 +308,7 @@ class MFCR {
};
void Print(std::ostream* out) const {
- (*out) << "MFCR(d=" << discount_ << ",strength=" << strength_ << ") customers=" << num_customers_ << std::endl;
+ (*out) << "MFCR<" << Floors << ">(d=" << discount_ << ",strength=" << strength_ << ") customers=" << num_customers_ << std::endl;
for (typename std::tr1::unordered_map<Dish, DishLocations, DishHash>::const_iterator it = dish_locs_.begin();
it != dish_locs_.end(); ++it) {
(*out) << it->first << " (" << it->second.total_dish_count_ << " on " << it->second.table_counts_.size() << " tables): ";
@@ -323,7 +328,6 @@ class MFCR {
return dish_locs_.end();
}
- unsigned num_floors_;
unsigned num_tables_;
unsigned num_customers_;
std::tr1::unordered_map<Dish, DishLocations, DishHash> dish_locs_;
@@ -340,8 +344,8 @@ class MFCR {
double strength_prior_rate_;
};
-template <typename T,typename H>
-std::ostream& operator<<(std::ostream& o, const MFCR<T,H>& c) {
+template <unsigned N,typename T,typename H>
+std::ostream& operator<<(std::ostream& o, const MFCR<N,T,H>& c) {
c.Print(&o);
return o;
}
diff --git a/utils/mfcr_test.cc b/utils/mfcr_test.cc
index 7c45a37c..cc886335 100644
--- a/utils/mfcr_test.cc
+++ b/utils/mfcr_test.cc
@@ -9,7 +9,7 @@
using namespace std;
void test_exch(MT19937* rng) {
- MFCR<int> crp(2, 0.5, 3.0);
+ MFCR<2, int> crp(0.5, 3.0);
vector<double> lambdas(2);
vector<double> p0s(2);
lambdas[0] = 0.2;
@@ -22,23 +22,23 @@ void test_exch(MT19937* rng) {
double xt = 0;
int cust = 10;
vector<int> hist(cust + 1, 0), hist2(cust + 1, 0);
- for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); }
+ for (int i = 0; i < cust; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }
const int samples = 100000;
const bool simulate = true;
for (int k = 0; k < samples; ++k) {
if (!simulate) {
crp.clear();
- for (int i = 0; i < cust; ++i) { crp.increment(1, p0s, lambdas, rng); }
+ for (int i = 0; i < cust; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }
} else {
int da = rng->next() * cust;
bool a = rng->next() < 0.45;
if (a) {
- for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); }
+ for (int i = 0; i < da; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }
for (int i = 0; i < da; ++i) { crp.decrement(1, rng); }
xt += 1.0;
} else {
for (int i = 0; i < da; ++i) { crp.decrement(1, rng); }
- for (int i = 0; i < da; ++i) { crp.increment(1, p0s, lambdas, rng); }
+ for (int i = 0; i < da; ++i) { crp.increment(1, p0s.begin(), lambdas.begin(), rng); }
}
}
int c = crp.num_tables(1);