summaryrefslogtreecommitdiff
path: root/klm/lm/quantize.cc
blob: a8e0cb21cf565e40230d76b4f15cf96e9d49d0e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
/* Quantize into bins of equal size as described in
 * M. Federico and N. Bertoldi. 2006. How many bits are needed
 * to store probabilities for phrase-based translation? In Proc.
 * of the Workshop on Statistical Machine Translation, pages
 * 94–101, New York City, June. Association for Computa-
 * tional Linguistics.
 */

#include "lm/quantize.hh"

#include "lm/binary_format.hh"
#include "lm/lm_exception.hh"
#include "util/file.hh"

#include <algorithm>
#include <numeric>

namespace lm {
namespace ngram {

namespace {

void MakeBins(std::vector<float> &values, float *centers, uint32_t bins) {
  std::sort(values.begin(), values.end());
  std::vector<float>::const_iterator start = values.begin(), finish;
  for (uint32_t i = 0; i < bins; ++i, ++centers, start = finish) {
    finish = values.begin() + ((values.size() * static_cast<uint64_t>(i + 1)) / bins);
    if (finish == start) {
      // zero length bucket.
      *centers = i ? *(centers - 1) : -std::numeric_limits<float>::infinity();
    } else {
      *centers = std::accumulate(start, finish, 0.0) / static_cast<float>(finish - start);
    }
  }
}

const char kSeparatelyQuantizeVersion = 2;

} // namespace

void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) {
  char version;
  util::ReadOrThrow(fd, &version, 1);
  util::ReadOrThrow(fd, &config.prob_bits, 1);
  util::ReadOrThrow(fd, &config.backoff_bits, 1);
  if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion);
  util::AdvanceOrThrow(fd, -3);
}

void SeparatelyQuantize::SetupMemory(void *start, const Config &config) {
  // Reserve 8 byte header for bit counts.  
  start_ = reinterpret_cast<float*>(static_cast<uint8_t*>(start) + 8);
  prob_bits_ = config.prob_bits;
  backoff_bits_ = config.backoff_bits;
  // We need the reserved values.  
  if (config.prob_bits == 0) UTIL_THROW(ConfigException, "You can't quantize probability to zero");
  if (config.backoff_bits == 0) UTIL_THROW(ConfigException, "You can't quantize backoff to zero");
  if (config.prob_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing probability supports at most 25 bits.  Currently you have requested " << static_cast<unsigned>(config.prob_bits) << " bits.");
  if (config.backoff_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing backoff supports at most 25 bits.  Currently you have requested " << static_cast<unsigned>(config.backoff_bits) << " bits.");
}

void SeparatelyQuantize::Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff) {
  TrainProb(order, prob);

  // Backoff
  float *centers = start_ + TableStart(order) + ProbTableLength();
  *(centers++) = kNoExtensionBackoff;
  *(centers++) = kExtensionBackoff;
  MakeBins(backoff, centers, (1ULL << backoff_bits_) - 2);
}

void SeparatelyQuantize::TrainProb(uint8_t order, std::vector<float> &prob) {
  float *centers = start_ + TableStart(order);
  MakeBins(prob, centers, (1ULL << prob_bits_));
}

void SeparatelyQuantize::FinishedLoading(const Config &config) {
  uint8_t *actual_base = reinterpret_cast<uint8_t*>(start_) - 8;
  *(actual_base++) = kSeparatelyQuantizeVersion; // version
  *(actual_base++) = config.prob_bits;
  *(actual_base++) = config.backoff_bits;
}

} // namespace ngram
} // namespace lm