summaryrefslogtreecommitdiff
path: root/gi/pf/pyp_tm.cc
blob: 34ef0ba2798cc34d628d98d1ad9e1137da048e83 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include "pyp_tm.h"

#include <tr1/unordered_map>
#include <iostream>
#include <queue>

#include "base_distributions.h"
#include "monotonic_pseg.h"
#include "conditional_pseg.h"
#include "tdict.h"
#include "ccrp.h"
#include "pyp_word_model.h"
#include "tied_resampler.h"

using namespace std;
using namespace std::tr1;

template <typename Base>
struct ConditionalPYPWordModel {
  ConditionalPYPWordModel(Base* b) : base(*b), btr(2) {}

  void Summary() const {
    cerr << "Number of conditioning contexts: " << r.size() << endl;
    for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
      cerr << TD::Convert(it->first) << "   \tPYP(d=" << it->second.discount() << ",s=" << it->second.strength() << ") --------------------------" << endl;
      for (CCRP<vector<WordID> >::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
        cerr << "   " << i2->second.total_dish_count_ << '\t' << TD::GetString(i2->first) << endl;
    }
  }

  void ResampleHyperparameters(MT19937* rng) {
    btr.ResampleHyperparameters(rng);
  } 

  prob_t Prob(const WordID src, const vector<WordID>& trglets) const {
    RuleModelHash::const_iterator it = r.find(src);
    if (it == r.end()) {
      return base(trglets);
    } else {
      return it->second.prob(trglets, base(trglets));
    }
  }

  void Increment(const WordID src, const vector<WordID>& trglets, MT19937* rng) {
    RuleModelHash::iterator it = r.find(src);
    if (it == r.end()) {
      it = r.insert(make_pair(src, CCRP<vector<WordID> >(0.5,1.0))).first;
      static const WordID kNULL = TD::Convert("NULL");
      btr.Add(src == kNULL ? 0 : 1, &it->second);
    }
    if (it->second.increment(trglets, base(trglets), rng))
      base.Increment(trglets, rng);
  }

  void Decrement(const WordID src, const vector<WordID>& trglets, MT19937* rng) {
    RuleModelHash::iterator it = r.find(src);
    assert(it != r.end());
    if (it->second.decrement(trglets, rng)) {
      base.Decrement(trglets, rng);
    }
  }

  prob_t Likelihood() const {
    prob_t p = prob_t::One();
    for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
      prob_t q; q.logeq(it->second.log_crp_prob());
      p *= q;
    }
    return p;
  }

  unsigned UniqueConditioningContexts() const {
    return r.size();
  }

  // TODO tie PYP hyperparameters based on source word frequency bins
  Base& base;
  BinTiedResampler<CCRP<vector<WordID> > > btr;
  typedef unordered_map<WordID, CCRP<vector<WordID> > > RuleModelHash;
  RuleModelHash r;
};

PYPLexicalTranslation::PYPLexicalTranslation(const vector<vector<WordID> >& lets,
                                             const unsigned num_letters) :
    letters(lets),
    up0(new PYPWordModel(num_letters)),
    tmodel(new ConditionalPYPWordModel<PYPWordModel>(up0)),
    kX(-TD::Convert("X")) {}

void PYPLexicalTranslation::Summary() const {
  tmodel->Summary();
  up0->Summary();
}

prob_t PYPLexicalTranslation::Likelihood() const {
  prob_t p = up0->Likelihood();
  p *= tmodel->Likelihood();
  return p;
}

void PYPLexicalTranslation::ResampleHyperparameters(MT19937* rng) {
  tmodel->ResampleHyperparameters(rng);
  up0->ResampleHyperparameters(rng);
}

unsigned PYPLexicalTranslation::UniqueConditioningContexts() const {
  return tmodel->UniqueConditioningContexts();
}

prob_t PYPLexicalTranslation::Prob(WordID src, WordID trg) const {
  return tmodel->Prob(src, letters[trg]);
}

void PYPLexicalTranslation::Increment(WordID src, WordID trg, MT19937* rng) {
  tmodel->Increment(src, letters[trg], rng);
}

void PYPLexicalTranslation::Decrement(WordID src, WordID trg, MT19937* rng) {
  tmodel->Decrement(src, letters[trg], rng);
}