summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-12-29 21:10:36 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-12-29 21:10:36 -0500
commit665badbdcc755183aa83414f6e86987f4d017393 (patch)
tree244cddf2e3863211fb0cc81922b7cdfa18d3a3e5
parentc430c3b3f7b22319001b2ddb1343d8a7506e9f81 (diff)
foo
-rw-r--r--gi/pf/conditional_pseg.h155
1 files changed, 155 insertions, 0 deletions
diff --git a/gi/pf/conditional_pseg.h b/gi/pf/conditional_pseg.h
new file mode 100644
index 00000000..edcdc813
--- /dev/null
+++ b/gi/pf/conditional_pseg.h
@@ -0,0 +1,155 @@
+#ifndef _CONDITIONAL_PSEG_H_
+#define _CONDITIONAL_PSEG_H_
+
+#include <vector>
+#include <tr1/unordered_map>
+#include <boost/functional/hash.hpp>
+#include <iostream>
+
+#include "prob.h"
+#include "ccrp_nt.h"
+#include "trule.h"
+#include "base_measures.h"
+#include "tdict.h"
+
+template <typename ConditionalBaseMeasure>
+struct ConditionalTranslationModel {
+ explicit ConditionalTranslationModel(ConditionalBaseMeasure& rcp0) :
+ rp0(rcp0) {}
+
+ void Summary() const {
+ std::cerr << "Number of conditioning contexts: " << r.size() << std::endl;
+ for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
+ std::cerr << TD::GetString(it->first) << " \t(\\alpha = " << it->second.concentration() << ") --------------------------" << std::endl;
+ for (CCRP_NoTable<TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
+ std::cerr << " " << i2->second << '\t' << i2->first << std::endl;
+ }
+ }
+
+ void ResampleHyperparameters(MT19937* rng) {
+ for (RuleModelHash::iterator it = r.begin(); it != r.end(); ++it)
+ it->second.resample_hyperparameters(rng);
+ }
+
+ int DecrementRule(const TRule& rule) {
+ RuleModelHash::iterator it = r.find(rule.f_);
+ assert(it != r.end());
+ int count = it->second.decrement(rule);
+ if (count) {
+ if (it->second.num_customers() == 0) r.erase(it);
+ }
+ return count;
+ }
+
+ int IncrementRule(const TRule& rule) {
+ RuleModelHash::iterator it = r.find(rule.f_);
+ if (it == r.end()) {
+ it = r.insert(make_pair(rule.f_, CCRP_NoTable<TRule>(1.0, 1.0, 8.0))).first;
+ }
+ int count = it->second.increment(rule);
+ return count;
+ }
+
+ void IncrementRules(const std::vector<TRulePtr>& rules) {
+ for (int i = 0; i < rules.size(); ++i)
+ IncrementRule(*rules[i]);
+ }
+
+ void DecrementRules(const std::vector<TRulePtr>& rules) {
+ for (int i = 0; i < rules.size(); ++i)
+ DecrementRule(*rules[i]);
+ }
+
+ prob_t RuleProbability(const TRule& rule) const {
+ prob_t p;
+ RuleModelHash::const_iterator it = r.find(rule.f_);
+ if (it == r.end()) {
+ p.logeq(log(rp0(rule)));
+ } else {
+ p.logeq(it->second.logprob(rule, log(rp0(rule))));
+ }
+ return p;
+ }
+
+ prob_t Likelihood() const {
+ prob_t p = prob_t::One();
+ for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
+ prob_t q; q.logeq(it->second.log_crp_prob());
+ p *= q;
+ for (CCRP_NoTable<TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
+ p *= rp0(i2->first);
+ }
+ return p;
+ }
+
+ const ConditionalBaseMeasure& rp0;
+ typedef std::tr1::unordered_map<std::vector<WordID>,
+ CCRP_NoTable<TRule>,
+ boost::hash<std::vector<WordID> > > RuleModelHash;
+ RuleModelHash r;
+};
+
+template <typename ConditionalBaseMeasure>
+struct ConditionalParallelSegementationModel {
+ explicit ConditionalParallelSegementationModel(ConditionalBaseMeasure& rcp0) :
+ tmodel(rcp0), base(prob_t::One()), aligns(1,1) {}
+
+ ConditionalTranslationModel<ConditionalBaseMeasure> tmodel;
+
+ void DecrementRule(const TRule& rule) {
+ tmodel.DecrementRule(rule);
+ }
+
+ void IncrementRule(const TRule& rule) {
+ tmodel.IncrementRule(rule);
+ }
+
+ void IncrementRulesAndAlignments(const std::vector<TRulePtr>& rules) {
+ tmodel.IncrementRules(rules);
+ for (int i = 0; i < rules.size(); ++i) {
+ IncrementAlign(rules[i]->f_.size());
+ }
+ }
+
+ void DecrementRulesAndAlignments(const std::vector<TRulePtr>& rules) {
+ tmodel.DecrementRules(rules);
+ for (int i = 0; i < rules.size(); ++i) {
+ DecrementAlign(rules[i]->f_.size());
+ }
+ }
+
+ prob_t RuleProbability(const TRule& rule) const {
+ return tmodel.RuleProbability(rule);
+ }
+
+ void IncrementAlign(unsigned span) {
+ if (aligns.increment(span)) {
+ // TODO
+ }
+ }
+
+ void DecrementAlign(unsigned span) {
+ if (aligns.decrement(span)) {
+ // TODO
+ }
+ }
+
+ prob_t AlignProbability(unsigned span) const {
+ prob_t p;
+ p.logeq(aligns.logprob(span, log_poisson(span, 1.0)));
+ return p;
+ }
+
+ prob_t Likelihood() const {
+ prob_t p; p.logeq(aligns.log_crp_prob());
+ p *= base;
+ p *= tmodel.Likelihood();
+ return p;
+ }
+
+ prob_t base;
+ CCRP_NoTable<unsigned> aligns;
+};
+
+#endif
+