diff options
Diffstat (limited to 'klm/lm')
-rw-r--r-- | klm/lm/bhiksha.cc | 2 | ||||
-rw-r--r-- | klm/lm/bhiksha.hh | 4 | ||||
-rw-r--r-- | klm/lm/binary_format.cc | 20 | ||||
-rw-r--r-- | klm/lm/binary_format.hh | 4 | ||||
-rw-r--r-- | klm/lm/build_binary.cc | 11 | ||||
-rw-r--r-- | klm/lm/max_order.hh | 2 | ||||
-rw-r--r-- | klm/lm/model.cc | 21 | ||||
-rw-r--r-- | klm/lm/model.hh | 2 | ||||
-rw-r--r-- | klm/lm/quantize.hh | 8 | ||||
-rw-r--r-- | klm/lm/read_arpa.cc | 22 | ||||
-rw-r--r-- | klm/lm/search_hashed.cc | 2 | ||||
-rw-r--r-- | klm/lm/search_hashed.hh | 8 | ||||
-rw-r--r-- | klm/lm/search_trie.cc | 2 | ||||
-rw-r--r-- | klm/lm/search_trie.hh | 4 | ||||
-rw-r--r-- | klm/lm/sri_test.cc | 65 | ||||
-rw-r--r-- | klm/lm/state.hh | 2 | ||||
-rw-r--r-- | klm/lm/trie.cc | 4 | ||||
-rw-r--r-- | klm/lm/trie.hh | 8 | ||||
-rw-r--r-- | klm/lm/trie_sort.cc | 20 | ||||
-rw-r--r-- | klm/lm/trie_sort.hh | 2 | ||||
-rw-r--r-- | klm/lm/vocab.cc | 4 | ||||
-rw-r--r-- | klm/lm/vocab.hh | 4 | ||||
-rw-r--r-- | klm/lm/word_index.hh | 3 |
23 files changed, 91 insertions, 133 deletions
diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc index 870a4eee..088ea98d 100644 --- a/klm/lm/bhiksha.cc +++ b/klm/lm/bhiksha.cc @@ -50,7 +50,7 @@ std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &con } } // namespace -std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { +uint64_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */; } diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index 9734f3ab..8ff88654 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -33,7 +33,7 @@ class DontBhiksha { static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} - static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } + static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) { return util::RequiredBits(max_next); @@ -67,7 +67,7 @@ class ArrayBhiksha { static void UpdateConfigFromBinary(int fd, Config &config); - static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); + static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config); diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index a56e998e..efa67056 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -83,7 +83,13 @@ void WriteHeader(void *to, const Parameters ¶ms) { uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { if (config.write_mmap) { std::size_t total = TotalHeaderSize(order) + memory_size; - backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED); + backing.file.reset(util::CreateOrThrow(config.write_mmap)); + if (config.write_method == Config::WRITE_MMAP) { + backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED); + } else { + util::ResizeOrThrow(backing.file.get(), 0); + util::MapAnonymous(total, backing.vocab); + } strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order)); return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order); } else { @@ -121,12 +127,14 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { if (!config.write_mmap) return; - util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); switch (config.write_method) { case Config::WRITE_MMAP: + util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); util::SyncOrThrow(backing.search.get(), backing.search.size()); break; case Config::WRITE_AFTER: + util::SeekOrThrow(backing.file.get(), 0); + util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size()); util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad); util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size()); util::FSyncOrThrow(backing.file.get()); @@ -141,6 +149,10 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_ params.fixed.has_vocabulary = config.include_vocab; params.fixed.search_version = search_version; WriteHeader(backing.vocab.get(), params); + if (config.write_method == Config::WRITE_AFTER) { + util::SeekOrThrow(backing.file.get(), 0); + util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size())); + } } namespace detail { @@ -200,10 +212,10 @@ void SeekPastHeader(int fd, const Parameters ¶ms) { util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); } -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { +uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) { const uint64_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. - std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size; + std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size); if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index dd795f62..bf699d5f 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -70,7 +70,7 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet void SeekPastHeader(int fd, const Parameters ¶ms); -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing); +uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing); void ComplainAboutARPA(const Config &config, ModelType model_type); @@ -90,7 +90,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to) new_config.probing_multiplier = params.fixed.probing_multiplier; detail::SeekPastHeader(backing.file.get(), params); To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); - std::size_t memory_size = To::Size(params.counts, new_config); + uint64_t memory_size = To::Size(params.counts, new_config); uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); to.InitializeFromBinary(start, params, new_config, backing.file.get()); } else { diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index c2ca1101..2b8c9d5b 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -8,10 +8,11 @@ #include <math.h> #include <stdlib.h> -#include <unistd.h> #ifdef WIN32 #include "util/getopt.hh" +#else +#include <unistd.h> #endif namespace lm { @@ -86,16 +87,16 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector<uint64_t> counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t sizes[6]; + uint64_t sizes[6]; sizes[0] = ProbingModel::Size(counts, config); sizes[1] = RestProbingModel::Size(counts, config); sizes[2] = TrieModel::Size(counts, config); sizes[3] = QuantTrieModel::Size(counts, config); sizes[4] = ArrayTrieModel::Size(counts, config); sizes[5] = QuantArrayTrieModel::Size(counts, config); - std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); - std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); - std::size_t divide; + uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); + uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); + uint64_t divide; char prefix; if (min_length < (1 << 10) * 10) { prefix = ' '; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index bc8687cd..989f8324 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -8,5 +8,5 @@ #define KENLM_MAX_ORDER 6 #endif #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "Edit klm/lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif diff --git a/klm/lm/model.cc b/klm/lm/model.cc index b46333a4..40af8a63 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -12,6 +12,7 @@ #include <functional> #include <numeric> #include <cmath> +#include <limits> namespace lm { namespace ngram { @@ -19,17 +20,18 @@ namespace detail { template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType; -template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) { +template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) { return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) { + size_t goal_size = util::CheckOverflow(Size(counts, config)); uint8_t *start = static_cast<uint8_t*>(base); size_t allocated = VocabularyT::Size(counts[0], config); vocab_.SetupMemory(start, allocated, counts[0], config); start += allocated; start = search_.SetupMemory(start, counts, config); - if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config)); + if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size); } template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { @@ -49,13 +51,18 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge } namespace { -void CheckMaxOrder(size_t order) { - UTIL_THROW_IF(order > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << order << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE); +void CheckCounts(const std::vector<uint64_t> &counts) { + UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE); + if (sizeof(uint64_t) > sizeof(std::size_t)) { + for (std::vector<uint64_t>::const_iterator i = counts.begin(); i != counts.end(); ++i) { + UTIL_THROW_IF(*i > static_cast<uint64_t>(std::numeric_limits<size_t>::max()), util::OverflowException, "This model has " << *i << " " << (i - counts.begin() + 1) << "-grams which is too many for 32-bit machines."); + } + } } } // namespace template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { - CheckMaxOrder(params.counts.size()); + CheckCounts(params.counts); SetupMemory(start, params.counts, config); vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab); search_.LoadedBinary(); @@ -68,11 +75,11 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT std::vector<uint64_t> counts; // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. ReadARPACounts(f, counts); - CheckMaxOrder(counts.size()); + CheckCounts(counts); if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); - std::size_t vocab_size = VocabularyT::Size(counts[0], config); + std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 6dee9419..13ff864e 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -41,7 +41,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod * does not include small non-mapped control structures, such as this class * itself. */ - static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); + static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); /* Load the model from a file. It may be an ARPA or binary file. Binary * files must have the format expected by this class or you'll get an diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index abed0112..8ce2378a 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -24,7 +24,7 @@ class DontQuantize { public: static const ModelType kModelTypeAdd = static_cast<ModelType>(0); static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} - static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } + static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } static uint8_t LongestBits(const Config &/*config*/) { return 31; } @@ -138,9 +138,9 @@ class SeparatelyQuantize { static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); - static std::size_t Size(uint8_t order, const Config &config) { - size_t longest_table = (static_cast<size_t>(1) << static_cast<size_t>(config.prob_bits)) * sizeof(float); - size_t middle_table = (static_cast<size_t>(1) << static_cast<size_t>(config.backoff_bits)) * sizeof(float) + longest_table; + static uint64_t Size(uint8_t order, const Config &config) { + uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float); + uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table; // unigrams are currently not quantized so no need for a table. return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8; } diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 70727e4c..b709fef9 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -2,12 +2,13 @@ #include "lm/blank.hh" +#include <cmath> #include <cstdlib> #include <iostream> +#include <sstream> #include <vector> #include <ctype.h> -#include <math.h> #include <string.h> #include <stdint.h> @@ -31,6 +32,15 @@ bool IsEntirelyWhiteSpace(const StringPiece &line) { const char kBinaryMagic[] = "mmap lm http://kheafield.com/code"; +// strtoull isn't portable enough :-( +uint64_t ReadCount(const std::string &from) { + std::stringstream stream(from); + uint64_t ret; + stream >> ret; + UTIL_THROW_IF(!stream, FormatLoadException, "Bad count " << from); + return ret; +} + } // namespace void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { @@ -52,15 +62,11 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { // So strtol doesn't go off the end of line. std::string remaining(line.data() + 6, line.size() - 6); char *end_ptr; - unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10); + unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10); if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line); if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line); ++end_ptr; - const char *start = end_ptr; - long int count = std::strtol(start, &end_ptr, 10); - if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count); - if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line); - number.push_back(count); + number.push_back(ReadCount(end_ptr)); } } @@ -103,7 +109,7 @@ void ReadBackoff(util::FilePiece &in, float &backoff) { int float_class = _fpclass(backoff); UTIL_THROW_IF(float_class == _FPCLASS_SNAN || float_class == _FPCLASS_QNAN || float_class == _FPCLASS_NINF || float_class == _FPCLASS_PINF, FormatLoadException, "Bad backoff " << backoff); #else - int float_class = fpclassify(backoff); + int float_class = std::fpclassify(backoff); UTIL_THROW_IF(float_class == FP_NAN || float_class == FP_INFINITE, FormatLoadException, "Bad backoff " << backoff); #endif } diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 13942309..a1623834 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -234,7 +234,7 @@ template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, c ApplyBuild(f, counts, config, vocab, warn, build); } -template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { +template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { switch (config.rest_function) { case Config::REST_MAX: { diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 7e8c1220..a52f107b 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -74,8 +74,8 @@ template <class Value> class HashedSearch { // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} - static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { - std::size_t ret = Unigram::Size(counts[0]); + static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { + uint64_t ret = Unigram::Size(counts[0]); for (unsigned char n = 1; n < counts.size() - 1; ++n) { ret += Middle::Size(counts[n], config.probing_multiplier); } @@ -160,8 +160,8 @@ template <class Value> class HashedSearch { #endif {} - static std::size_t Size(uint64_t count) { - return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk> + static uint64_t Size(uint64_t count) { + return (count + 1) * sizeof(typename Value::Weights); // +1 for hallucinate <unk> } const typename Value::Weights &Lookup(WordIndex index) const { diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 832cc9f7..debcfd07 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -89,7 +89,7 @@ class BackoffMessages { if (!HasExtension(weights.backoff)) { weights.backoff = kExtensionBackoff; UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed."); - WriteOrThrow(unigrams, &weights, sizeof(weights)); + util::WriteOrThrow(unigrams, &weights, sizeof(weights)); } const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex)); base[write_to.array][write_to.index] += weights.backoff; diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 10b22ab1..1264baf5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -44,8 +44,8 @@ template <class Quant, class Bhiksha> class TrieSearch { Bhiksha::UpdateConfigFromBinary(fd, config); } - static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { - std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); + static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { + uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); for (unsigned char i = 1; i < counts.size() - 1; ++i) { ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config); } diff --git a/klm/lm/sri_test.cc b/klm/lm/sri_test.cc deleted file mode 100644 index e697d722..00000000 --- a/klm/lm/sri_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include "lm/sri.hh" - -#include <stdlib.h> - -#define BOOST_TEST_MODULE SRITest -#include <boost/test/unit_test.hpp> - -namespace lm { -namespace sri { -namespace { - -#define StartTest(word, ngram, score) \ - ret = model.FullScore( \ - state, \ - model.GetVocabulary().Index(word), \ - out);\ - BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ - BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \ - BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); - -#define AppendTest(word, ngram, score) \ - StartTest(word, ngram, score) \ - state = out; - -template <class M> void Starters(M &model) { - FullScoreReturn ret; - Model::State state(model.BeginSentenceState()); - Model::State out; - - StartTest("looking", 2, -0.4846522); - - // , probability plus <s> backoff - StartTest(",", 1, -1.383514 + -0.4149733); - // <unk> probability plus <s> backoff - StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); -} - -template <class M> void Continuation(M &model) { - FullScoreReturn ret; - Model::State state(model.BeginSentenceState()); - Model::State out; - - AppendTest("looking", 2, -0.484652); - AppendTest("on", 3, -0.348837); - AppendTest("a", 4, -0.0155266); - AppendTest("little", 5, -0.00306122); - State preserve = state; - AppendTest("the", 1, -4.04005); - AppendTest("biarritz", 1, -1.9889); - AppendTest("not_found", 0, -2.29666); - AppendTest("more", 1, -1.20632); - AppendTest(".", 2, -0.51363); - AppendTest("</s>", 3, -0.0191651); - - state = preserve; - AppendTest("more", 5, -0.00181395); - AppendTest("loin", 5, -0.0432557); -} - -BOOST_AUTO_TEST_CASE(starters) { Model m("test.arpa", 5); Starters(m); } -BOOST_AUTO_TEST_CASE(continuation) { Model m("test.arpa", 5); Continuation(m); } - -} // namespace -} // namespace sri -} // namespace lm diff --git a/klm/lm/state.hh b/klm/lm/state.hh index 830e40aa..551510a8 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -47,6 +47,8 @@ class State { unsigned char length; }; +typedef State Right; + inline uint64_t hash_value(const State &state, uint64_t seed = 0) { return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed); } diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 0f1ca574..d9895f89 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -36,7 +36,7 @@ bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_ } } // namespace -std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { +uint64_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits; // Extra entry for next pointer at the end. // +7 then / 8 to round up bits and convert to bytes @@ -57,7 +57,7 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) max_vocab_ = max_vocab; } -template <class Bhiksha> std::size_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { +template <class Bhiksha> uint64_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config)); } diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 034a1414..9ea3c546 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -49,7 +49,7 @@ class Unigram { unigram_ = static_cast<UnigramValue*>(start); } - static std::size_t Size(uint64_t count) { + static uint64_t Size(uint64_t count) { // +1 in case unknown doesn't appear. +1 for the final next. return (count + 2) * sizeof(UnigramValue); } @@ -84,7 +84,7 @@ class BitPacked { } protected: - static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); + static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); @@ -99,7 +99,7 @@ class BitPacked { template <class Bhiksha> class BitPackedMiddle : public BitPacked { public: - static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); + static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); // next_source need not be initialized. BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); @@ -128,7 +128,7 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked { class BitPackedLongest : public BitPacked { public: - static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { + static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { return BaseSize(entries, max_vocab, quant_bits); } diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 0d83221e..8663e94e 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -22,12 +22,6 @@ namespace lm { namespace ngram { namespace trie { - -void WriteOrThrow(FILE *to, const void *data, size_t size) { - assert(size); - if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); -} - namespace { typedef util::SizedIterator NGramIter; @@ -95,12 +89,12 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. if (context_begin == context_end) return out.release(); PartialIter i(context_begin); - WriteOrThrow(out.get(), i->Data(), context_size); + util::WriteOrThrow(out.get(), i->Data(), context_size); const void *previous = i->Data(); ++i; for (; i != context_end; ++i) { if (memcmp(previous, i->Data(), context_size)) { - WriteOrThrow(out.get(), i->Data(), context_size); + util::WriteOrThrow(out.get(), i->Data(), context_size); previous = i->Data(); } } @@ -116,7 +110,7 @@ struct ThrowCombine { // Useful for context files that just contain records with no value. struct FirstCombine { void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const { - WriteOrThrow(out, first, entry_size); + util::WriteOrThrow(out, first, entry_size); } }; @@ -129,10 +123,10 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f EntryCompare less(order); while (first && second) { if (less(first.Data(), second.Data())) { - WriteOrThrow(out_file.get(), first.Data(), entry_size); + util::WriteOrThrow(out_file.get(), first.Data(), entry_size); ++first; } else if (less(second.Data(), first.Data())) { - WriteOrThrow(out_file.get(), second.Data(), entry_size); + util::WriteOrThrow(out_file.get(), second.Data(), entry_size); ++second; } else { combine(entry_size, first.Data(), second.Data(), out_file.get()); @@ -140,7 +134,7 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f } } for (RecordReader &remains = (first ? first : second); remains; ++remains) { - WriteOrThrow(out_file.get(), remains.Data(), entry_size); + util::WriteOrThrow(out_file.get(), remains.Data(), entry_size); } return out_file.release(); } @@ -164,7 +158,7 @@ void RecordReader::Init(FILE *file, std::size_t entry_size) { void RecordReader::Overwrite(const void *start, std::size_t amount) { long internal = (uint8_t*)start - (uint8_t*)data_.get(); UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); - WriteOrThrow(file_, start, amount); + util::WriteOrThrow(file_, start, amount); long forward = entry_size_ - internal - amount; #if !defined(_WIN32) && !defined(_WIN64) if (forward) diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh index 1e6fce51..2197b80c 100644 --- a/klm/lm/trie_sort.hh +++ b/klm/lm/trie_sort.hh @@ -29,8 +29,6 @@ struct Config; namespace trie { -void WriteOrThrow(FILE *to, const void *data, size_t size); - class EntryCompare : public std::binary_function<const void*, const void*, bool> { public: explicit EntryCompare(unsigned char order) : order_(order) {} diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 5de68f16..398475be 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -87,7 +87,7 @@ void WriteWordsWrapper::Write(int fd) { SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} -std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { +uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) { // Lead with the number of entries. return sizeof(uint64_t) + sizeof(uint64_t) * entries; } @@ -165,7 +165,7 @@ struct ProbingVocabularyHeader { ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} -std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { +uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) { return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index a25432f9..074cd446 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -62,7 +62,7 @@ class SortedVocabulary : public base::Vocabulary { } // Size for purposes of file writing - static size_t Size(std::size_t entries, const Config &config); + static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. WordIndex Bound() const { return bound_; } @@ -129,7 +129,7 @@ class ProbingVocabulary : public base::Vocabulary { return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; } - static size_t Size(std::size_t entries, const Config &config); + static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()). WordIndex Bound() const { return bound_; } diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh index 67841c30..e09557a7 100644 --- a/klm/lm/word_index.hh +++ b/klm/lm/word_index.hh @@ -2,8 +2,11 @@ #ifndef LM_WORD_INDEX__ #define LM_WORD_INDEX__ +#include <limits.h> + namespace lm { typedef unsigned int WordIndex; +const WordIndex kMaxWordIndex = UINT_MAX; } // namespace lm typedef lm::WordIndex LMWordIndex; |