summaryrefslogtreecommitdiff
path: root/gi/pf/pyp_tm.cc
blob: 6bc8a5bf1db01fdcdf864e9159f8d3d1775d7ba1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include "pyp_tm.h"

#include <tr1/unordered_map>
#include <iostream>
#include <queue>

#include "tdict.h"
#include "ccrp.h"
#include "pyp_word_model.h"
#include "tied_resampler.h"

using namespace std;
using namespace std::tr1;

struct FreqBinner {
  FreqBinner(const std::string& fname) { fd_.Load(fname); }
  unsigned NumberOfBins() const { return fd_.Max() + 1; }
  unsigned Bin(const WordID& w) const { return fd_.LookUp(w); }
  FreqDict<unsigned> fd_;
};

template <typename Base, class Binner = FreqBinner>
struct ConditionalPYPWordModel {
  ConditionalPYPWordModel(Base* b, const Binner* bnr = NULL) :
      base(*b),
      binner(bnr),
      btr(binner ? binner->NumberOfBins() + 1u : 2u) {}

  void Summary() const {
    cerr << "Number of conditioning contexts: " << r.size() << endl;
    for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
      cerr << TD::Convert(it->first) << "   \tPYP(d=" << it->second.discount() << ",s=" << it->second.strength() << ") --------------------------" << endl;
      for (CCRP<vector<WordID> >::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2)
        cerr << "   " << i2->second.total_dish_count_ << '\t' << TD::GetString(i2->first) << endl;
    }
  }

  void ResampleHyperparameters(MT19937* rng) {
    btr.ResampleHyperparameters(rng);
  } 

  prob_t Prob(const WordID src, const vector<WordID>& trglets) const {
    RuleModelHash::const_iterator it = r.find(src);
    if (it == r.end()) {
      return base(trglets);
    } else {
      return it->second.prob(trglets, base(trglets));
    }
  }

  void Increment(const WordID src, const vector<WordID>& trglets, MT19937* rng) {
    RuleModelHash::iterator it = r.find(src);
    if (it == r.end()) {
      it = r.insert(make_pair(src, CCRP<vector<WordID> >(0.5,1.0))).first;
      static const WordID kNULL = TD::Convert("NULL");
      unsigned bin = (src == kNULL ? 0 : 1);
      if (binner && bin) { bin = binner->Bin(src) + 1; }
      btr.Add(bin, &it->second);
    }
    if (it->second.increment(trglets, base(trglets), rng))
      base.Increment(trglets, rng);
  }

  void Decrement(const WordID src, const vector<WordID>& trglets, MT19937* rng) {
    RuleModelHash::iterator it = r.find(src);
    assert(it != r.end());
    if (it->second.decrement(trglets, rng)) {
      base.Decrement(trglets, rng);
    }
  }

  prob_t Likelihood() const {
    prob_t p = prob_t::One();
    for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) {
      prob_t q; q.logeq(it->second.log_crp_prob());
      p *= q;
    }
    return p;
  }

  unsigned UniqueConditioningContexts() const {
    return r.size();
  }

  // TODO tie PYP hyperparameters based on source word frequency bins
  Base& base;
  const Binner* binner;
  BinTiedResampler<CCRP<vector<WordID> > > btr;
  typedef unordered_map<WordID, CCRP<vector<WordID> > > RuleModelHash;
  RuleModelHash r;
};

PYPLexicalTranslation::PYPLexicalTranslation(const vector<vector<WordID> >& lets,
                                             const unsigned vocab_size,
                                             const unsigned num_letters) :
    letters(lets),
    base(vocab_size, num_letters, 5),
    tmodel(new ConditionalPYPWordModel<PoissonUniformWordModel>(&base, new FreqBinner("10k.freq"))),
    kX(-TD::Convert("X")) {}

void PYPLexicalTranslation::Summary() const {
  tmodel->Summary();
}

prob_t PYPLexicalTranslation::Likelihood() const {
  return tmodel->Likelihood() * base.Likelihood();
}

void PYPLexicalTranslation::ResampleHyperparameters(MT19937* rng) {
  tmodel->ResampleHyperparameters(rng);
}

unsigned PYPLexicalTranslation::UniqueConditioningContexts() const {
  return tmodel->UniqueConditioningContexts();
}

prob_t PYPLexicalTranslation::Prob(WordID src, WordID trg) const {
  return tmodel->Prob(src, letters[trg]);
}

void PYPLexicalTranslation::Increment(WordID src, WordID trg, MT19937* rng) {
  tmodel->Increment(src, letters[trg], rng);
}

void PYPLexicalTranslation::Decrement(WordID src, WordID trg, MT19937* rng) {
  tmodel->Decrement(src, letters[trg], rng);
}