summaryrefslogtreecommitdiff
path: root/gi/pf/hpyp_tm.cc
blob: f362d3f85158352c2f428551ccbf4b3cf82ac721 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include "hpyp_tm.h"

#include <tr1/unordered_map>
#include <iostream>
#include <queue>

#include "tdict.h"
#include "ccrp.h"
#include "pyp_word_model.h"
#include "tied_resampler.h"

using namespace std;
using namespace std::tr1;

struct FreqBinner {
  FreqBinner(const std::string& fname) { fd_.Load(fname); }
  unsigned NumberOfBins() const { return fd_.Max() + 1; }
  unsigned Bin(const WordID& w) const { return fd_.LookUp(w); }
  FreqDict<unsigned> fd_;
};

template <typename Base, class Binner = FreqBinner>
struct ConditionalPYPWordModel {
  ConditionalPYPWordModel(Base* b, const Binner* bnr = NULL) :
      base(*b),
      binner(bnr),
      btr(binner ? binner->NumberOfBins() + 1u : 2u) {}

  void Summary() const {
    cerr << "Number of conditioning contexts: " << r.size() << endl;
    for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
      cerr << TD::Convert(it->first) << "   \tPYP(d=" << it->second.discount() << ",s=" << it->second.strength() << ") --------------------------" << endl;
      for (CCRP<vector<WordID> >::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
        cerr << "   " << i2->second << endl;
    }
  }

  void ResampleHyperparameters(MT19937* rng) {
    btr.ResampleHyperparameters(rng);
  } 

  prob_t Prob(const WordID src, const vector<WordID>& trglets) const {
    RuleModelHash::const_iterator it = r.find(src);
    if (it == r.end()) {
      return base(trglets);
    } else {
      return it->second.prob(trglets, base(trglets));
    }
  }

  void Increment(const WordID src, const vector<WordID>& trglets, MT19937* rng) {
    RuleModelHash::iterator it = r.find(src);
    if (it == r.end()) {
      it = r.insert(make_pair(src, CCRP<vector<WordID> >(0.5,1.0))).first;
      static const WordID kNULL = TD::Convert("NULL");
      unsigned bin = (src == kNULL ? 0 : 1);
      if (binner && bin) { bin = binner->Bin(src) + 1; }
      btr.Add(bin, &it->second);
    }
    if (it->second.increment(trglets, base(trglets), rng))
      base.Increment(trglets, rng);
  }

  void Decrement(const WordID src, const vector<WordID>& trglets, MT19937* rng) {
    RuleModelHash::iterator it = r.find(src);
    assert(it != r.end());
    if (it->second.decrement(trglets, rng)) {
      base.Decrement(trglets, rng);
    }
  }

  prob_t Likelihood() const {
    prob_t p = prob_t::One();
    for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
      prob_t q; q.logeq(it->second.log_crp_prob());
      p *= q;
    }
    return p;
  }

  unsigned UniqueConditioningContexts() const {
    return r.size();
  }

  // TODO tie PYP hyperparameters based on source word frequency bins
  Base& base;
  const Binner* binner;
  BinTiedResampler<CCRP<vector<WordID> > > btr;
  typedef unordered_map<WordID, CCRP<vector<WordID> > > RuleModelHash;
  RuleModelHash r;
};

HPYPLexicalTranslation::HPYPLexicalTranslation(const vector<vector<WordID> >& lets,
                                               const unsigned vocab_size,
                                               const unsigned num_letters) :
    letters(lets),
    base(vocab_size, num_letters, 5),
    up0(new PYPWordModel<PoissonUniformWordModel>(&base)),
    tmodel(new ConditionalPYPWordModel<PYPWordModel<PoissonUniformWordModel> >(up0, new FreqBinner("10k.freq"))),
    kX(-TD::Convert("X")) {}

void HPYPLexicalTranslation::Summary() const {
  tmodel->Summary();
  up0->Summary();
}

prob_t HPYPLexicalTranslation::Likelihood() const {
  prob_t p = up0->Likelihood();
  p *= tmodel->Likelihood();
  return p;
}

void HPYPLexicalTranslation::ResampleHyperparameters(MT19937* rng) {
  tmodel->ResampleHyperparameters(rng);
  up0->ResampleHyperparameters(rng);
}

unsigned HPYPLexicalTranslation::UniqueConditioningContexts() const {
  return tmodel->UniqueConditioningContexts();
}

prob_t HPYPLexicalTranslation::Prob(WordID src, WordID trg) const {
  return tmodel->Prob(src, letters[trg]);
}

void HPYPLexicalTranslation::Increment(WordID src, WordID trg, MT19937* rng) {
  tmodel->Increment(src, letters[trg], rng);
}

void HPYPLexicalTranslation::Decrement(WordID src, WordID trg, MT19937* rng) {
  tmodel->Decrement(src, letters[trg], rng);
}