diff options
| -rw-r--r-- | gi/pf/conditional_pseg.h | 155 | 
1 files changed, 155 insertions, 0 deletions
| diff --git a/gi/pf/conditional_pseg.h b/gi/pf/conditional_pseg.h new file mode 100644 index 00000000..edcdc813 --- /dev/null +++ b/gi/pf/conditional_pseg.h @@ -0,0 +1,155 @@ +#ifndef _CONDITIONAL_PSEG_H_ +#define _CONDITIONAL_PSEG_H_ + +#include <vector> +#include <tr1/unordered_map> +#include <boost/functional/hash.hpp> +#include <iostream> + +#include "prob.h" +#include "ccrp_nt.h" +#include "trule.h" +#include "base_measures.h" +#include "tdict.h" + +template <typename ConditionalBaseMeasure> +struct ConditionalTranslationModel { +  explicit ConditionalTranslationModel(ConditionalBaseMeasure& rcp0) : +    rp0(rcp0) {} + +  void Summary() const { +    std::cerr << "Number of conditioning contexts: " << r.size() << std::endl; +    for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) { +      std::cerr << TD::GetString(it->first) << "   \t(\\alpha = " << it->second.concentration() << ") --------------------------" << std::endl; +      for (CCRP_NoTable<TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2) +        std::cerr << "   " << i2->second << '\t' << i2->first << std::endl; +    } +  } + +  void ResampleHyperparameters(MT19937* rng) { +    for (RuleModelHash::iterator it = r.begin(); it != r.end(); ++it) +      it->second.resample_hyperparameters(rng); +  }  + +  int DecrementRule(const TRule& rule) { +    RuleModelHash::iterator it = r.find(rule.f_); +    assert(it != r.end());     +    int count = it->second.decrement(rule); +    if (count) { +      if (it->second.num_customers() == 0) r.erase(it); +    } +    return count; +  } + +  int IncrementRule(const TRule& rule) { +    RuleModelHash::iterator it = r.find(rule.f_); +    if (it == r.end()) { +      it = r.insert(make_pair(rule.f_, CCRP_NoTable<TRule>(1.0, 1.0, 8.0))).first; +    }  +    int count = it->second.increment(rule); +    return count; +  } + +  void IncrementRules(const std::vector<TRulePtr>& rules) { +    for (int i = 0; i < rules.size(); ++i) +      IncrementRule(*rules[i]); +  } + +  void DecrementRules(const std::vector<TRulePtr>& rules) { +    for (int i = 0; i < rules.size(); ++i) +      DecrementRule(*rules[i]); +  } + +  prob_t RuleProbability(const TRule& rule) const { +    prob_t p; +    RuleModelHash::const_iterator it = r.find(rule.f_); +    if (it == r.end()) { +      p.logeq(log(rp0(rule))); +    } else { +      p.logeq(it->second.logprob(rule, log(rp0(rule)))); +    } +    return p; +  } + +  prob_t Likelihood() const { +    prob_t p = prob_t::One(); +    for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) { +      prob_t q; q.logeq(it->second.log_crp_prob()); +      p *= q; +      for (CCRP_NoTable<TRule>::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2) +        p *= rp0(i2->first); +    } +    return p; +  } + +  const ConditionalBaseMeasure& rp0; +  typedef std::tr1::unordered_map<std::vector<WordID>, +                                  CCRP_NoTable<TRule>, +                                  boost::hash<std::vector<WordID> > > RuleModelHash; +  RuleModelHash r; +}; + +template <typename ConditionalBaseMeasure> +struct ConditionalParallelSegementationModel { +  explicit ConditionalParallelSegementationModel(ConditionalBaseMeasure& rcp0) : +    tmodel(rcp0), base(prob_t::One()), aligns(1,1) {} + +  ConditionalTranslationModel<ConditionalBaseMeasure> tmodel; + +  void DecrementRule(const TRule& rule) { +    tmodel.DecrementRule(rule); +  } + +  void IncrementRule(const TRule& rule) { +    tmodel.IncrementRule(rule); +  } + +  void IncrementRulesAndAlignments(const std::vector<TRulePtr>& rules) { +    tmodel.IncrementRules(rules); +    for (int i = 0; i < rules.size(); ++i) { +      IncrementAlign(rules[i]->f_.size()); +    } +  } + +  void DecrementRulesAndAlignments(const std::vector<TRulePtr>& rules) { +    tmodel.DecrementRules(rules); +    for (int i = 0; i < rules.size(); ++i) { +      DecrementAlign(rules[i]->f_.size()); +    } +  } + +  prob_t RuleProbability(const TRule& rule) const { +    return tmodel.RuleProbability(rule); +  } + +  void IncrementAlign(unsigned span) { +    if (aligns.increment(span)) { +      // TODO +    } +  } + +  void DecrementAlign(unsigned span) { +    if (aligns.decrement(span)) { +      // TODO +    } +  } + +  prob_t AlignProbability(unsigned span) const { +    prob_t p; +    p.logeq(aligns.logprob(span, log_poisson(span, 1.0))); +    return p; +  } + +  prob_t Likelihood() const { +    prob_t p; p.logeq(aligns.log_crp_prob()); +    p *= base; +    p *= tmodel.Likelihood(); +    return p; +  } + +  prob_t base; +  CCRP_NoTable<unsigned> aligns; +}; + +#endif + | 
