summaryrefslogtreecommitdiff
path: root/gi/pf/conditional_pseg.h
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-01-23 15:47:29 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-01-23 15:47:29 -0500
commit4ebb11b25cf87dc5938b5eb65e884d0e3f4ee146 (patch)
tree69966f7d05dff15742e43698d004c183646b2d98 /gi/pf/conditional_pseg.h
parent5f998b1d600a34f95a5293522167394d3dd37bf6 (diff)
more alignment stuff
Diffstat (limited to 'gi/pf/conditional_pseg.h')
-rw-r--r--gi/pf/conditional_pseg.h74
1 files changed, 74 insertions, 0 deletions
diff --git a/gi/pf/conditional_pseg.h b/gi/pf/conditional_pseg.h
index edcdc813..db951d15 100644
--- a/gi/pf/conditional_pseg.h
+++ b/gi/pf/conditional_pseg.h
@@ -8,11 +8,85 @@
#include "prob.h"
#include "ccrp_nt.h"
+#include "mfcr.h"
#include "trule.h"
#include "base_measures.h"
#include "tdict.h"
template <typename ConditionalBaseMeasure>
+struct MConditionalTranslationModel {
+ explicit MConditionalTranslationModel(ConditionalBaseMeasure& rcp0) :
+ rp0(rcp0), lambdas(1, 1.0), p0s(1) {}
+
+ void Summary() const {
+ std::cerr << "Number of conditioning contexts: " << r.size() << std::endl;
+ for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
+ std::cerr << TD::GetString(it->first) << " \t(d=" << it->second.d() << ",\\alpha = " << it->second.alpha() << ") --------------------------" << std::endl;
+ for (MFCR<TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
+ std::cerr << " " << -1 << '\t' << i2->first << std::endl;
+ }
+ }
+
+ void ResampleHyperparameters(MT19937* rng) {
+ for (RuleModelHash::iterator it = r.begin(); it != r.end(); ++it)
+ it->second.resample_hyperparameters(rng);
+ }
+
+ int DecrementRule(const TRule& rule, MT19937* rng) {
+ RuleModelHash::iterator it = r.find(rule.f_);
+ assert(it != r.end());
+ const TableCount delta = it->second.decrement(rule, rng);
+ if (delta.count) {
+ if (it->second.num_customers() == 0) r.erase(it);
+ }
+ return delta.count;
+ }
+
+ int IncrementRule(const TRule& rule, MT19937* rng) {
+ RuleModelHash::iterator it = r.find(rule.f_);
+ if (it == r.end()) {
+ it = r.insert(make_pair(rule.f_, MFCR<TRule>(1, 1.0, 1.0, 1.0, 1.0, 1e-9, 4.0))).first;
+ }
+ p0s[0] = rp0(rule).as_float();
+ TableCount delta = it->second.increment(rule, p0s, lambdas, rng);
+ return delta.count;
+ }
+
+ prob_t RuleProbability(const TRule& rule) const {
+ prob_t p;
+ RuleModelHash::const_iterator it = r.find(rule.f_);
+ if (it == r.end()) {
+ p.logeq(log(rp0(rule)));
+ } else {
+ p0s[0] = rp0(rule).as_float();
+ p = prob_t(it->second.prob(rule, p0s, lambdas));
+ }
+ return p;
+ }
+
+ prob_t Likelihood() const {
+ prob_t p = prob_t::One();
+#if 0
+ for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
+ prob_t q; q.logeq(it->second.log_crp_prob());
+ p *= q;
+ for (CCRP_NoTable<TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
+ p *= rp0(i2->first);
+ }
+#endif
+ return p;
+ }
+
+ const ConditionalBaseMeasure& rp0;
+ typedef std::tr1::unordered_map<std::vector<WordID>,
+ MFCR<TRule>,
+ boost::hash<std::vector<WordID> > > RuleModelHash;
+ RuleModelHash r;
+ std::vector<double> lambdas;
+ mutable std::vector<double> p0s;
+};
+
+template <typename ConditionalBaseMeasure>
struct ConditionalTranslationModel {
explicit ConditionalTranslationModel(ConditionalBaseMeasure& rcp0) :
rp0(rcp0) {}