diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-03-05 16:06:45 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-03-05 16:06:45 -0500 |
commit | 2048ac9943e2695a75b5f0303ca869e66ee32202 (patch) | |
tree | 35de13bff0ce7f1fde619d04d165f7e6ac5b0631 /gi/pf/conditional_pseg.h | |
parent | ce58cb44771a5194b71682d1602abe2fef9e6f13 (diff) |
use template parameter inference to figure out what type to use for probability computations, templatatize number of floors in MFCR rather than compile-time set
Diffstat (limited to 'gi/pf/conditional_pseg.h')
-rw-r--r-- | gi/pf/conditional_pseg.h | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/gi/pf/conditional_pseg.h b/gi/pf/conditional_pseg.h index 86403d8d..ef73e332 100644 --- a/gi/pf/conditional_pseg.h +++ b/gi/pf/conditional_pseg.h @@ -17,13 +17,13 @@ template <typename ConditionalBaseMeasure> struct MConditionalTranslationModel { explicit MConditionalTranslationModel(ConditionalBaseMeasure& rcp0) : - rp0(rcp0), lambdas(1, 1.0), p0s(1) {} + rp0(rcp0), lambdas(1, prob_t::One()), p0s(1) {} void Summary() const { std::cerr << "Number of conditioning contexts: " << r.size() << std::endl; for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) { std::cerr << TD::GetString(it->first) << " \t(d=" << it->second.discount() << ",s=" << it->second.strength() << ") --------------------------" << std::endl; - for (MFCR<TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2) + for (MFCR<1,TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2) std::cerr << " " << -1 << '\t' << i2->first << std::endl; } } @@ -46,10 +46,10 @@ struct MConditionalTranslationModel { int IncrementRule(const TRule& rule, MT19937* rng) { RuleModelHash::iterator it = r.find(rule.f_); if (it == r.end()) { - it = r.insert(make_pair(rule.f_, MFCR<TRule>(1, 1.0, 1.0, 1.0, 1.0, 1e-9, 4.0))).first; + it = r.insert(make_pair(rule.f_, MFCR<1,TRule>(1.0, 1.0, 1.0, 1.0, 1e-9, 4.0))).first; } - p0s[0] = rp0(rule).as_float(); - TableCount delta = it->second.increment(rule, p0s, lambdas, rng); + p0s[0] = rp0(rule); + TableCount delta = it->second.increment(rule, p0s.begin(), lambdas.begin(), rng); return delta.count; } @@ -57,10 +57,10 @@ struct MConditionalTranslationModel { prob_t p; RuleModelHash::const_iterator it = r.find(rule.f_); if (it == r.end()) { - p.logeq(log(rp0(rule))); + p = rp0(rule); } else { - p0s[0] = rp0(rule).as_float(); - p = prob_t(it->second.prob(rule, p0s, lambdas)); + p0s[0] = rp0(rule); + p = it->second.prob(rule, p0s.begin(), lambdas.begin()); } return p; } @@ -80,11 +80,11 @@ struct MConditionalTranslationModel { const ConditionalBaseMeasure& rp0; typedef std::tr1::unordered_map<std::vector<WordID>, - MFCR<TRule>, + MFCR<1, TRule>, boost::hash<std::vector<WordID> > > RuleModelHash; RuleModelHash r; - std::vector<double> lambdas; - mutable std::vector<double> p0s; + std::vector<prob_t> lambdas; + mutable std::vector<prob_t> p0s; }; template <typename ConditionalBaseMeasure> |