diff options
| author | Kenneth Heafield <github@kheafield.com> | 2012-09-11 14:23:39 +0100 | 
|---|---|---|
| committer | Kenneth Heafield <github@kheafield.com> | 2012-09-11 14:27:52 +0100 | 
| commit | 8882e9ebe158aef382bb5544559ef7f2a553db62 (patch) | |
| tree | f0ed595a45df16ddd1ca7ba61bc4ac0ee22d2dfb /klm | |
| parent | 104e23dd0b0795abab4565228537438481dc5a5b (diff) | |
Update kenlm and build system
Diffstat (limited to 'klm')
| -rw-r--r-- | klm/lm/Jamfile | 11 | ||||
| -rw-r--r-- | klm/lm/bhiksha.cc | 2 | ||||
| -rw-r--r-- | klm/lm/bhiksha.hh | 4 | ||||
| -rw-r--r-- | klm/lm/binary_format.cc | 4 | ||||
| -rw-r--r-- | klm/lm/binary_format.hh | 4 | ||||
| -rw-r--r-- | klm/lm/build_binary.cc | 9 | ||||
| -rw-r--r-- | klm/lm/max_order.hh | 2 | ||||
| -rw-r--r-- | klm/lm/model.cc | 21 | ||||
| -rw-r--r-- | klm/lm/model.hh | 2 | ||||
| -rw-r--r-- | klm/lm/partial.hh | 167 | ||||
| -rw-r--r-- | klm/lm/partial_test.cc | 199 | ||||
| -rw-r--r-- | klm/lm/quantize.hh | 8 | ||||
| -rw-r--r-- | klm/lm/read_arpa.cc | 23 | ||||
| -rw-r--r-- | klm/lm/search_hashed.hh | 6 | ||||
| -rw-r--r-- | klm/lm/search_trie.hh | 4 | ||||
| -rw-r--r-- | klm/lm/state.hh | 2 | ||||
| -rw-r--r-- | klm/lm/trie.cc | 4 | ||||
| -rw-r--r-- | klm/lm/trie.hh | 8 | ||||
| -rw-r--r-- | klm/lm/vocab.cc | 4 | ||||
| -rw-r--r-- | klm/lm/vocab.hh | 4 | ||||
| -rw-r--r-- | klm/util/Jamfile | 14 | ||||
| -rw-r--r-- | klm/util/ersatz_progress.cc | 10 | ||||
| -rw-r--r-- | klm/util/ersatz_progress.hh | 10 | ||||
| -rw-r--r-- | klm/util/exception.cc | 3 | ||||
| -rw-r--r-- | klm/util/exception.hh | 22 | ||||
| -rw-r--r-- | klm/util/file.cc | 7 | ||||
| -rw-r--r-- | klm/util/file_piece.cc | 1 | ||||
| -rw-r--r-- | klm/util/probing_hash_table.hh | 5 | 
28 files changed, 487 insertions, 73 deletions
| diff --git a/klm/lm/Jamfile b/klm/lm/Jamfile index b1971d88..dd620068 100644 --- a/klm/lm/Jamfile +++ b/klm/lm/Jamfile @@ -2,13 +2,14 @@ lib kenlm : bhiksha.cc binary_format.cc config.cc lm_exception.cc model.cc quant  import testing ; -run left_test.cc ../util//kenutil kenlm ../..//boost_unit_test_framework : : test.arpa ; -run model_test.cc ../util//kenutil kenlm ../..//boost_unit_test_framework : : test.arpa test_nounk.arpa ; +run left_test.cc ../util//kenutil kenlm /top//boost_unit_test_framework : : test.arpa ; +run model_test.cc ../util//kenutil kenlm /top//boost_unit_test_framework : : test.arpa test_nounk.arpa ;  exe query : ngram_query.cc kenlm ../util//kenutil ;  exe build_binary : build_binary.cc kenlm ../util//kenutil ; +exe kenlm_max_order : max_order.cc : <include>.. ; -install legacy : build_binary query  -  : <location>$(TOP)/klm/lm <install-type>EXE <install-dependencies>on <link>shared:<dll-path>$(TOP)/klm/lm <link>shared:<install-type>LIB ; +alias programs : query build_binary kenlm_max_order ; -alias programs : build_binary query ; +install legacy : build_binary query kenlm_max_order +  : <location>$(TOP)/lm <install-type>EXE <install-dependencies>on <link>shared:<dll-path>$(TOP)/lm <link>shared:<install-type>LIB ; diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc index 870a4eee..088ea98d 100644 --- a/klm/lm/bhiksha.cc +++ b/klm/lm/bhiksha.cc @@ -50,7 +50,7 @@ std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &con  }  } // namespace -std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { +uint64_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) {    return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */;  } diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index 9734f3ab..8ff88654 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -33,7 +33,7 @@ class DontBhiksha {      static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} -    static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } +    static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }      static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) {        return util::RequiredBits(max_next); @@ -67,7 +67,7 @@ class ArrayBhiksha {      static void UpdateConfigFromBinary(int fd, Config &config); -    static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); +    static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);      static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config); diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index a56e998e..fd841e59 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -200,10 +200,10 @@ void SeekPastHeader(int fd, const Parameters ¶ms) {    util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));  } -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { +uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) {    const uint64_t file_size = util::SizeFile(backing.file.get());    // The header is smaller than a page, so we have to map the whole header as well.   -  std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size; +  std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);    if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)      UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index dd795f62..bf699d5f 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -70,7 +70,7 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet  void SeekPastHeader(int fd, const Parameters ¶ms); -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing); +uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing);  void ComplainAboutARPA(const Config &config, ModelType model_type); @@ -90,7 +90,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)        new_config.probing_multiplier = params.fixed.probing_multiplier;        detail::SeekPastHeader(backing.file.get(), params);        To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); -      std::size_t memory_size = To::Size(params.counts, new_config); +      uint64_t memory_size = To::Size(params.counts, new_config);        uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing);        to.InitializeFromBinary(start, params, new_config, backing.file.get());      } else { diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index c2ca1101..efe99899 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -8,7 +8,6 @@  #include <math.h>  #include <stdlib.h> -#include <unistd.h>  #ifdef WIN32  #include "util/getopt.hh" @@ -86,16 +85,16 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {    std::vector<uint64_t> counts;    util::FilePiece f(file);    lm::ReadARPACounts(f, counts); -  std::size_t sizes[6]; +  uint64_t sizes[6];    sizes[0] = ProbingModel::Size(counts, config);    sizes[1] = RestProbingModel::Size(counts, config);    sizes[2] = TrieModel::Size(counts, config);    sizes[3] = QuantTrieModel::Size(counts, config);    sizes[4] = ArrayTrieModel::Size(counts, config);    sizes[5] = QuantArrayTrieModel::Size(counts, config); -  std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); -  std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); -  std::size_t divide; +  uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); +  uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); +  uint64_t divide;    char prefix;    if (min_length < (1 << 10) * 10) {      prefix = ' '; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index bc8687cd..989f8324 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -8,5 +8,5 @@  #define KENLM_MAX_ORDER 6  #endif  #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "Edit klm/lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile.  In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'.  Otherwise, edit lm/max_order.hh."  #endif diff --git a/klm/lm/model.cc b/klm/lm/model.cc index b46333a4..40af8a63 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -12,6 +12,7 @@  #include <functional>  #include <numeric>  #include <cmath> +#include <limits>  namespace lm {  namespace ngram { @@ -19,17 +20,18 @@ namespace detail {  template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType; -template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) { +template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {    return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);  }  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) { +  size_t goal_size = util::CheckOverflow(Size(counts, config));    uint8_t *start = static_cast<uint8_t*>(base);    size_t allocated = VocabularyT::Size(counts[0], config);    vocab_.SetupMemory(start, allocated, counts[0], config);    start += allocated;    start = search_.SetupMemory(start, counts, config); -  if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config)); +  if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);  }  template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { @@ -49,13 +51,18 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge  }  namespace { -void CheckMaxOrder(size_t order) { -  UTIL_THROW_IF(order > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << order << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ".  " << KENLM_ORDER_MESSAGE); +void CheckCounts(const std::vector<uint64_t> &counts) { +  UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ".  " << KENLM_ORDER_MESSAGE); +  if (sizeof(uint64_t) > sizeof(std::size_t)) { +    for (std::vector<uint64_t>::const_iterator i = counts.begin(); i != counts.end(); ++i) { +      UTIL_THROW_IF(*i > static_cast<uint64_t>(std::numeric_limits<size_t>::max()), util::OverflowException, "This model has " << *i << " " << (i - counts.begin() + 1) << "-grams which is too many for 32-bit machines."); +    } +  }  }  } // namespace  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { -  CheckMaxOrder(params.counts.size()); +  CheckCounts(params.counts);    SetupMemory(start, params.counts, config);    vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);    search_.LoadedBinary(); @@ -68,11 +75,11 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT      std::vector<uint64_t> counts;      // File counts do not include pruned trigrams that extend to quadgrams etc.   These will be fixed by search_.      ReadARPACounts(f, counts); -    CheckMaxOrder(counts.size()); +    CheckCounts(counts);      if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");      if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); -    std::size_t vocab_size = VocabularyT::Size(counts[0], config); +    std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));      // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs.        vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 6dee9419..13ff864e 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -41,7 +41,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod       * does not include small non-mapped control structures, such as this class       * itself.         */ -    static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); +    static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());      /* Load the model from a file.  It may be an ARPA or binary file.  Binary       * files must have the format expected by this class or you'll get an diff --git a/klm/lm/partial.hh b/klm/lm/partial.hh new file mode 100644 index 00000000..1dede359 --- /dev/null +++ b/klm/lm/partial.hh @@ -0,0 +1,167 @@ +#ifndef LM_PARTIAL__ +#define LM_PARTIAL__ + +#include "lm/return.hh" +#include "lm/state.hh" + +#include <algorithm> + +#include <assert.h> + +namespace lm { +namespace ngram { + +struct ExtendReturn { +  float adjust; +  bool make_full; +  unsigned char next_use; +}; + +template <class Model> ExtendReturn ExtendLoop( +    const Model &model, +    unsigned char seen, const WordIndex *add_rbegin, const WordIndex *add_rend, const float *backoff_start, +    const uint64_t *pointers, const uint64_t *pointers_end, +    uint64_t *&pointers_write, +    float *backoff_write) { +  unsigned char add_length = add_rend - add_rbegin; + +  float backoff_buf[2][KENLM_MAX_ORDER - 1]; +  float *backoff_in = backoff_buf[0], *backoff_out = backoff_buf[1]; +  std::copy(backoff_start, backoff_start + add_length, backoff_in); + +  ExtendReturn value; +  value.make_full = false; +  value.adjust = 0.0; +  value.next_use = add_length; + +  unsigned char i = 0; +  unsigned char length = pointers_end - pointers; +  // pointers_write is NULL means that the existing left state is full, so we should use completed probabilities.   +  if (pointers_write) { +    // Using full context, writing to new left state.    +    for (; i < length; ++i) { +      FullScoreReturn ret(model.ExtendLeft( +          add_rbegin, add_rbegin + value.next_use, +          backoff_in, +          pointers[i], i + seen + 1, +          backoff_out, +          value.next_use)); +      std::swap(backoff_in, backoff_out); +      if (ret.independent_left) { +        value.adjust += ret.prob; +        value.make_full = true; +        ++i; +        break; +      } +      value.adjust += ret.rest; +      *pointers_write++ = ret.extend_left; +      if (value.next_use != add_length) { +        value.make_full = true; +        ++i; +        break; +      } +    } +  } +  // Using some of the new context.   +  for (; i < length && value.next_use; ++i) { +    FullScoreReturn ret(model.ExtendLeft( +        add_rbegin, add_rbegin + value.next_use, +        backoff_in, +        pointers[i], i + seen + 1, +        backoff_out, +        value.next_use)); +    std::swap(backoff_in, backoff_out); +    value.adjust += ret.prob; +  } +  float unrest = model.UnRest(pointers + i, pointers_end, i + seen + 1); +  // Using none of the new context.   +  value.adjust += unrest; + +  std::copy(backoff_in, backoff_in + value.next_use, backoff_write); +  return value; +} + +template <class Model> float RevealBefore(const Model &model, const Right &reveal, const unsigned char seen, bool reveal_full, Left &left, Right &right) { +  assert(seen < reveal.length || reveal_full); +  uint64_t *pointers_write = reveal_full ? NULL : left.pointers; +  float backoff_buffer[KENLM_MAX_ORDER - 1]; +  ExtendReturn value(ExtendLoop( +      model, +      seen, reveal.words + seen, reveal.words + reveal.length, reveal.backoff + seen, +      left.pointers, left.pointers + left.length, +      pointers_write, +      left.full ? backoff_buffer : (right.backoff + right.length))); +  if (reveal_full) { +    left.length = 0; +    value.make_full = true; +  } else { +    left.length = pointers_write - left.pointers; +    value.make_full |= (left.length == model.Order() - 1); +  } +  if (left.full) { +    for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i]; +  } else { +    // If left wasn't full when it came in, put words into right state.   +    std::copy(reveal.words + seen, reveal.words + seen + value.next_use, right.words + right.length); +    right.length += value.next_use; +    left.full = value.make_full || (right.length == model.Order() - 1); +  } +  return value.adjust; +} + +template <class Model> float RevealAfter(const Model &model, Left &left, Right &right, const Left &reveal, unsigned char seen) { +  assert(seen < reveal.length || reveal.full); +  uint64_t *pointers_write = left.full ? NULL : (left.pointers + left.length); +  ExtendReturn value(ExtendLoop( +      model, +      seen, right.words, right.words + right.length, right.backoff, +      reveal.pointers + seen, reveal.pointers + reveal.length, +      pointers_write, +      right.backoff)); +  if (reveal.full) { +    for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += right.backoff[i]; +    right.length = 0; +    value.make_full = true; +  } else { +    right.length = value.next_use; +    value.make_full |= (right.length == model.Order() - 1); +  } +  if (!left.full) { +    left.length = pointers_write - left.pointers; +    left.full = value.make_full || (left.length == model.Order() - 1); +  } +  return value.adjust; +} + +template <class Model> float Subsume(const Model &model, Left &first_left, const Right &first_right, const Left &second_left, Right &second_right, const unsigned int between_length) { +  assert(first_right.length < KENLM_MAX_ORDER); +  assert(second_left.length < KENLM_MAX_ORDER); +  assert(between_length < KENLM_MAX_ORDER - 1); +  uint64_t *pointers_write = first_left.full ? NULL : (first_left.pointers + first_left.length); +  float backoff_buffer[KENLM_MAX_ORDER - 1]; +  ExtendReturn value(ExtendLoop( +        model, +        between_length, first_right.words, first_right.words + first_right.length, first_right.backoff, +        second_left.pointers, second_left.pointers + second_left.length, +        pointers_write, +        second_left.full ? backoff_buffer : (second_right.backoff + second_right.length))); +  if (second_left.full) { +    for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i]; +  } else { +    std::copy(first_right.words, first_right.words + value.next_use, second_right.words + second_right.length); +    second_right.length += value.next_use; +    value.make_full |= (second_right.length == model.Order() - 1); +  } +  if (!first_left.full) { +    first_left.length = pointers_write - first_left.pointers; +    first_left.full = value.make_full || second_left.full || (first_left.length == model.Order() - 1); +  } +  assert(first_left.length < KENLM_MAX_ORDER); +  assert(second_right.length < KENLM_MAX_ORDER); +  return value.adjust; +} + +} // namespace ngram +} // namespace lm + +#endif // LM_PARTIAL__ diff --git a/klm/lm/partial_test.cc b/klm/lm/partial_test.cc new file mode 100644 index 00000000..8d309c85 --- /dev/null +++ b/klm/lm/partial_test.cc @@ -0,0 +1,199 @@ +#include "lm/partial.hh" + +#include "lm/left.hh" +#include "lm/model.hh" +#include "util/tokenize_piece.hh" + +#define BOOST_TEST_MODULE PartialTest +#include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp> + +namespace lm { +namespace ngram { +namespace { + +const char *TestLocation() { +  if (boost::unit_test::framework::master_test_suite().argc < 2) { +    return "test.arpa"; +  } +  return boost::unit_test::framework::master_test_suite().argv[1]; +} + +Config SilentConfig() { +  Config config; +  config.arpa_complain = Config::NONE; +  config.messages = NULL; +  return config; +} + +struct ModelFixture { +  ModelFixture() : m(TestLocation(), SilentConfig()) {} + +  RestProbingModel m; +}; + +BOOST_FIXTURE_TEST_SUITE(suite, ModelFixture) + +BOOST_AUTO_TEST_CASE(SimpleBefore) { +  Left left; +  left.full = false; +  left.length = 0; +  Right right; +  right.length = 0; + +  Right reveal; +  reveal.length = 1; +  WordIndex period = m.GetVocabulary().Index("."); +  reveal.words[0] = period; +  reveal.backoff[0] = -0.845098; + +  BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 0, false, left, right), 0.001); +  BOOST_CHECK_EQUAL(0, left.length); +  BOOST_CHECK(!left.full); +  BOOST_CHECK_EQUAL(1, right.length); +  BOOST_CHECK_EQUAL(period, right.words[0]); +  BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001); + +  WordIndex more = m.GetVocabulary().Index("more"); +  reveal.words[1] = more; +  reveal.backoff[1] =  -0.4771212; +  reveal.length = 2; +  BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 1, false, left, right), 0.001); +  BOOST_CHECK_EQUAL(0, left.length); +  BOOST_CHECK(!left.full); +  BOOST_CHECK_EQUAL(2, right.length); +  BOOST_CHECK_EQUAL(period, right.words[0]); +  BOOST_CHECK_EQUAL(more, right.words[1]); +  BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001); +  BOOST_CHECK_CLOSE(-0.4771212, right.backoff[1], 0.001); +} + +BOOST_AUTO_TEST_CASE(AlsoWouldConsider) { +  WordIndex would = m.GetVocabulary().Index("would"); +  WordIndex consider = m.GetVocabulary().Index("consider"); + +  ChartState current; +  current.left.length = 1; +  current.left.pointers[0] = would; +  current.left.full = false; +  current.right.length = 1; +  current.right.words[0] = would; +  current.right.backoff[0] = -0.30103; + +  Left after; +  after.full = false; +  after.length = 1; +  after.pointers[0] = consider; + +  // adjustment for would consider +  BOOST_CHECK_CLOSE(-1.687872 - -0.2922095 - 0.30103, RevealAfter(m, current.left, current.right, after, 0), 0.001); + +  BOOST_CHECK_EQUAL(2, current.left.length); +  BOOST_CHECK_EQUAL(would, current.left.pointers[0]); +  BOOST_CHECK_EQUAL(false, current.left.full); + +  WordIndex also = m.GetVocabulary().Index("also"); +  Right before; +  before.length = 1; +  before.words[0] = also; +  before.backoff[0] = -0.30103; +  // r(would) = -0.2922095 [i would], r(would -> consider) = -1.988902 [b(would) + p(consider)] +  // p(also -> would) = -2, p(also would -> consider) = -3 +  BOOST_CHECK_CLOSE(-2 + 0.2922095 -3 + 1.988902, RevealBefore(m, before, 0, false, current.left, current.right), 0.001); +  BOOST_CHECK_EQUAL(0, current.left.length); +  BOOST_CHECK(current.left.full); +  BOOST_CHECK_EQUAL(2, current.right.length); +  BOOST_CHECK_EQUAL(would, current.right.words[0]); +  BOOST_CHECK_EQUAL(also, current.right.words[1]); +} + +BOOST_AUTO_TEST_CASE(EndSentence) { +  WordIndex loin = m.GetVocabulary().Index("loin"); +  WordIndex period = m.GetVocabulary().Index("."); +  WordIndex eos = m.GetVocabulary().EndSentence(); + +  ChartState between; +  between.left.length = 1; +  between.left.pointers[0] = eos; +  between.left.full = true; +  between.right.length = 0; + +  Right before; +  before.words[0] = period; +  before.words[1] = loin; +  before.backoff[0] = -0.845098; +  before.backoff[1] = 0.0; +   +  before.length = 1; +  BOOST_CHECK_CLOSE(-0.0410707, RevealBefore(m, before, 0, true, between.left, between.right), 0.001); +  BOOST_CHECK_EQUAL(0, between.left.length); +} + +float ScoreFragment(const RestProbingModel &model, unsigned int *begin, unsigned int *end, ChartState &out) { +  RuleScore<RestProbingModel> scorer(model, out); +  for (unsigned int *i = begin; i < end; ++i) { +    scorer.Terminal(*i); +  } +  return scorer.Finish(); +} + +void CheckAdjustment(const RestProbingModel &model, float expect, const Right &before_in, bool before_full, ChartState between, const Left &after_in) { +  Right before(before_in); +  Left after(after_in); +  after.full = false; +  float got = 0.0; +  for (unsigned int i = 1; i < 5; ++i) { +    if (before_in.length >= i) { +      before.length = i; +      got += RevealBefore(model, before, i - 1, false, between.left, between.right); +    } +    if (after_in.length >= i) { +      after.length = i; +      got += RevealAfter(model, between.left, between.right, after, i - 1); +    } +  } +  if (after_in.full) { +    after.full = true; +    got += RevealAfter(model, between.left, between.right, after, after.length); +  } +  if (before_full) { +    got += RevealBefore(model, before, before.length, true, between.left, between.right); +  } +  // Sometimes they're zero and BOOST_CHECK_CLOSE fails for this.  +  BOOST_CHECK(fabs(expect - got) < 0.001); +} + +void FullDivide(const RestProbingModel &model, StringPiece str) { +  std::vector<WordIndex> indices; +  for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) { +    indices.push_back(model.GetVocabulary().Index(*i)); +  } +  ChartState full_state; +  float full = ScoreFragment(model, &indices.front(), &indices.back() + 1, full_state); + +  ChartState before_state; +  before_state.left.full = false; +  RuleScore<RestProbingModel> before_scorer(model, before_state); +  float before_score = 0.0; +  for (unsigned int before = 0; before < indices.size(); ++before) { +    for (unsigned int after = before; after <= indices.size(); ++after) { +      ChartState after_state, between_state; +      float after_score = ScoreFragment(model, &indices.front() + after, &indices.front() + indices.size(), after_state); +      float between_score = ScoreFragment(model, &indices.front() + before, &indices.front() + after, between_state); +      CheckAdjustment(model, full - before_score - after_score - between_score, before_state.right, before_state.left.full, between_state, after_state.left); +    } +    before_scorer.Terminal(indices[before]); +    before_score = before_scorer.Finish(); +  } +} + +BOOST_AUTO_TEST_CASE(Strings) { +  FullDivide(m, "also would consider"); +  FullDivide(m, "looking on a little more loin . </s>"); +  FullDivide(m, "in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>"); +} + +BOOST_AUTO_TEST_SUITE_END() +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index abed0112..8ce2378a 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -24,7 +24,7 @@ class DontQuantize {    public:      static const ModelType kModelTypeAdd = static_cast<ModelType>(0);      static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} -    static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } +    static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }      static uint8_t MiddleBits(const Config &/*config*/) { return 63; }      static uint8_t LongestBits(const Config &/*config*/) { return 31; } @@ -138,9 +138,9 @@ class SeparatelyQuantize {      static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); -    static std::size_t Size(uint8_t order, const Config &config) { -      size_t longest_table = (static_cast<size_t>(1) << static_cast<size_t>(config.prob_bits)) * sizeof(float); -      size_t middle_table = (static_cast<size_t>(1) << static_cast<size_t>(config.backoff_bits)) * sizeof(float) + longest_table; +    static uint64_t Size(uint8_t order, const Config &config) { +      uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float); +      uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table;        // unigrams are currently not quantized so no need for a table.          return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8;      } diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 70727e4c..174bd3a3 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -2,12 +2,13 @@  #include "lm/blank.hh" +#include <cmath>  #include <cstdlib>  #include <iostream> +#include <sstream>  #include <vector>  #include <ctype.h> -#include <math.h>  #include <string.h>  #include <stdint.h> @@ -31,6 +32,16 @@ bool IsEntirelyWhiteSpace(const StringPiece &line) {  const char kBinaryMagic[] = "mmap lm http://kheafield.com/code"; +// strtoull isn't portable enough :-( +uint64_t ReadCount(const std::string &from) { +  std::stringstream stream(from); +  uint64_t ret; +  stream >> ret; +  UTIL_THROW_IF(!stream, FormatLoadException, "Bad count " << from); +  UTIL_THROW_IF(static_cast<std::size_t>(stream.tellg()) != from.size(), FormatLoadException, "Extra content in count: '" << from << "'"); +  return ret; +} +  } // namespace  void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { @@ -52,15 +63,11 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {      // So strtol doesn't go off the end of line.        std::string remaining(line.data() + 6, line.size() - 6);      char *end_ptr; -    unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10); +    unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10);      if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line);      if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line);      ++end_ptr; -    const char *start = end_ptr; -    long int count = std::strtol(start, &end_ptr, 10); -    if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count); -    if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line); -    number.push_back(count); +    number.push_back(ReadCount(end_ptr));    }  } @@ -103,7 +110,7 @@ void ReadBackoff(util::FilePiece &in, float &backoff) {  		int float_class = _fpclass(backoff);          UTIL_THROW_IF(float_class == _FPCLASS_SNAN || float_class == _FPCLASS_QNAN || float_class == _FPCLASS_NINF || float_class == _FPCLASS_PINF, FormatLoadException, "Bad backoff " << backoff);  #else -        int float_class = fpclassify(backoff); +        int float_class = std::fpclassify(backoff);          UTIL_THROW_IF(float_class == FP_NAN || float_class == FP_INFINITE, FormatLoadException, "Bad backoff " << backoff);  #endif        } diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 7e8c1220..3bcde921 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -74,8 +74,8 @@ template <class Value> class HashedSearch {      // TODO: move probing_multiplier here with next binary file format update.        static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} -    static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { -      std::size_t ret = Unigram::Size(counts[0]); +    static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { +      uint64_t ret = Unigram::Size(counts[0]);        for (unsigned char n = 1; n < counts.size() - 1; ++n) {          ret += Middle::Size(counts[n], config.probing_multiplier);        } @@ -160,7 +160,7 @@ template <class Value> class HashedSearch {  #endif        {} -        static std::size_t Size(uint64_t count) { +        static uint64_t Size(uint64_t count) {            return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>          } diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 10b22ab1..1264baf5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -44,8 +44,8 @@ template <class Quant, class Bhiksha> class TrieSearch {        Bhiksha::UpdateConfigFromBinary(fd, config);      } -    static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { -      std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); +    static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { +      uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);        for (unsigned char i = 1; i < counts.size() - 1; ++i) {          ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);        } diff --git a/klm/lm/state.hh b/klm/lm/state.hh index 830e40aa..551510a8 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -47,6 +47,8 @@ class State {      unsigned char length;  }; +typedef State Right; +  inline uint64_t hash_value(const State &state, uint64_t seed = 0) {    return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed);  } diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 0f1ca574..d9895f89 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -36,7 +36,7 @@ bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_  }  } // namespace -std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { +uint64_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {    uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits;    // Extra entry for next pointer at the end.      // +7 then / 8 to round up bits and convert to bytes @@ -57,7 +57,7 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)    max_vocab_ = max_vocab;  } -template <class Bhiksha> std::size_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { +template <class Bhiksha> uint64_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {    return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));  } diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 034a1414..9ea3c546 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -49,7 +49,7 @@ class Unigram {        unigram_ = static_cast<UnigramValue*>(start);      } -    static std::size_t Size(uint64_t count) { +    static uint64_t Size(uint64_t count) {        // +1 in case unknown doesn't appear.  +1 for the final next.          return (count + 2) * sizeof(UnigramValue);      } @@ -84,7 +84,7 @@ class BitPacked {      }    protected: -    static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); +    static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);      void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); @@ -99,7 +99,7 @@ class BitPacked {  template <class Bhiksha> class BitPackedMiddle : public BitPacked {    public: -    static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); +    static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);      // next_source need not be initialized.        BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); @@ -128,7 +128,7 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked {  class BitPackedLongest : public BitPacked {    public: -    static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { +    static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {        return BaseSize(entries, max_vocab, quant_bits);      } diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 5de68f16..398475be 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -87,7 +87,7 @@ void WriteWordsWrapper::Write(int fd) {  SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} -std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { +uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {    // Lead with the number of entries.      return sizeof(uint64_t) + sizeof(uint64_t) * entries;  } @@ -165,7 +165,7 @@ struct ProbingVocabularyHeader {  ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} -std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { +uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) {    return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier);  } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index a25432f9..074cd446 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -62,7 +62,7 @@ class SortedVocabulary : public base::Vocabulary {      }      // Size for purposes of file writing -    static size_t Size(std::size_t entries, const Config &config); +    static uint64_t Size(uint64_t entries, const Config &config);      // Vocab words are [0, Bound())  Only valid after FinishedLoading/LoadedBinary.        WordIndex Bound() const { return bound_; } @@ -129,7 +129,7 @@ class ProbingVocabulary : public base::Vocabulary {        return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;      } -    static size_t Size(std::size_t entries, const Config &config); +    static uint64_t Size(uint64_t entries, const Config &config);      // Vocab words are [0, Bound()).        WordIndex Bound() const { return bound_; } diff --git a/klm/util/Jamfile b/klm/util/Jamfile index 3ee2c2c2..a939265f 100644 --- a/klm/util/Jamfile +++ b/klm/util/Jamfile @@ -1,10 +1,10 @@ -lib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc usage.cc ../..//z : <include>.. : : <include>.. ; +lib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc usage.cc /top//z : <include>.. : : <include>.. ;  import testing ; -unit-test bit_packing_test : bit_packing_test.cc kenutil ../..///boost_unit_test_framework ; -run file_piece_test.cc kenutil ../..///boost_unit_test_framework : : file_piece.cc ; -unit-test joint_sort_test : joint_sort_test.cc kenutil ../..///boost_unit_test_framework ; -unit-test probing_hash_table_test : probing_hash_table_test.cc kenutil ../..///boost_unit_test_framework ; -unit-test sorted_uniform_test : sorted_uniform_test.cc kenutil ../..///boost_unit_test_framework ; -unit-test tokenize_piece_test : tokenize_piece_test.cc kenutil ../..///boost_unit_test_framework ; +unit-test bit_packing_test : bit_packing_test.cc kenutil /top//boost_unit_test_framework ; +run file_piece_test.cc kenutil /top//boost_unit_test_framework : : file_piece.cc ; +unit-test joint_sort_test : joint_sort_test.cc kenutil /top//boost_unit_test_framework ; +unit-test probing_hash_table_test : probing_hash_table_test.cc kenutil /top//boost_unit_test_framework ; +unit-test sorted_uniform_test : sorted_uniform_test.cc kenutil /top//boost_unit_test_framework ; +unit-test tokenize_piece_test : tokenize_piece_test.cc kenutil /top//boost_unit_test_framework ; diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc index 07b14e26..eb635ad8 100644 --- a/klm/util/ersatz_progress.cc +++ b/klm/util/ersatz_progress.cc @@ -9,16 +9,16 @@ namespace util {  namespace { const unsigned char kWidth = 100; } -ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<std::size_t>::max()), complete_(next_), out_(NULL) {} +ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<uint64_t>::max()), complete_(next_), out_(NULL) {}  ErsatzProgress::~ErsatzProgress() {    if (out_) Finished();  } -ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std::string &message)  +ErsatzProgress::ErsatzProgress(uint64_t complete, std::ostream *to, const std::string &message)     : current_(0), next_(complete / kWidth), complete_(complete), stones_written_(0), out_(to) {    if (!out_) { -    next_ = std::numeric_limits<std::size_t>::max(); +    next_ = std::numeric_limits<uint64_t>::max();      return;    }    if (!message.empty()) *out_ << message << '\n'; @@ -28,14 +28,14 @@ ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std  void ErsatzProgress::Milestone() {    if (!out_) { current_ = 0; return; }    if (!complete_) return; -  unsigned char stone = std::min(static_cast<std::size_t>(kWidth), (current_ * kWidth) / complete_); +  unsigned char stone = std::min(static_cast<uint64_t>(kWidth), (current_ * kWidth) / complete_);    for (; stones_written_ < stone; ++stones_written_) {      (*out_) << '*';    }    if (stone == kWidth) {      (*out_) << std::endl; -    next_ = std::numeric_limits<std::size_t>::max(); +    next_ = std::numeric_limits<uint64_t>::max();      out_ = NULL;    } else {      next_ = std::max(next_, (stone * complete_) / kWidth); diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index f709dc51..ff4d590f 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -4,6 +4,8 @@  #include <iostream>  #include <string> +#include <inttypes.h> +  // Ersatz version of boost::progress so core language model doesn't depend on  // boost.  Also adds option to print nothing.   @@ -14,7 +16,7 @@ class ErsatzProgress {      ErsatzProgress();      // Null means no output.  The null value is useful for passing along the ostream pointer from another caller.    -    explicit ErsatzProgress(std::size_t complete, std::ostream *to = &std::cerr, const std::string &message = ""); +    explicit ErsatzProgress(uint64_t complete, std::ostream *to = &std::cerr, const std::string &message = "");      ~ErsatzProgress(); @@ -23,12 +25,12 @@ class ErsatzProgress {        return *this;      } -    ErsatzProgress &operator+=(std::size_t amount) { +    ErsatzProgress &operator+=(uint64_t amount) {        if ((current_ += amount) >= next_) Milestone();        return *this;      } -    void Set(std::size_t to) { +    void Set(uint64_t to) {        if ((current_ = to) >= next_) Milestone();        Milestone();      } @@ -40,7 +42,7 @@ class ErsatzProgress {    private:      void Milestone(); -    std::size_t current_, next_, complete_; +    uint64_t current_, next_, complete_;      unsigned char stones_written_;      std::ostream *out_; diff --git a/klm/util/exception.cc b/klm/util/exception.cc index c4f8c04c..3806e6de 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -84,4 +84,7 @@ EndOfFileException::EndOfFileException() throw() {  }  EndOfFileException::~EndOfFileException() throw() {} +OverflowException::OverflowException() throw() {} +OverflowException::~OverflowException() throw() {} +  } // namespace util diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 6d6a37cb..83f99cd6 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -2,9 +2,12 @@  #define UTIL_EXCEPTION__  #include <exception> +#include <limits>  #include <sstream>  #include <string> +#include <inttypes.h> +  namespace util {  template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data); @@ -111,6 +114,25 @@ class EndOfFileException : public Exception {      ~EndOfFileException() throw();  }; +class OverflowException : public Exception { +  public: +    OverflowException() throw(); +    ~OverflowException() throw(); +}; + +template <unsigned len> inline std::size_t CheckOverflowInternal(uint64_t value) { +  UTIL_THROW_IF(value > static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), OverflowException, "Integer overflow detected.  This model is too big for 32-bit code."); +  return value; +} + +template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) { +  return value; +} + +inline std::size_t CheckOverflow(uint64_t value) { +  return CheckOverflowInternal<sizeof(std::size_t)>(value); +} +  } // namespace util  #endif // UTIL_EXCEPTION__ diff --git a/klm/util/file.cc b/klm/util/file.cc index 98f13983..ff5e64c9 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -119,8 +119,13 @@ void FSyncOrThrow(int fd) {  }  namespace { -void InternalSeek(int fd, off_t off, int whence) { +void InternalSeek(int fd, int64_t off, int whence) { +#if defined(_WIN32) || defined(_WIN64) +  UTIL_THROW_IF((__int64)-1 == _lseeki64(fd, off, whence), ErrnoException, "Windows seek failed"); + +#else    UTIL_THROW_IF((off_t)-1 == lseek(fd, off, whence), ErrnoException, "Seek failed"); +#endif  }  } // namespace diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index af341d6d..19a68728 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -11,7 +11,6 @@  #include <string>  #include <limits> -#include <unistd.h>  #include <assert.h>  #include <ctype.h>  #include <fcntl.h> diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 3354b68e..770faa7e 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -8,6 +8,7 @@  #include <functional>  #include <assert.h> +#include <inttypes.h>  namespace util { @@ -42,8 +43,8 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry      typedef EqualT Equal;    public: -    static std::size_t Size(std::size_t entries, float multiplier) { -      std::size_t buckets = std::max(entries + 1, static_cast<std::size_t>(multiplier * static_cast<float>(entries))); +    static uint64_t Size(uint64_t entries, float multiplier) { +      uint64_t buckets = std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries)));        return buckets * sizeof(Entry);      } | 
