1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
#include "pyp_tm.h"
#include <tr1/unordered_map>
#include <iostream>
#include <queue>
#include "tdict.h"
#include "ccrp.h"
#include "pyp_word_model.h"
#include "tied_resampler.h"
using namespace std;
using namespace std::tr1;
struct FreqBinner {
FreqBinner(const std::string& fname) { fd_.Load(fname); }
unsigned NumberOfBins() const { return fd_.Max() + 1; }
unsigned Bin(const WordID& w) const { return fd_.LookUp(w); }
FreqDict<unsigned> fd_;
};
template <typename Base, class Binner = FreqBinner>
struct ConditionalPYPWordModel {
ConditionalPYPWordModel(Base* b, const Binner* bnr = NULL) :
base(*b),
binner(bnr),
btr(binner ? binner->NumberOfBins() + 1u : 2u) {}
void Summary() const {
cerr << "Number of conditioning contexts: " << r.size() << endl;
for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
cerr << TD::Convert(it->first) << " \tPYP(d=" << it->second.discount() << ",s=" << it->second.strength() << ") --------------------------" << endl;
for (CCRP<vector<WordID> >::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
cerr << " " << i2->second << '\t' << TD::GetString(i2->first) << endl;
}
}
void ResampleHyperparameters(MT19937* rng) {
btr.ResampleHyperparameters(rng);
}
prob_t Prob(const WordID src, const vector<WordID>& trglets) const {
RuleModelHash::const_iterator it = r.find(src);
if (it == r.end()) {
return base(trglets);
} else {
return it->second.prob(trglets, base(trglets));
}
}
void Increment(const WordID src, const vector<WordID>& trglets, MT19937* rng) {
RuleModelHash::iterator it = r.find(src);
if (it == r.end()) {
it = r.insert(make_pair(src, CCRP<vector<WordID> >(0.5,1.0))).first;
static const WordID kNULL = TD::Convert("NULL");
unsigned bin = (src == kNULL ? 0 : 1);
if (binner && bin) { bin = binner->Bin(src) + 1; }
btr.Add(bin, &it->second);
}
if (it->second.increment(trglets, base(trglets), rng))
base.Increment(trglets, rng);
}
void Decrement(const WordID src, const vector<WordID>& trglets, MT19937* rng) {
RuleModelHash::iterator it = r.find(src);
assert(it != r.end());
if (it->second.decrement(trglets, rng)) {
base.Decrement(trglets, rng);
}
}
prob_t Likelihood() const {
prob_t p = prob_t::One();
for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
prob_t q; q.logeq(it->second.log_crp_prob());
p *= q;
}
return p;
}
unsigned UniqueConditioningContexts() const {
return r.size();
}
// TODO tie PYP hyperparameters based on source word frequency bins
Base& base;
const Binner* binner;
BinTiedResampler<CCRP<vector<WordID> > > btr;
typedef unordered_map<WordID, CCRP<vector<WordID> > > RuleModelHash;
RuleModelHash r;
};
PYPLexicalTranslation::PYPLexicalTranslation(const vector<vector<WordID> >& lets,
const unsigned vocab_size,
const unsigned num_letters) :
letters(lets),
base(vocab_size, num_letters, 5),
tmodel(new ConditionalPYPWordModel<PoissonUniformWordModel>(&base, new FreqBinner("10k.freq"))),
kX(-TD::Convert("X")) {}
void PYPLexicalTranslation::Summary() const {
tmodel->Summary();
}
prob_t PYPLexicalTranslation::Likelihood() const {
return tmodel->Likelihood() * base.Likelihood();
}
void PYPLexicalTranslation::ResampleHyperparameters(MT19937* rng) {
tmodel->ResampleHyperparameters(rng);
}
unsigned PYPLexicalTranslation::UniqueConditioningContexts() const {
return tmodel->UniqueConditioningContexts();
}
prob_t PYPLexicalTranslation::Prob(WordID src, WordID trg) const {
return tmodel->Prob(src, letters[trg]);
}
void PYPLexicalTranslation::Increment(WordID src, WordID trg, MT19937* rng) {
tmodel->Increment(src, letters[trg], rng);
}
void PYPLexicalTranslation::Decrement(WordID src, WordID trg, MT19937* rng) {
tmodel->Decrement(src, letters[trg], rng);
}
|