summaryrefslogtreecommitdiff
path: root/gi/pf/conditional_pseg.h
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-03-05 16:06:45 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-03-05 16:06:45 -0500
commit4c007d48d5829233d0ae3c3c8b48f8c25631bf81 (patch)
treede540fa94cd96ac3721f52e3c9095bd2036b19b3 /gi/pf/conditional_pseg.h
parent1d5a0055a948663d799b4c5b1380ce1d9742bf6b (diff)
use template parameter inference to figure out what type to use for probability computations, templatatize number of floors in MFCR rather than compile-time set
Diffstat (limited to 'gi/pf/conditional_pseg.h')
-rw-r--r--gi/pf/conditional_pseg.h22
1 files changed, 11 insertions, 11 deletions
diff --git a/gi/pf/conditional_pseg.h b/gi/pf/conditional_pseg.h
index 86403d8d..ef73e332 100644
--- a/gi/pf/conditional_pseg.h
+++ b/gi/pf/conditional_pseg.h
@@ -17,13 +17,13 @@
template <typename ConditionalBaseMeasure>
struct MConditionalTranslationModel {
explicit MConditionalTranslationModel(ConditionalBaseMeasure& rcp0) :
- rp0(rcp0), lambdas(1, 1.0), p0s(1) {}
+ rp0(rcp0), lambdas(1, prob_t::One()), p0s(1) {}
void Summary() const {
std::cerr << "Number of conditioning contexts: " << r.size() << std::endl;
for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
std::cerr << TD::GetString(it->first) << " \t(d=" << it->second.discount() << ",s=" << it->second.strength() << ") --------------------------" << std::endl;
- for (MFCR<TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
+ for (MFCR<1,TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
std::cerr << " " << -1 << '\t' << i2->first << std::endl;
}
}
@@ -46,10 +46,10 @@ struct MConditionalTranslationModel {
int IncrementRule(const TRule& rule, MT19937* rng) {
RuleModelHash::iterator it = r.find(rule.f_);
if (it == r.end()) {
- it = r.insert(make_pair(rule.f_, MFCR<TRule>(1, 1.0, 1.0, 1.0, 1.0, 1e-9, 4.0))).first;
+ it = r.insert(make_pair(rule.f_, MFCR<1,TRule>(1.0, 1.0, 1.0, 1.0, 1e-9, 4.0))).first;
}
- p0s[0] = rp0(rule).as_float();
- TableCount delta = it->second.increment(rule, p0s, lambdas, rng);
+ p0s[0] = rp0(rule);
+ TableCount delta = it->second.increment(rule, p0s.begin(), lambdas.begin(), rng);
return delta.count;
}
@@ -57,10 +57,10 @@ struct MConditionalTranslationModel {
prob_t p;
RuleModelHash::const_iterator it = r.find(rule.f_);
if (it == r.end()) {
- p.logeq(log(rp0(rule)));
+ p = rp0(rule);
} else {
- p0s[0] = rp0(rule).as_float();
- p = prob_t(it->second.prob(rule, p0s, lambdas));
+ p0s[0] = rp0(rule);
+ p = it->second.prob(rule, p0s.begin(), lambdas.begin());
}
return p;
}
@@ -80,11 +80,11 @@ struct MConditionalTranslationModel {
const ConditionalBaseMeasure& rp0;
typedef std::tr1::unordered_map<std::vector<WordID>,
- MFCR<TRule>,
+ MFCR<1, TRule>,
boost::hash<std::vector<WordID> > > RuleModelHash;
RuleModelHash r;
- std::vector<double> lambdas;
- mutable std::vector<double> p0s;
+ std::vector<prob_t> lambdas;
+ mutable std::vector<prob_t> p0s;
};
template <typename ConditionalBaseMeasure>