From 4ebb11b25cf87dc5938b5eb65e884d0e3f4ee146 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 23 Jan 2012 15:47:29 -0500 Subject: more alignment stuff --- gi/pf/conditional_pseg.h | 74 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) (limited to 'gi/pf/conditional_pseg.h') diff --git a/gi/pf/conditional_pseg.h b/gi/pf/conditional_pseg.h index edcdc813..db951d15 100644 --- a/gi/pf/conditional_pseg.h +++ b/gi/pf/conditional_pseg.h @@ -8,10 +8,84 @@ #include "prob.h" #include "ccrp_nt.h" +#include "mfcr.h" #include "trule.h" #include "base_measures.h" #include "tdict.h" +template +struct MConditionalTranslationModel { + explicit MConditionalTranslationModel(ConditionalBaseMeasure& rcp0) : + rp0(rcp0), lambdas(1, 1.0), p0s(1) {} + + void Summary() const { + std::cerr << "Number of conditioning contexts: " << r.size() << std::endl; + for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) { + std::cerr << TD::GetString(it->first) << " \t(d=" << it->second.d() << ",\\alpha = " << it->second.alpha() << ") --------------------------" << std::endl; + for (MFCR::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2) + std::cerr << " " << -1 << '\t' << i2->first << std::endl; + } + } + + void ResampleHyperparameters(MT19937* rng) { + for (RuleModelHash::iterator it = r.begin(); it != r.end(); ++it) + it->second.resample_hyperparameters(rng); + } + + int DecrementRule(const TRule& rule, MT19937* rng) { + RuleModelHash::iterator it = r.find(rule.f_); + assert(it != r.end()); + const TableCount delta = it->second.decrement(rule, rng); + if (delta.count) { + if (it->second.num_customers() == 0) r.erase(it); + } + return delta.count; + } + + int IncrementRule(const TRule& rule, MT19937* rng) { + RuleModelHash::iterator it = r.find(rule.f_); + if (it == r.end()) { + it = r.insert(make_pair(rule.f_, MFCR(1, 1.0, 1.0, 1.0, 1.0, 1e-9, 4.0))).first; + } + p0s[0] = rp0(rule).as_float(); + TableCount delta = it->second.increment(rule, p0s, lambdas, rng); + return delta.count; + } + + prob_t RuleProbability(const TRule& rule) const { + prob_t p; + RuleModelHash::const_iterator it = r.find(rule.f_); + if (it == r.end()) { + p.logeq(log(rp0(rule))); + } else { + p0s[0] = rp0(rule).as_float(); + p = prob_t(it->second.prob(rule, p0s, lambdas)); + } + return p; + } + + prob_t Likelihood() const { + prob_t p = prob_t::One(); +#if 0 + for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) { + prob_t q; q.logeq(it->second.log_crp_prob()); + p *= q; + for (CCRP_NoTable::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2) + p *= rp0(i2->first); + } +#endif + return p; + } + + const ConditionalBaseMeasure& rp0; + typedef std::tr1::unordered_map, + MFCR, + boost::hash > > RuleModelHash; + RuleModelHash r; + std::vector lambdas; + mutable std::vector p0s; +}; + template struct ConditionalTranslationModel { explicit ConditionalTranslationModel(ConditionalBaseMeasure& rcp0) : -- cgit v1.2.3