diff options
Diffstat (limited to 'klm/lm')
37 files changed, 800 insertions, 487 deletions
diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc index 088ea98d..c8a18dfd 100644 --- a/klm/lm/bhiksha.cc +++ b/klm/lm/bhiksha.cc @@ -1,4 +1,6 @@ #include "lm/bhiksha.hh" + +#include "lm/binary_format.hh" #include "lm/config.hh" #include "util/file.hh" #include "util/exception.hh" @@ -15,11 +17,11 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_ const uint8_t kArrayBhikshaVersion = 0; // TODO: put this in binary file header instead when I change the binary file format again. -void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { - uint8_t version; - uint8_t configured_bits; - util::ReadOrThrow(fd, &version, 1); - util::ReadOrThrow(fd, &configured_bits, 1); +void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { + uint8_t buffer[2]; + file.ReadForConfig(buffer, 2, offset); + uint8_t version = buffer[0]; + uint8_t configured_bits = buffer[1]; if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); config.pointer_bhiksha_bits = configured_bits; } @@ -87,9 +89,6 @@ void ArrayBhiksha::FinishedLoading(const Config &config) { *(head_write++) = config.pointer_bhiksha_bits; } -void ArrayBhiksha::LoadedBinary() { -} - } // namespace trie } // namespace ngram } // namespace lm diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index 8ff88654..350571a6 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -24,6 +24,7 @@ namespace lm { namespace ngram { struct Config; +class BinaryFormat; namespace trie { @@ -31,7 +32,7 @@ class DontBhiksha { public: static const ModelType kModelTypeAdd = static_cast<ModelType>(0); - static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} + static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &/*config*/) {} static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } @@ -53,8 +54,6 @@ class DontBhiksha { void FinishedLoading(const Config &/*config*/) {} - void LoadedBinary() {} - uint8_t InlineBits() const { return next_.bits; } private: @@ -65,7 +64,7 @@ class ArrayBhiksha { public: static const ModelType kModelTypeAdd = kArrayAdd; - static void UpdateConfigFromBinary(int fd, Config &config); + static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); @@ -93,8 +92,6 @@ class ArrayBhiksha { void FinishedLoading(const Config &config); - void LoadedBinary(); - uint8_t InlineBits() const { return next_inline_.bits; } private: diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 39c4a9b6..9c744b13 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -8,11 +8,15 @@ #include <cstring> #include <limits> #include <string> +#include <cstdlib> #include <stdint.h> namespace lm { namespace ngram { + +const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; + namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; @@ -57,8 +61,6 @@ struct Sanity { } }; -const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; - std::size_t TotalHeaderSize(unsigned char order) { return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); } @@ -80,83 +82,6 @@ void WriteHeader(void *to, const Parameters ¶ms) { } // namespace -uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { - if (config.write_mmap) { - std::size_t total = TotalHeaderSize(order) + memory_size; - backing.file.reset(util::CreateOrThrow(config.write_mmap)); - if (config.write_method == Config::WRITE_MMAP) { - backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED); - } else { - util::ResizeOrThrow(backing.file.get(), 0); - util::MapAnonymous(total, backing.vocab); - } - strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order)); - return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order); - } else { - util::MapAnonymous(memory_size, backing.vocab); - return reinterpret_cast<uint8_t*>(backing.vocab.get()); - } -} - -uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { - std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; - if (config.write_mmap) { - // Grow the file to accomodate the search, using zeros. - try { - util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size); - } catch (util::ErrnoException &e) { - e << " for file " << config.write_mmap; - throw e; - } - - if (config.write_method == Config::WRITE_AFTER) { - util::MapAnonymous(memory_size, backing.search); - return reinterpret_cast<uint8_t*>(backing.search.get()); - } - // mmap it now. - // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. - std::size_t page_size = util::SizePage(); - std::size_t alignment_cruft = adjusted_vocab % page_size; - backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); - return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft; - } else { - util::MapAnonymous(memory_size, backing.search); - return reinterpret_cast<uint8_t*>(backing.search.get()); - } -} - -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { - if (!config.write_mmap) return; - switch (config.write_method) { - case Config::WRITE_MMAP: - util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); - util::SyncOrThrow(backing.search.get(), backing.search.size()); - break; - case Config::WRITE_AFTER: - util::SeekOrThrow(backing.file.get(), 0); - util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size()); - util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad); - util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size()); - util::FSyncOrThrow(backing.file.get()); - break; - } - // header and vocab share the same mmap. The header is written here because we know the counts. - Parameters params = Parameters(); - params.counts = counts; - params.fixed.order = counts.size(); - params.fixed.probing_multiplier = config.probing_multiplier; - params.fixed.model_type = model_type; - params.fixed.has_vocabulary = config.include_vocab; - params.fixed.search_version = search_version; - WriteHeader(backing.vocab.get(), params); - if (config.write_method == Config::WRITE_AFTER) { - util::SeekOrThrow(backing.file.get(), 0); - util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size())); - } -} - -namespace detail { - bool IsBinaryFormat(int fd) { const uint64_t size = util::SizeFile(fd); if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false; @@ -169,21 +94,21 @@ bool IsBinaryFormat(int fd) { } Sanity reference_header = Sanity(); reference_header.SetToReference(); - if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true; - if (!memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) { + if (!std::memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true; + if (!std::memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) { UTIL_THROW(FormatLoadException, "This binary file did not finish building"); } - if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) { + if (!std::memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) { char *end_ptr; const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion); - long int version = strtol(begin_version, &end_ptr, 10); + long int version = std::strtol(begin_version, &end_ptr, 10); if ((end_ptr != begin_version) && version != kMagicVersion) { UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary"); } OldSanity old_sanity = OldSanity(); old_sanity.SetToReference(); - UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable."); + UTIL_THROW_IF(!std::memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable."); UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture"); } return false; @@ -208,44 +133,164 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version); } -void SeekPastHeader(int fd, const Parameters ¶ms) { - util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); +const std::size_t kInvalidSize = static_cast<std::size_t>(-1); + +BinaryFormat::BinaryFormat(const Config &config) + : write_method_(config.write_method), write_mmap_(config.write_mmap), load_method_(config.load_method), + header_size_(kInvalidSize), vocab_size_(kInvalidSize), vocab_string_offset_(kInvalidOffset) {} + +void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters ¶ms) { + file_.reset(fd); + write_mmap_ = NULL; // Ignore write requests; this is already in binary format. + ReadHeader(fd, params); + MatchCheck(model_type, search_version, params); + header_size_ = TotalHeaderSize(params.counts.size()); +} + +void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const { + assert(header_size_ != kInvalidSize); + util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_); } -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) { - const uint64_t file_size = util::SizeFile(backing.file.get()); +void *BinaryFormat::LoadBinary(std::size_t size) { + assert(header_size_ != kInvalidSize); + const uint64_t file_size = util::SizeFile(file_.get()); // The header is smaller than a page, so we have to map the whole header as well. - std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size); - if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) - UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); + uint64_t total_map = static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(size); + UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); - util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search); + util::MapRead(load_method_, file_.get(), 0, util::CheckOverflow(total_map), mapping_); - if (config.enumerate_vocab && !params.fixed.has_vocabulary) - UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); + vocab_string_offset_ = total_map; + return reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_; +} + +void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) { + vocab_size_ = memory_size; + if (!write_mmap_) { + header_size_ = 0; + util::MapAnonymous(memory_size, memory_vocab_); + return reinterpret_cast<uint8_t*>(memory_vocab_.get()); + } + header_size_ = TotalHeaderSize(order); + std::size_t total = util::CheckOverflow(static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(memory_size)); + file_.reset(util::CreateOrThrow(write_mmap_)); + // some gccs complain about uninitialized variables even though all enum values are covered. + void *vocab_base = NULL; + switch (write_method_) { + case Config::WRITE_MMAP: + mapping_.reset(util::MapZeroedWrite(file_.get(), total), total, util::scoped_memory::MMAP_ALLOCATED); + vocab_base = mapping_.get(); + break; + case Config::WRITE_AFTER: + util::ResizeOrThrow(file_.get(), 0); + util::MapAnonymous(total, memory_vocab_); + vocab_base = memory_vocab_.get(); + break; + } + strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_); + return reinterpret_cast<uint8_t*>(vocab_base) + header_size_; +} - // Seek to vocabulary words - util::SeekOrThrow(backing.file.get(), total_map); - return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size()); +void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base) { + assert(vocab_size_ != kInvalidSize); + vocab_pad_ = vocab_pad; + std::size_t new_size = header_size_ + vocab_size_ + vocab_pad_ + memory_size; + vocab_string_offset_ = new_size; + if (!write_mmap_ || write_method_ == Config::WRITE_AFTER) { + util::MapAnonymous(memory_size, memory_search_); + assert(header_size_ == 0 || write_mmap_); + vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_; + return reinterpret_cast<uint8_t*>(memory_search_.get()); + } + + assert(write_method_ == Config::WRITE_MMAP); + // Also known as total size without vocab words. + // Grow the file to accomodate the search, using zeros. + // According to man mmap, behavior is undefined when the file is resized + // underneath a mmap that is not a multiple of the page size. So to be + // safe, we'll unmap it and map it again. + mapping_.reset(); + util::ResizeOrThrow(file_.get(), new_size); + void *ret; + MapFile(vocab_base, ret); + return ret; } -void ComplainAboutARPA(const Config &config, ModelType model_type) { - if (config.write_mmap || !config.messages) return; - if (config.arpa_complain == Config::ALL) { - *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; - } else if (config.arpa_complain == Config::EXPENSIVE && - (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { - *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; +void BinaryFormat::WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base) { + // Checking Config's include_vocab is the responsibility of the caller. + assert(header_size_ != kInvalidSize && vocab_size_ != kInvalidSize); + if (!write_mmap_) { + // Unchanged base. + vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()); + search_base = reinterpret_cast<uint8_t*>(memory_search_.get()); + return; + } + if (write_method_ == Config::WRITE_MMAP) { + mapping_.reset(); + } + util::SeekOrThrow(file_.get(), VocabStringReadingOffset()); + util::WriteOrThrow(file_.get(), &buffer[0], buffer.size()); + if (write_method_ == Config::WRITE_MMAP) { + MapFile(vocab_base, search_base); + } else { + vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_; + search_base = reinterpret_cast<uint8_t*>(memory_search_.get()); + } +} + +void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts) { + if (!write_mmap_) return; + switch (write_method_) { + case Config::WRITE_MMAP: + util::SyncOrThrow(mapping_.get(), mapping_.size()); + break; + case Config::WRITE_AFTER: + util::SeekOrThrow(file_.get(), 0); + util::WriteOrThrow(file_.get(), memory_vocab_.get(), memory_vocab_.size()); + util::SeekOrThrow(file_.get(), header_size_ + vocab_size_ + vocab_pad_); + util::WriteOrThrow(file_.get(), memory_search_.get(), memory_search_.size()); + util::FSyncOrThrow(file_.get()); + break; + } + // header and vocab share the same mmap. + Parameters params = Parameters(); + memset(¶ms, 0, sizeof(Parameters)); + params.counts = counts; + params.fixed.order = counts.size(); + params.fixed.probing_multiplier = config.probing_multiplier; + params.fixed.model_type = model_type; + params.fixed.has_vocabulary = config.include_vocab; + params.fixed.search_version = search_version; + switch (write_method_) { + case Config::WRITE_MMAP: + WriteHeader(mapping_.get(), params); + util::SyncOrThrow(mapping_.get(), mapping_.size()); + break; + case Config::WRITE_AFTER: + { + std::vector<uint8_t> buffer(TotalHeaderSize(counts.size())); + WriteHeader(&buffer[0], params); + util::SeekOrThrow(file_.get(), 0); + util::WriteOrThrow(file_.get(), &buffer[0], buffer.size()); + } + break; } } -} // namespace detail +void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) { + mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED); + vocab_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_; + search_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_; +} bool RecognizeBinary(const char *file, ModelType &recognized) { util::scoped_fd fd(util::OpenReadOrThrow(file)); - if (!detail::IsBinaryFormat(fd.get())) return false; + if (!IsBinaryFormat(fd.get())) { + return false; + } Parameters params; - detail::ReadHeader(fd.get(), params); + ReadHeader(fd.get(), params); recognized = params.fixed.model_type; return true; } diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index bf699d5f..f33f88d7 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -17,6 +17,8 @@ namespace lm { namespace ngram { +extern const char *kModelNames[6]; + /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in * this header designed for use by decoder authors. @@ -42,67 +44,63 @@ struct Parameters { std::vector<uint64_t> counts; }; -struct Backing { - // File behind memory, if any. - util::scoped_fd file; - // Vocabulary lookup table. Not to be confused with the vocab words themselves. - util::scoped_memory vocab; - // Raw block of memory backing the language model data structures - util::scoped_memory search; -}; - -// Create just enough of a binary file to write vocabulary to it. -uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); -// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. -uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing); - -// Write header to binary file. This is done last to prevent incomplete files -// from loading. -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing); +class BinaryFormat { + public: + explicit BinaryFormat(const Config &config); + + // Reading a binary file: + // Takes ownership of fd + void InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters ¶ms); + // Used to read parts of the file to update the config object before figuring out full size. + void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const; + // Actually load the binary file and return a pointer to the beginning of the search area. + void *LoadBinary(std::size_t size); + + uint64_t VocabStringReadingOffset() const { + assert(vocab_string_offset_ != kInvalidOffset); + return vocab_string_offset_; + } -namespace detail { + // Writing a binary file or initializing in RAM from ARPA: + // Size for vocabulary. + void *SetupJustVocab(std::size_t memory_size, uint8_t order); + // Warning: can change the vocaulary base pointer. + void *GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base); + // Warning: can change vocabulary and search base addresses. + void WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base); + // Write the header at the beginning of the file. + void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts); + + private: + void MapFile(void *&vocab_base, void *&search_base); + + // Copied from configuration. + const Config::WriteMethod write_method_; + const char *write_mmap_; + util::LoadMethod load_method_; + + // File behind memory, if any. + util::scoped_fd file_; + + // If there is a file involved, a single mapping. + util::scoped_memory mapping_; + + // If the data is only in memory, separately allocate each because the trie + // knows vocab's size before it knows search's size (because SRILM might + // have pruned). + util::scoped_memory memory_vocab_, memory_search_; + + // Memory ranges. Note that these may not be contiguous and may not all + // exist. + std::size_t header_size_, vocab_size_, vocab_pad_; + // aka end of search. + uint64_t vocab_string_offset_; + + static const uint64_t kInvalidOffset = (uint64_t)-1; +}; bool IsBinaryFormat(int fd); -void ReadHeader(int fd, Parameters ¶ms); - -void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms); - -void SeekPastHeader(int fd, const Parameters ¶ms); - -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing); - -void ComplainAboutARPA(const Config &config, ModelType model_type); - -} // namespace detail - -template <class To> void LoadLM(const char *file, const Config &config, To &to) { - Backing &backing = to.MutableBacking(); - backing.file.reset(util::OpenReadOrThrow(file)); - - try { - if (detail::IsBinaryFormat(backing.file.get())) { - Parameters params; - detail::ReadHeader(backing.file.get(), params); - detail::MatchCheck(To::kModelType, To::kVersion, params); - // Replace the run-time configured probing_multiplier with the one in the file. - Config new_config(config); - new_config.probing_multiplier = params.fixed.probing_multiplier; - detail::SeekPastHeader(backing.file.get(), params); - To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); - uint64_t memory_size = To::Size(params.counts, new_config); - uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); - to.InitializeFromBinary(start, params, new_config, backing.file.get()); - } else { - detail::ComplainAboutARPA(config, To::kModelType); - to.InitializeFromARPA(file, config); - } - } catch (util::Exception &e) { - e << " File: " << file; - throw; - } -} - } // namespace ngram } // namespace lm #endif // LM_BINARY_FORMAT__ diff --git a/klm/lm/build_binary_main.cc b/klm/lm/build_binary_main.cc index ab2c0c32..15b421e9 100644 --- a/klm/lm/build_binary_main.cc +++ b/klm/lm/build_binary_main.cc @@ -52,6 +52,7 @@ void Usage(const char *name, const char *default_mem) { "-a compresses pointers using an array of offsets. The parameter is the\n" " maximum number of bits encoded by the array. Memory is minimized subject\n" " to the maximum, so pick 255 to minimize memory.\n\n" +"-h print this help message.\n\n" "Get a memory estimate by passing an ARPA file without an output file name.\n"; exit(1); } @@ -104,12 +105,15 @@ int main(int argc, char *argv[]) { const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G"; + if (argc == 2 && !strcmp(argv[1], "--help")) + Usage(argv[0], default_mem); + try { bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false; lm::ngram::Config config; config.building_memory = util::ParseSize(default_mem); int opt; - while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) { + while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:h")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -161,6 +165,7 @@ int main(int argc, char *argv[]) { ParseFileList(optarg, config.rest_lower_files); config.rest_function = Config::REST_LOWER; break; + case 'h': // help default: Usage(argv[0], default_mem); } @@ -186,6 +191,7 @@ int main(int argc, char *argv[]) { config.write_mmap = argv[optind + 2]; } else { Usage(argv[0], default_mem); + return 1; } if (!strcmp(model_type, "probing")) { if (!set_write_method) config.write_method = Config::WRITE_AFTER; diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc index aea93ad1..ccc06efc 100644 --- a/klm/lm/builder/corpus_count.cc +++ b/klm/lm/builder/corpus_count.cc @@ -238,12 +238,17 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) { const WordIndex end_sentence = vocab.Lookup("</s>"); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; - StringPiece delimiters("\0\t\r ", 4); + bool delimiters[256]; + memset(delimiters, 0, sizeof(delimiters)); + const char kDelimiterSet[] = "\0\t\n\r "; + for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) { + delimiters[static_cast<unsigned char>(*i)] = true; + } try { while(true) { StringPiece line(from_.ReadLine()); writer.StartSentence(); - for (util::TokenIter<util::AnyCharacter, true> w(line, delimiters); w; ++w) { + for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) { WordIndex word = vocab.Lookup(*w); UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future."); writer.Append(word); diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc index c87abdb8..2563deed 100644 --- a/klm/lm/builder/lmplz_main.cc +++ b/klm/lm/builder/lmplz_main.cc @@ -33,7 +33,10 @@ int main(int argc, char *argv[]) { po::options_description options("Language model building options"); lm::builder::PipelineConfig pipeline; + std::string text, arpa; + options.add_options() + ("help", po::bool_switch(), "Show this help message") ("order,o", po::value<std::size_t>(&pipeline.order) #if BOOST_VERSION >= 104200 ->required() @@ -47,8 +50,13 @@ int main(int argc, char *argv[]) { ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table") ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)") ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") - ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); - if (argc == 1) { + ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.") + ("text", po::value<std::string>(&text), "Read text from a file instead of stdin") + ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout"); + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, options), vm); + + if (argc == 1 || vm["help"].as<bool>()) { std::cerr << "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n" "Please cite:\n" @@ -66,12 +74,17 @@ int main(int argc, char *argv[]) { "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n" "Memory sizes are specified like GNU sort: a number followed by a unit character.\n" "Valid units are \% for percentage of memory (supported platforms only) and (in\n" - "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n"; + "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n"; + uint64_t mem = util::GuessPhysicalMemory(); + if (mem) { + std::cerr << "This machine has " << mem << " bytes of memory.\n\n"; + } else { + std::cerr << "Unable to determine the amount of memory on this machine.\n\n"; + } std::cerr << options << std::endl; return 1; } - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, options), vm); + po::notify(vm); // required() appeared in Boost 1.42.0. @@ -92,9 +105,17 @@ int main(int argc, char *argv[]) { initial.adder_out.block_count = 2; pipeline.read_backoffs = initial.adder_out; + util::scoped_fd in(0), out(1); + if (vm.count("text")) { + in.reset(util::OpenReadOrThrow(text.c_str())); + } + if (vm.count("arpa")) { + out.reset(util::CreateOrThrow(arpa.c_str())); + } + // Read from stdin try { - lm::builder::Pipeline(pipeline, 0, 1); + lm::builder::Pipeline(pipeline, in.release(), out.release()); } catch (const util::MallocException &e) { std::cerr << e.what() << std::endl; std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl; diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc index b89ea6ba..44a2313c 100644 --- a/klm/lm/builder/pipeline.cc +++ b/klm/lm/builder/pipeline.cc @@ -226,6 +226,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); chain.Wait(true); + std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl; std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl; master.InitForAdjust(sorter, type_count); } diff --git a/klm/lm/facade.hh b/klm/lm/facade.hh index 8b186017..de1551f1 100644 --- a/klm/lm/facade.hh +++ b/klm/lm/facade.hh @@ -16,19 +16,28 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ typedef StateT State; typedef VocabularyT Vocabulary; - // Default Score function calls FullScore. Model can override this. - float Score(const State &in_state, const WordIndex new_word, State &out_state) const { - return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob; - } - /* Translate from void* to State */ - FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const { + FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const { return static_cast<const Child*>(this)->FullScore( *reinterpret_cast<const State*>(in_state), new_word, *reinterpret_cast<State*>(out_state)); } - float Score(const void *in_state, const WordIndex new_word, void *out_state) const { + + FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const { + return static_cast<const Child*>(this)->FullScoreForgotState( + context_rbegin, + context_rend, + new_word, + *reinterpret_cast<State*>(out_state)); + } + + // Default Score function calls FullScore. Model can override this. + float Score(const State &in_state, const WordIndex new_word, State &out_state) const { + return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob; + } + + float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const { return static_cast<const Child*>(this)->Score( *reinterpret_cast<const State*>(in_state), new_word, diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh index 5b31620b..602b5b31 100644 --- a/klm/lm/filter/arpa_io.hh +++ b/klm/lm/filter/arpa_io.hh @@ -14,7 +14,6 @@ #include <string> #include <vector> -#include <err.h> #include <string.h> #include <stdint.h> diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh index 97c0fa25..d992026f 100644 --- a/klm/lm/filter/count_io.hh +++ b/klm/lm/filter/count_io.hh @@ -5,20 +5,18 @@ #include <iostream> #include <string> -#include <err.h> - +#include "util/fake_ofstream.hh" +#include "util/file.hh" #include "util/file_piece.hh" namespace lm { class CountOutput : boost::noncopyable { public: - explicit CountOutput(const char *name) : file_(name, std::ios::out) {} + explicit CountOutput(const char *name) : file_(util::CreateOrThrow(name)) {} void AddNGram(const StringPiece &line) { - if (!(file_ << line << '\n')) { - err(3, "Writing counts file failed"); - } + file_ << line << '\n'; } template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { @@ -30,7 +28,7 @@ class CountOutput : boost::noncopyable { } private: - std::fstream file_; + util::FakeOFStream file_; }; class CountBatch { diff --git a/klm/lm/filter/filter_main.cc b/klm/lm/filter/filter_main.cc index 1736bc40..82fdc1ef 100644 --- a/klm/lm/filter/filter_main.cc +++ b/klm/lm/filter/filter_main.cc @@ -6,6 +6,7 @@ #endif #include "lm/filter/vocab.hh" #include "lm/filter/wrapper.hh" +#include "util/exception.hh" #include "util/file_piece.hh" #include <boost/ptr_container/ptr_vector.hpp> @@ -157,92 +158,96 @@ template <class Format> void DispatchFilterModes(const Config &config, std::istr } // namespace lm int main(int argc, char *argv[]) { - if (argc < 4) { - lm::DisplayHelp(argv[0]); - return 1; - } + try { + if (argc < 4) { + lm::DisplayHelp(argv[0]); + return 1; + } - // I used to have boost::program_options, but some users didn't want to compile boost. - lm::Config config; - config.mode = lm::MODE_UNSET; - for (int i = 1; i < argc - 2; ++i) { - const char *str = argv[i]; - if (!std::strcmp(str, "copy")) { - config.mode = lm::MODE_COPY; - } else if (!std::strcmp(str, "single")) { - config.mode = lm::MODE_SINGLE; - } else if (!std::strcmp(str, "multiple")) { - config.mode = lm::MODE_MULTIPLE; - } else if (!std::strcmp(str, "union")) { - config.mode = lm::MODE_UNION; - } else if (!std::strcmp(str, "phrase")) { - config.phrase = true; - } else if (!std::strcmp(str, "context")) { - config.context = true; - } else if (!std::strcmp(str, "arpa")) { - config.format = lm::FORMAT_ARPA; - } else if (!std::strcmp(str, "raw")) { - config.format = lm::FORMAT_COUNT; + // I used to have boost::program_options, but some users didn't want to compile boost. + lm::Config config; + config.mode = lm::MODE_UNSET; + for (int i = 1; i < argc - 2; ++i) { + const char *str = argv[i]; + if (!std::strcmp(str, "copy")) { + config.mode = lm::MODE_COPY; + } else if (!std::strcmp(str, "single")) { + config.mode = lm::MODE_SINGLE; + } else if (!std::strcmp(str, "multiple")) { + config.mode = lm::MODE_MULTIPLE; + } else if (!std::strcmp(str, "union")) { + config.mode = lm::MODE_UNION; + } else if (!std::strcmp(str, "phrase")) { + config.phrase = true; + } else if (!std::strcmp(str, "context")) { + config.context = true; + } else if (!std::strcmp(str, "arpa")) { + config.format = lm::FORMAT_ARPA; + } else if (!std::strcmp(str, "raw")) { + config.format = lm::FORMAT_COUNT; #ifndef NTHREAD - } else if (!std::strncmp(str, "threads:", 8)) { - config.threads = boost::lexical_cast<size_t>(str + 8); - if (!config.threads) { - std::cerr << "Specify at least one thread." << std::endl; + } else if (!std::strncmp(str, "threads:", 8)) { + config.threads = boost::lexical_cast<size_t>(str + 8); + if (!config.threads) { + std::cerr << "Specify at least one thread." << std::endl; + return 1; + } + } else if (!std::strncmp(str, "batch_size:", 11)) { + config.batch_size = boost::lexical_cast<size_t>(str + 11); + if (config.batch_size < 5000) { + std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; + if (!config.batch_size) return 1; + } +#endif + } else { + lm::DisplayHelp(argv[0]); return 1; } - } else if (!std::strncmp(str, "batch_size:", 11)) { - config.batch_size = boost::lexical_cast<size_t>(str + 11); - if (config.batch_size < 5000) { - std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; - if (!config.batch_size) return 1; - } -#endif - } else { + } + + if (config.mode == lm::MODE_UNSET) { lm::DisplayHelp(argv[0]); return 1; } - } - - if (config.mode == lm::MODE_UNSET) { - lm::DisplayHelp(argv[0]); - return 1; - } - if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { - std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; - return 1; - } + if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { + std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; + return 1; + } - bool cmd_is_model = true; - const char *cmd_input = argv[argc - 2]; - if (!strncmp(cmd_input, "vocab:", 6)) { - cmd_is_model = false; - cmd_input += 6; - } else if (!strncmp(cmd_input, "model:", 6)) { - cmd_input += 6; - } else if (strchr(cmd_input, ':')) { - errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); - } else { - std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; - } - std::ifstream cmd_file; - std::istream *vocab; - if (cmd_is_model) { - vocab = &std::cin; - } else { - cmd_file.open(cmd_input, std::ios::in); - if (!cmd_file) { - err(2, "Could not open input file %s", cmd_input); + bool cmd_is_model = true; + const char *cmd_input = argv[argc - 2]; + if (!strncmp(cmd_input, "vocab:", 6)) { + cmd_is_model = false; + cmd_input += 6; + } else if (!strncmp(cmd_input, "model:", 6)) { + cmd_input += 6; + } else if (strchr(cmd_input, ':')) { + std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl; + return 1; + } else { + std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; + } + std::ifstream cmd_file; + std::istream *vocab; + if (cmd_is_model) { + vocab = &std::cin; + } else { + cmd_file.open(cmd_input, std::ios::in); + UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input); + vocab = &cmd_file; } - vocab = &cmd_file; - } - util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); + util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); - if (config.format == lm::FORMAT_ARPA) { - lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); - } else if (config.format == lm::FORMAT_COUNT) { - lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); + if (config.format == lm::FORMAT_ARPA) { + lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); + } else if (config.format == lm::FORMAT_COUNT) { + lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); + } + return 0; + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; } - return 0; } diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh index 7f945b0d..7d8c28db 100644 --- a/klm/lm/filter/format.hh +++ b/klm/lm/filter/format.hh @@ -1,5 +1,5 @@ #ifndef LM_FILTER_FORMAT_H__ -#define LM_FITLER_FORMAT_H__ +#define LM_FILTER_FORMAT_H__ #include "lm/filter/arpa_io.hh" #include "lm/filter/count_io.hh" diff --git a/klm/lm/filter/phrase.cc b/klm/lm/filter/phrase.cc index 1bef2a3f..e2946b14 100644 --- a/klm/lm/filter/phrase.cc +++ b/klm/lm/filter/phrase.cc @@ -48,21 +48,21 @@ unsigned int ReadMultiple(std::istream &in, Substrings &out) { return sentence_id + sentence_content; } -namespace detail { const StringPiece kEndSentence("</s>"); } - namespace { - typedef unsigned int Sentence; typedef std::vector<Sentence> Sentences; +} // namespace -class Vertex; +namespace detail { + +const StringPiece kEndSentence("</s>"); class Arc { public: Arc() {} // For arcs from one vertex to another. - void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) { + void SetPhrase(detail::Vertex &from, detail::Vertex &to, const Sentences &intersect) { Set(to, intersect); from_ = &from; } @@ -71,7 +71,7 @@ class Arc { * aligned). These have no from_ vertex; it implictly matches every * sentence. This also handles when the n-gram is a substring of a phrase. */ - void SetRight(Vertex &to, const Sentences &complete) { + void SetRight(detail::Vertex &to, const Sentences &complete) { Set(to, complete); from_ = NULL; } @@ -97,11 +97,11 @@ class Arc { void LowerBound(const Sentence to); private: - void Set(Vertex &to, const Sentences &sentences); + void Set(detail::Vertex &to, const Sentences &sentences); const Sentence *current_; const Sentence *last_; - Vertex *from_; + detail::Vertex *from_; }; struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> { @@ -183,7 +183,13 @@ void Vertex::LowerBound(const Sentence to) { } } -void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Vertex *const vertices, Arc *free_arc) { +} // namespace detail + +namespace { + +void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, detail::Vertex *const vertices, detail::Arc *free_arc) { + using detail::Vertex; + using detail::Arc; assert(!hashes.empty()); const Hash *const first_word = &*hashes.begin(); @@ -231,17 +237,29 @@ void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Verte namespace detail { -} // namespace detail +// Here instead of header due to forward declaration. +ConditionCommon::ConditionCommon(const Substrings &substrings) : substrings_(substrings) {} -bool Union::Evaluate() { +// Rest of the variables are temporaries anyway +ConditionCommon::ConditionCommon(const ConditionCommon &from) : substrings_(from.substrings_) {} + +ConditionCommon::~ConditionCommon() {} + +detail::Vertex &ConditionCommon::MakeGraph() { assert(!hashes_.empty()); - // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable. - Vertex vertices[hashes_.size()]; + vertices_.clear(); + vertices_.resize(hashes_.size()); + arcs_.clear(); // One for every substring. - Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; - BuildGraph(substrings_, hashes_, vertices, arcs); - Vertex &last_vertex = vertices[hashes_.size() - 1]; + arcs_.resize(((hashes_.size() + 1) * hashes_.size()) / 2); + BuildGraph(substrings_, hashes_, &*vertices_.begin(), &*arcs_.begin()); + return vertices_[hashes_.size() - 1]; +} + +} // namespace detail +bool Union::Evaluate() { + detail::Vertex &last_vertex = MakeGraph(); unsigned int lower = 0; while (true) { last_vertex.LowerBound(lower); @@ -252,14 +270,7 @@ bool Union::Evaluate() { } template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) { - assert(!hashes_.empty()); - // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable. - Vertex vertices[hashes_.size()]; - // One for every substring. - Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; - BuildGraph(substrings_, hashes_, vertices, arcs); - Vertex &last_vertex = vertices[hashes_.size() - 1]; - + detail::Vertex &last_vertex = MakeGraph(); unsigned int lower = 0; while (true) { last_vertex.LowerBound(lower); diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh index b4edff41..e8e85835 100644 --- a/klm/lm/filter/phrase.hh +++ b/klm/lm/filter/phrase.hh @@ -103,11 +103,33 @@ template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std:: } } +class Vertex; +class Arc; + +class ConditionCommon { + protected: + ConditionCommon(const Substrings &substrings); + ConditionCommon(const ConditionCommon &from); + + ~ConditionCommon(); + + detail::Vertex &MakeGraph(); + + // Temporaries in PassNGram and Evaluate to avoid reallocation. + std::vector<Hash> hashes_; + + private: + std::vector<detail::Vertex> vertices_; + std::vector<detail::Arc> arcs_; + + const Substrings &substrings_; +}; + } // namespace detail -class Union { +class Union : public detail::ConditionCommon { public: - explicit Union(const Substrings &substrings) : substrings_(substrings) {} + explicit Union(const Substrings &substrings) : detail::ConditionCommon(substrings) {} template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) { detail::MakeHashes(begin, end, hashes_); @@ -116,23 +138,19 @@ class Union { private: bool Evaluate(); - - std::vector<Hash> hashes_; - - const Substrings &substrings_; }; -class Multiple { +class Multiple : public detail::ConditionCommon { public: - explicit Multiple(const Substrings &substrings) : substrings_(substrings) {} + explicit Multiple(const Substrings &substrings) : detail::ConditionCommon(substrings) {} template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { detail::MakeHashes(begin, end, hashes_); if (hashes_.empty()) { output.AddNGram(line); - return; + } else { + Evaluate(line, output); } - Evaluate(line, output); } template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { @@ -143,10 +161,6 @@ class Multiple { private: template <class Output> void Evaluate(const StringPiece &line, Output &output); - - std::vector<Hash> hashes_; - - const Substrings &substrings_; }; } // namespace phrase diff --git a/klm/lm/filter/phrase_table_vocab_main.cc b/klm/lm/filter/phrase_table_vocab_main.cc new file mode 100644 index 00000000..e0f47d89 --- /dev/null +++ b/klm/lm/filter/phrase_table_vocab_main.cc @@ -0,0 +1,165 @@ +#include "util/fake_ofstream.hh" +#include "util/file_piece.hh" +#include "util/murmur_hash.hh" +#include "util/pool.hh" +#include "util/string_piece.hh" +#include "util/string_piece_hash.hh" +#include "util/tokenize_piece.hh" + +#include <boost/unordered_map.hpp> +#include <boost/unordered_set.hpp> + +#include <cstddef> +#include <vector> + +namespace { + +struct MutablePiece { + mutable StringPiece behind; + bool operator==(const MutablePiece &other) const { + return behind == other.behind; + } +}; + +std::size_t hash_value(const MutablePiece &m) { + return hash_value(m.behind); +} + +class InternString { + public: + const char *Add(StringPiece str) { + MutablePiece mut; + mut.behind = str; + std::pair<boost::unordered_set<MutablePiece>::iterator, bool> res(strs_.insert(mut)); + if (res.second) { + void *mem = backing_.Allocate(str.size() + 1); + memcpy(mem, str.data(), str.size()); + static_cast<char*>(mem)[str.size()] = 0; + res.first->behind = StringPiece(static_cast<char*>(mem), str.size()); + } + return res.first->behind.data(); + } + + private: + util::Pool backing_; + boost::unordered_set<MutablePiece> strs_; +}; + +class TargetWords { + public: + void Introduce(StringPiece source) { + vocab_.resize(vocab_.size() + 1); + std::vector<unsigned int> temp(1, vocab_.size() - 1); + Add(temp, source); + } + + void Add(const std::vector<unsigned int> &sentences, StringPiece target) { + if (sentences.empty()) return; + interns_.clear(); + for (util::TokenIter<util::SingleCharacter, true> i(target, ' '); i; ++i) { + interns_.push_back(intern_.Add(*i)); + } + for (std::vector<unsigned int>::const_iterator i(sentences.begin()); i != sentences.end(); ++i) { + boost::unordered_set<const char *> &vocab = vocab_[*i]; + for (std::vector<const char *>::const_iterator j = interns_.begin(); j != interns_.end(); ++j) { + vocab.insert(*j); + } + } + } + + void Print() const { + util::FakeOFStream out(1); + for (std::vector<boost::unordered_set<const char *> >::const_iterator i = vocab_.begin(); i != vocab_.end(); ++i) { + for (boost::unordered_set<const char *>::const_iterator j = i->begin(); j != i->end(); ++j) { + out << *j << ' '; + } + out << '\n'; + } + } + + private: + InternString intern_; + + std::vector<boost::unordered_set<const char *> > vocab_; + + // Temporary in Add. + std::vector<const char *> interns_; +}; + +class Input { + public: + explicit Input(std::size_t max_length) + : max_length_(max_length), sentence_id_(0), empty_() {} + + void AddSentence(StringPiece sentence, TargetWords &targets) { + canonical_.clear(); + starts_.clear(); + starts_.push_back(0); + for (util::TokenIter<util::AnyCharacter, true> i(sentence, StringPiece("\0 \t", 3)); i; ++i) { + canonical_.append(i->data(), i->size()); + canonical_ += ' '; + starts_.push_back(canonical_.size()); + } + targets.Introduce(canonical_); + for (std::size_t i = 0; i < starts_.size() - 1; ++i) { + std::size_t subtract = starts_[i]; + const char *start = &canonical_[subtract]; + for (std::size_t j = i + 1; j < std::min(starts_.size(), i + max_length_ + 1); ++j) { + map_[util::MurmurHash64A(start, &canonical_[starts_[j]] - start - 1)].push_back(sentence_id_); + } + } + ++sentence_id_; + } + + // Assumes single space-delimited phrase with no space at the beginning or end. + const std::vector<unsigned int> &Matches(StringPiece phrase) const { + Map::const_iterator i = map_.find(util::MurmurHash64A(phrase.data(), phrase.size())); + return i == map_.end() ? empty_ : i->second; + } + + private: + const std::size_t max_length_; + + // hash of phrase is the key, array of sentences is the value. + typedef boost::unordered_map<uint64_t, std::vector<unsigned int> > Map; + Map map_; + + std::size_t sentence_id_; + + // Temporaries in AddSentence. + std::string canonical_; + std::vector<std::size_t> starts_; + + const std::vector<unsigned int> empty_; +}; + +} // namespace + +int main(int argc, char *argv[]) { + if (argc != 2) { + std::cerr << "Expected source text on the command line" << std::endl; + return 1; + } + Input input(7); + TargetWords targets; + try { + util::FilePiece inputs(argv[1], &std::cerr); + while (true) + input.AddSentence(inputs.ReadLine(), targets); + } catch (const util::EndOfFileException &e) {} + + util::FilePiece table(0, NULL, &std::cerr); + StringPiece line; + const StringPiece pipes("|||"); + while (true) { + try { + line = table.ReadLine(); + } catch (const util::EndOfFileException &e) { break; } + util::TokenIter<util::MultiCharacter> it(line, pipes); + StringPiece source(*it); + if (!source.empty() && source[source.size() - 1] == ' ') + source.remove_suffix(1); + targets.Add(input.Matches(source), *++it); + } + targets.Print(); +} diff --git a/klm/lm/filter/vocab.cc b/klm/lm/filter/vocab.cc index 7ee4e84b..011ab599 100644 --- a/klm/lm/filter/vocab.cc +++ b/klm/lm/filter/vocab.cc @@ -4,7 +4,6 @@ #include <iostream> #include <ctype.h> -#include <err.h> namespace lm { namespace vocab { diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh index 90b07a08..eb657501 100644 --- a/klm/lm/filter/wrapper.hh +++ b/klm/lm/filter/wrapper.hh @@ -39,17 +39,15 @@ template <class FilterT> class ContextFilter { explicit ContextFilter(Filter &backend) : backend_(backend) {} template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { - pieces_.clear(); - // TODO: this copy could be avoided by a lookahead iterator. - std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_)); - backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output); + // Find beginning of string or last space. + const char *last_space; + for (last_space = ngram.data() + ngram.size() - 1; last_space > ngram.data() && *last_space != ' '; --last_space) {} + backend_.AddNGram(StringPiece(ngram.data(), last_space - ngram.data()), line, output); } void Flush() const {} private: - std::vector<StringPiece> pieces_; - Filter backend_; }; diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a26654a6..a5a16bf8 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -34,23 +34,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size); } -template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { - LoadLM(file, config, *this); - - // g++ prints warnings unless these are fully initialized. - State begin_sentence = State(); - begin_sentence.length = 1; - begin_sentence.words[0] = vocab_.BeginSentence(); - typename Search::Node ignored_node; - bool ignored_independent_left; - uint64_t ignored_extend_left; - begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); - State null_context = State(); - null_context.length = 0; - P::Init(begin_sentence, null_context, vocab_, search_.Order()); +namespace { +void ComplainAboutARPA(const Config &config, ModelType model_type) { + if (config.write_mmap || !config.messages) return; + if (config.arpa_complain == Config::ALL) { + *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; + } else if (config.arpa_complain == Config::EXPENSIVE && + (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { + *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; + } } -namespace { void CheckCounts(const std::vector<uint64_t> &counts) { UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE); if (sizeof(uint64_t) > sizeof(std::size_t)) { @@ -59,18 +53,45 @@ void CheckCounts(const std::vector<uint64_t> &counts) { } } } + } // namespace -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { - CheckCounts(params.counts); - SetupMemory(start, params.counts, config); - vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab); - search_.LoadedBinary(); +template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &init_config) : backing_(init_config) { + util::scoped_fd fd(util::OpenReadOrThrow(file)); + if (IsBinaryFormat(fd.get())) { + Parameters parameters; + int fd_shallow = fd.release(); + backing_.InitializeBinary(fd_shallow, kModelType, kVersion, parameters); + CheckCounts(parameters.counts); + + Config new_config(init_config); + new_config.probing_multiplier = parameters.fixed.probing_multiplier; + Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config); + UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); + + SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config); + vocab_.LoadedBinary(parameters.fixed.has_vocabulary, fd_shallow, new_config.enumerate_vocab, backing_.VocabStringReadingOffset()); + } else { + ComplainAboutARPA(init_config, kModelType); + InitializeFromARPA(fd.release(), file, init_config); + } + + // g++ prints warnings unless these are fully initialized. + State begin_sentence = State(); + begin_sentence.length = 1; + begin_sentence.words[0] = vocab_.BeginSentence(); + typename Search::Node ignored_node; + bool ignored_independent_left; + uint64_t ignored_extend_left; + begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); + State null_context = State(); + null_context.length = 0; + P::Init(begin_sentence, null_context, vocab_, search_.Order()); } -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { - // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. - util::FilePiece f(backing_.file.release(), file, config.ProgressMessages()); +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(int fd, const char *file, const Config &config) { + // Backing file is the ARPA. + util::FilePiece f(fd, file, config.ProgressMessages()); try { std::vector<uint64_t> counts; // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. @@ -81,13 +102,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. - vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); + vocab_.SetupMemory(backing_.SetupJustVocab(vocab_size, counts.size()), vocab_size, counts[0], config); - if (config.write_mmap) { + if (config.write_mmap && config.include_vocab) { WriteWordsWrapper wrap(config.enumerate_vocab); vocab_.ConfigureEnumerate(&wrap, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config)); + void *vocab_rebase, *search_rebase; + backing_.WriteVocabWords(wrap.Buffer(), vocab_rebase, search_rebase); + // Due to writing at the end of file, mmap may have relocated data. So remap. + vocab_.Relocate(vocab_rebase); + search_.SetupMemory(reinterpret_cast<uint8_t*>(search_rebase), counts, config); } else { vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); @@ -99,18 +124,13 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.UnknownUnigram().backoff = 0.0; search_.UnknownUnigram().prob = config.unknown_missing_logprob; } - FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_); + backing_.FinishFile(config, kModelType, kVersion, counts); } catch (util::Exception &e) { e << " Byte: " << f.Offset(); throw; } } -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { - util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); - Search::UpdateConfigFromBinary(fd, counts, config); -} - template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) { diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 60f55110..e75da93b 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -67,7 +67,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; /* Get the state for a context. Don't use this if you can avoid it. Use - * BeginSentenceState or EmptyContextState and extend from those. If + * BeginSentenceState or NullContextState and extend from those. If * you're only going to use this state to call FullScore once, use * FullScoreForgotState. * To use this function, make an array of WordIndex containing the context @@ -104,10 +104,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod } private: - friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); - - static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); - FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const; // Score bigrams and above. Do not include backoff. @@ -116,15 +112,11 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod // Appears after Size in the cc file. void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config); - void InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd); - - void InitializeFromARPA(const char *file, const Config &config); + void InitializeFromARPA(int fd, const char *file, const Config &config); float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const; - Backing &MutableBacking() { return backing_; } - - Backing backing_; + BinaryFormat backing_; VocabularyT vocab_; diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index eb159094..7005b05e 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -360,10 +360,11 @@ BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) { LoadingTest<QuantArrayTrieModel>(); } -template <class ModelT> void BinaryTest() { +template <class ModelT> void BinaryTest(Config::WriteMethod write_method) { Config config; config.write_mmap = "test.binary"; config.messages = NULL; + config.write_method = write_method; ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; @@ -406,6 +407,11 @@ template <class ModelT> void BinaryTest() { unlink("test_nounk.binary"); } +template <class ModelT> void BinaryTest() { + BinaryTest<ModelT>(Config::WRITE_MMAP); + BinaryTest<ModelT>(Config::WRITE_AFTER); +} + BOOST_AUTO_TEST_CASE(write_and_read_probing) { BinaryTest<ProbingModel>(); } diff --git a/klm/lm/ngram_query.hh b/klm/lm/ngram_query.hh index dfcda170..ec2590f4 100644 --- a/klm/lm/ngram_query.hh +++ b/klm/lm/ngram_query.hh @@ -11,21 +11,25 @@ #include <istream> #include <string> +#include <math.h> + namespace lm { namespace ngram { template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { - std::cerr << "Loading statistics:\n"; - util::PrintUsage(std::cerr); typename Model::State state, out; lm::FullScoreReturn ret; std::string word; + double corpus_total = 0.0; + uint64_t corpus_oov = 0; + uint64_t corpus_tokens = 0; + while (in_stream) { state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); float total = 0.0; bool got = false; - unsigned int oov = 0; + uint64_t oov = 0; while (in_stream >> word) { got = true; lm::WordIndex vocab = model.GetVocabulary().Index(word); @@ -33,6 +37,7 @@ template <class Model> void Query(const Model &model, bool sentence_context, std ret = model.FullScore(state, vocab, out); total += ret.prob; out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; + ++corpus_tokens; state = out; char c; while (true) { @@ -50,12 +55,14 @@ template <class Model> void Query(const Model &model, bool sentence_context, std if (sentence_context) { ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); total += ret.prob; + ++corpus_tokens; out_stream << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; } out_stream << "Total: " << total << " OOV: " << oov << '\n'; + corpus_total += total; + corpus_oov += oov; } - std::cerr << "After queries:\n"; - util::PrintUsage(std::cerr); + out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << std::endl; } template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index b58c3f3f..273ea398 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -38,13 +38,13 @@ const char kSeparatelyQuantizeVersion = 2; } // namespace -void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) { - char version; - util::ReadOrThrow(fd, &version, 1); - util::ReadOrThrow(fd, &config.prob_bits, 1); - util::ReadOrThrow(fd, &config.backoff_bits, 1); +void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { + unsigned char buffer[3]; + file.ReadForConfig(buffer, 3, offset); + char version = buffer[0]; + config.prob_bits = buffer[1]; + config.backoff_bits = buffer[2]; if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); - util::AdvanceOrThrow(fd, -3); } void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) { diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 8ce2378a..9d3a2f43 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -18,12 +18,13 @@ namespace lm { namespace ngram { struct Config; +class BinaryFormat; /* Store values directly and don't quantize. */ class DontQuantize { public: static const ModelType kModelTypeAdd = static_cast<ModelType>(0); - static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} + static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {} static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } static uint8_t LongestBits(const Config &/*config*/) { return 31; } @@ -136,7 +137,7 @@ class SeparatelyQuantize { public: static const ModelType kModelTypeAdd = kQuantAdd; - static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); + static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); static uint64_t Size(uint8_t order, const Config &config) { uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float); diff --git a/klm/lm/query_main.cc b/klm/lm/query_main.cc index 27d3a1a5..bd4fde62 100644 --- a/klm/lm/query_main.cc +++ b/klm/lm/query_main.cc @@ -1,42 +1,65 @@ #include "lm/ngram_query.hh" +#ifdef WITH_NPLM +#include "lm/wrappers/nplm.hh" +#endif + +#include <stdlib.h> + +void Usage(const char *name) { + std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl; + std::cerr << "Usage: " << name << " [-n] lm_file" << std::endl; + std::cerr << "Input is wrapped in <s> and </s> unless -n is passed." << std::endl; + exit(1); +} + int main(int argc, char *argv[]) { - if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) { - std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl; - std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl; - std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl; - return 1; + bool sentence_context = true; + const char *file = NULL; + for (char **arg = argv + 1; arg != argv + argc; ++arg) { + if (!strcmp(*arg, "-n")) { + sentence_context = false; + } else if (!strcmp(*arg, "-h") || !strcmp(*arg, "--help") || file) { + Usage(argv[0]); + } else { + file = *arg; + } } + if (!file) Usage(argv[0]); try { - bool sentence_context = (argc == 2); using namespace lm::ngram; ModelType model_type; - if (RecognizeBinary(argv[1], model_type)) { + if (RecognizeBinary(file, model_type)) { switch(model_type) { case PROBING: - Query<lm::ngram::ProbingModel>(argv[1], sentence_context, std::cin, std::cout); + Query<lm::ngram::ProbingModel>(file, sentence_context, std::cin, std::cout); break; case REST_PROBING: - Query<lm::ngram::RestProbingModel>(argv[1], sentence_context, std::cin, std::cout); + Query<lm::ngram::RestProbingModel>(file, sentence_context, std::cin, std::cout); break; case TRIE: - Query<TrieModel>(argv[1], sentence_context, std::cin, std::cout); + Query<TrieModel>(file, sentence_context, std::cin, std::cout); break; case QUANT_TRIE: - Query<QuantTrieModel>(argv[1], sentence_context, std::cin, std::cout); + Query<QuantTrieModel>(file, sentence_context, std::cin, std::cout); break; case ARRAY_TRIE: - Query<ArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout); + Query<ArrayTrieModel>(file, sentence_context, std::cin, std::cout); break; case QUANT_ARRAY_TRIE: - Query<QuantArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout); + Query<QuantArrayTrieModel>(file, sentence_context, std::cin, std::cout); break; default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; abort(); } +#ifdef WITH_NPLM + } else if (lm::np::Model::Recognize(file)) { + lm::np::Model model(file); + Query(model, sentence_context, std::cin, std::cout); +#endif } else { - Query<ProbingModel>(argv[1], sentence_context, std::cin, std::cout); + Query<ProbingModel>(file, sentence_context, std::cin, std::cout); } std::cerr << "Total time including destruction:\n"; util::PrintUsage(std::cerr); diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 9ea08798..fb8bbfa2 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -19,7 +19,7 @@ namespace lm { -// 1 for '\t', '\n', and ' '. This is stricter than isspace. +// 1 for '\t', '\n', and ' '. This is stricter than isspace. const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; namespace { @@ -50,7 +50,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { // In general, ARPA files can have arbitrary text before "\data\" // But in KenLM, we require such lines to start with "#", so that // we can do stricter error checking - while (IsEntirelyWhiteSpace(line) || line.starts_with("#")) { + while (IsEntirelyWhiteSpace(line) || starts_with(line, "#")) { line = in.ReadLine(); } @@ -58,7 +58,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) { UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); } - if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic) + if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic) UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?"); UTIL_THROW_IF(line.size() >= 4 && StringPiece(line.data(), 4) == "blmt", FormatLoadException, "This looks like an IRSTLM binary file. Did you forget to pass --text yes to compile-lm?"); UTIL_THROW_IF(line == "iARPA", FormatLoadException, "This looks like an IRSTLM iARPA file. You need an ARPA file. Run\n compile-lm --text yes " << in.FileName() << " " << in.FileName() << ".arpa\nfirst."); @@ -66,7 +66,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { } while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); - // So strtol doesn't go off the end of line. + // So strtol doesn't go off the end of line. std::string remaining(line.data() + 6, line.size() - 6); char *end_ptr; unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10); @@ -102,8 +102,8 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) { } void ReadBackoff(util::FilePiece &in, float &backoff) { - // Always make zero negative. - // Negative zero means that no (n+1)-gram has this n-gram as context. + // Always make zero negative. + // Negative zero means that no (n+1)-gram has this n-gram as context. // Therefore the hypothesis state can be shorter. Of course, many n-grams // are context for (n+1)-grams. An algorithm in the data structure will go // back and set the backoff to positive zero in these cases. @@ -150,7 +150,7 @@ void PositiveProbWarn::Warn(float prob) { case THROW_UP: UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability. Error"); case COMPLAIN: - std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapepd to 0 log probability." << std::endl; + std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapped to 0 log probability." << std::endl; action_ = SILENT; break; case SILENT: diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 62275d27..354a56b4 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -204,9 +204,10 @@ template <class Build, class Activate, class Store> void ReadNGrams( namespace detail { template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { - std::size_t allocated = Unigram::Size(counts[0]); - unigram_ = Unigram(start, counts[0], allocated); - start += allocated; + unigram_ = Unigram(start, counts[0]); + start += Unigram::Size(counts[0]); + std::size_t allocated; + middle_.clear(); for (unsigned int n = 2; n < counts.size(); ++n) { allocated = Middle::Size(counts[n - 1], config.probing_multiplier); middle_.push_back(Middle(start, allocated)); @@ -218,9 +219,21 @@ template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, return start; } -template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing) { - // TODO: fix sorted. - SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config); +/*template <class Value> void HashedSearch<Value>::Relocate(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { + unigram_ = Unigram(start, counts[0]); + start += Unigram::Size(counts[0]); + for (unsigned int n = 2; n < counts.size(); ++n) { + middle[n-2].Relocate(start); + start += Middle::Size(counts[n - 1], config.probing_multiplier) + } + longest_.Relocate(start); +}*/ + +template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing) { + void *vocab_rebase; + void *search_base = backing.GrowForSearch(Size(counts, config), vocab.UnkCountChangePadding(), vocab_rebase); + vocab.Relocate(vocab_rebase); + SetupMemory(reinterpret_cast<uint8_t*>(search_base), counts, config); PositiveProbWarn warn(config.positive_log_probability); Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn); @@ -277,14 +290,6 @@ template <class Value> template <class Build> void HashedSearch<Value>::ApplyBui ReadEnd(f); } -template <class Value> void HashedSearch<Value>::LoadedBinary() { - unigram_.LoadedBinary(); - for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) { - i->LoadedBinary(); - } - longest_.LoadedBinary(); -} - template class HashedSearch<BackoffValue>; template class HashedSearch<RestValue>; diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 9d067bc2..8193262b 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -18,7 +18,7 @@ namespace util { class FilePiece; } namespace lm { namespace ngram { -struct Backing; +class BinaryFormat; class ProbingVocabulary; namespace detail { @@ -72,7 +72,7 @@ template <class Value> class HashedSearch { static const unsigned int kVersion = 0; // TODO: move probing_multiplier here with next binary file format update. - static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} + static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {} static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { uint64_t ret = Unigram::Size(counts[0]); @@ -84,9 +84,7 @@ template <class Value> class HashedSearch { uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); - void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing); - - void LoadedBinary(); + void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing); unsigned char Order() const { return middle_.size() + 2; @@ -148,7 +146,7 @@ template <class Value> class HashedSearch { public: Unigram() {} - Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : + Unigram(void *start, uint64_t count) : unigram_(static_cast<typename Value::Weights*>(start)) #ifdef DEBUG , count_(count) @@ -168,8 +166,6 @@ template <class Value> class HashedSearch { typename Value::Weights &Unknown() { return unigram_[0]; } - void LoadedBinary() {} - // For building. typename Value::Weights *Raw() { return unigram_; } diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 1b0d9b26..4a88194e 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -253,11 +253,6 @@ class FindBlanks { ++counts_.back(); } - // Unigrams wrote one past. - void Cleanup() { - --counts_[0]; - } - const std::vector<uint64_t> &Counts() const { return counts_; } @@ -310,8 +305,6 @@ template <class Quant, class Bhiksha> class WriteEntries { typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob); } - void Cleanup() {} - private: RecordReader *contexts_; const Quant &quant_; @@ -385,14 +378,14 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con util::ErsatzProgress progress(unigram_count + 1, progress_out, message); WordIndex unigram = 0; std::priority_queue<Gram> grams; - grams.push(Gram(&unigram, 1)); + if (unigram_count) grams.push(Gram(&unigram, 1)); for (unsigned char i = 2; i <= total_order; ++i) { if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i)); } BlankManager<Doing> blank(total_order, doing); - while (true) { + while (!grams.empty()) { Gram top = grams.top(); grams.pop(); unsigned char order = top.end - top.begin; @@ -400,8 +393,7 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con blank.Visit(&unigram, 1, doing.UnigramProb(unigram)); doing.Unigram(unigram); progress.Set(unigram); - if (++unigram == unigram_count + 1) break; - grams.push(top); + if (++unigram < unigram_count) grams.push(top); } else { if (order == total_order) { blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob); @@ -414,8 +406,6 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con if (++reader) grams.push(top); } } - assert(grams.empty()); - doing.Cleanup(); } void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) { @@ -469,7 +459,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c } // namespace -template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) { RecordReader inputs[KENLM_MAX_ORDER - 1]; RecordReader contexts[KENLM_MAX_ORDER - 1]; @@ -498,7 +488,10 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); - out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); + void *vocab_relocate; + void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate); + vocab.Relocate(vocab_relocate); + out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config); for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -524,6 +517,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve { WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer); + // Write the last unigram entry, which is the end pointer for the bigrams. + writer.Unigram(counts[0]); } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. @@ -579,15 +574,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() { - unigram_.LoadedBinary(); - for (Middle *i = middle_begin_; i != middle_end_; ++i) { - i->LoadedBinary(); - } - longest_.LoadedBinary(); -} - -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) { std::string temporary_prefix; if (config.temporary_directory_prefix) { temporary_prefix = config.temporary_directory_prefix; diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 60be416b..299262a5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -11,18 +11,19 @@ #include "util/file_piece.hh" #include <vector> +#include <cstdlib> #include <assert.h> namespace lm { namespace ngram { -struct Backing; +class BinaryFormat; class SortedVocabulary; namespace trie { template <class Quant, class Bhiksha> class TrieSearch; class SortedFiles; -template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing); template <class Quant, class Bhiksha> class TrieSearch { public: @@ -38,11 +39,11 @@ template <class Quant, class Bhiksha> class TrieSearch { static const unsigned int kVersion = 1; - static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { - Quant::UpdateConfigFromBinary(fd, counts, config); - util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); + static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) { + Quant::UpdateConfigFromBinary(file, offset, config); // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2. - if (counts.size() > 2) Bhiksha::UpdateConfigFromBinary(fd, config); + if (counts.size() > 2) + Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config); } static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { @@ -59,9 +60,7 @@ template <class Quant, class Bhiksha> class TrieSearch { uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); - void LoadedBinary(); - - void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing); unsigned char Order() const { return middle_end_ - middle_begin_ + 2; @@ -102,14 +101,14 @@ template <class Quant, class Bhiksha> class TrieSearch { } private: - friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); + friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing); - // Middles are managed manually so we can delay construction and they don't have to be copyable. + // Middles are managed manually so we can delay construction and they don't have to be copyable. void FreeMiddles() { for (const Middle *i = middle_begin_; i != middle_end_; ++i) { i->~Middle(); } - free(middle_begin_); + std::free(middle_begin_); } typedef trie::BitPackedMiddle<Bhiksha> Middle; diff --git a/klm/lm/state.hh b/klm/lm/state.hh index a6b9accb..543df37c 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -102,7 +102,7 @@ struct ChartState { } bool operator<(const ChartState &other) const { - return Compare(other) == -1; + return Compare(other) < 0; } void ZeroRemaining() { diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 9ea3c546..d858ab5e 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -62,8 +62,6 @@ class Unigram { return unigram_; } - void LoadedBinary() {} - UnigramPointer Find(WordIndex word, NodeRange &next) const { UnigramValue *val = unigram_ + word; next.begin = val->next; @@ -108,8 +106,6 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked { void FinishedLoading(uint64_t next_end, const Config &config); - void LoadedBinary() { bhiksha_.LoadedBinary(); } - util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const; util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) { @@ -138,14 +134,9 @@ class BitPackedLongest : public BitPacked { BaseInit(base, max_vocab, quant_bits); } - void LoadedBinary() {} - util::BitAddress Insert(WordIndex word); util::BitAddress Find(WordIndex word, const NodeRange &node) const; - - private: - uint8_t quant_bits_; }; } // namespace trie diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index dc542bb3..126d43ab 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -50,6 +50,10 @@ class PartialViewProxy { const void *Data() const { return inner_.Data(); } void *Data() { return inner_.Data(); } + friend void swap(PartialViewProxy first, PartialViewProxy second) { + std::swap_ranges(reinterpret_cast<char*>(first.Data()), reinterpret_cast<char*>(first.Data()) + first.attention_size_, reinterpret_cast<char*>(second.Data())); + } + private: friend class util::ProxyIterator<PartialViewProxy>; diff --git a/klm/lm/value_build.cc b/klm/lm/value_build.cc index 6124f8da..3ec3dce2 100644 --- a/klm/lm/value_build.cc +++ b/klm/lm/value_build.cc @@ -9,6 +9,7 @@ namespace ngram { template <class Model> LowerRestBuild<Model>::LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab) { UTIL_THROW_IF(config.rest_lower_files.size() != order - 1, ConfigException, "This model has order " << order << " so there should be " << (order - 1) << " lower-order models for rest cost purposes."); Config for_lower = config; + for_lower.write_mmap = NULL; for_lower.rest_lower_files.clear(); // Unigram models aren't supported, so this is a custom loader. diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 17f064b2..7a3e2379 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -125,10 +125,13 @@ class Model { void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); } // Requires in_state != out_state - virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + virtual float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; // Requires in_state != out_state - virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + virtual FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + + // Prefer to use FullScore. The context words should be provided in reverse order. + virtual FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0; unsigned char Order() const { return order_; } diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index fd7f96dc..7f0878f4 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -32,7 +32,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5); // Sadly some LMs have <UNK>. const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5); -void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) { +void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) { + util::SeekOrThrow(fd, offset); // Check that we're at the right place by reading <unk> which is always first. char check_unk[6]; util::ReadOrThrow(fd, check_unk, 6); @@ -80,11 +81,6 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { buffer_.push_back(0); } -void WriteWordsWrapper::Write(int fd, uint64_t start) { - util::SeekOrThrow(fd, start); - util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); -} - SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) { @@ -100,6 +96,12 @@ void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size saw_unk_ = false; } +void SortedVocabulary::Relocate(void *new_start) { + std::size_t delta = end_ - begin_; + begin_ = reinterpret_cast<uint64_t*>(new_start) + 1; + end_ = begin_ + delta; +} + void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) { enumerate_ = to; if (enumerate_) { @@ -147,11 +149,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { bound_ = end_ - begin_ + 1; } -void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { +void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) { end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); SetSpecial(Index("<s>"), Index("</s>"), 0); bound_ = end_ - begin_ + 1; - if (have_words) ReadWords(fd, to, bound_); + if (have_words) ReadWords(fd, to, bound_, offset); } namespace { @@ -179,6 +181,11 @@ void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::siz saw_unk_ = false; } +void ProbingVocabulary::Relocate(void *new_start) { + header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start); + lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader))); +} + void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) { enumerate_ = to; if (enumerate_) { @@ -206,12 +213,11 @@ void ProbingVocabulary::InternalFinishedLoading() { SetSpecial(Index("<s>"), Index("</s>"), 0); } -void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { +void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) { UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); - lookup_.LoadedBinary(); bound_ = header_->bound; SetSpecial(Index("<s>"), Index("</s>"), 0); - if (have_words) ReadWords(fd, to, bound_); + if (have_words) ReadWords(fd, to, bound_, offset); } void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 226ae438..074b74d8 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -36,7 +36,7 @@ class WriteWordsWrapper : public EnumerateVocab { void Add(WordIndex index, const StringPiece &str); - void Write(int fd, uint64_t start); + const std::string &Buffer() const { return buffer_; } private: EnumerateVocab *inner_; @@ -71,6 +71,8 @@ class SortedVocabulary : public base::Vocabulary { // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + void Relocate(void *new_start); + void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); WordIndex Insert(const StringPiece &str); @@ -83,15 +85,13 @@ class SortedVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } - void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); private: uint64_t *begin_, *end_; WordIndex bound_; - WordIndex highest_value_; - bool saw_unk_; EnumerateVocab *enumerate_; @@ -140,6 +140,8 @@ class ProbingVocabulary : public base::Vocabulary { // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + void Relocate(void *new_start); + void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); WordIndex Insert(const StringPiece &str); @@ -152,7 +154,7 @@ class ProbingVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } - void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); private: void InternalFinishedLoading(); |