From 89238977fc9d8f8d9a6421b0d4f35afc200f08e7 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 28 Feb 2012 17:23:55 -0500 Subject: Subject: where's my kenlm update?? From: Chris Dyer --- klm/lm/bhiksha.cc | 7 +- klm/lm/bhiksha.hh | 2 +- klm/lm/binary_format.cc | 139 +++++++++++++++++-------------- klm/lm/binary_format.hh | 14 +--- klm/lm/blank.hh | 2 +- klm/lm/build_binary.cc | 35 ++++++-- klm/lm/config.cc | 1 + klm/lm/config.hh | 8 ++ klm/lm/left_test.cc | 11 ++- klm/lm/model.cc | 17 ++-- klm/lm/model.hh | 11 +-- klm/lm/model_test.cc | 24 ++++-- klm/lm/ngram_query.cc | 145 ++++++++------------------------ klm/lm/ngram_query.hh | 103 +++++++++++++++++++++++ klm/lm/quantize.cc | 38 ++++----- klm/lm/quantize.hh | 2 +- klm/lm/read_arpa.cc | 2 +- klm/lm/return.hh | 2 +- klm/lm/search_hashed.cc | 22 ++--- klm/lm/search_hashed.hh | 63 +++++++++++--- klm/lm/search_trie.cc | 88 +++++++------------- klm/lm/search_trie.hh | 10 ++- klm/lm/trie.hh | 2 +- klm/lm/trie_sort.cc | 217 +++++++++++++++++++++++++++--------------------- klm/lm/trie_sort.hh | 55 ++++++++---- klm/lm/vocab.cc | 53 +++++++----- klm/lm/vocab.hh | 36 +++++--- 27 files changed, 648 insertions(+), 461 deletions(-) create mode 100644 klm/lm/ngram_query.hh (limited to 'klm/lm') diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc index bf86fd4b..cdeafb47 100644 --- a/klm/lm/bhiksha.cc +++ b/klm/lm/bhiksha.cc @@ -1,5 +1,6 @@ #include "lm/bhiksha.hh" #include "lm/config.hh" +#include "util/file.hh" #include @@ -12,12 +13,12 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_ const uint8_t kArrayBhikshaVersion = 0; +// TODO: put this in binary file header instead when I change the binary file format again. void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { uint8_t version; uint8_t configured_bits; - if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) { - UTIL_THROW(util::ErrnoException, "Could not read from binary file"); - } + util::ReadOrThrow(fd, &version, 1); + util::ReadOrThrow(fd, &configured_bits, 1); if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); config.pointer_bhiksha_bits = configured_bits; } diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index 3df43dda..5182ee2e 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -13,7 +13,7 @@ #ifndef LM_BHIKSHA__ #define LM_BHIKSHA__ -#include +#include #include #include "lm/model_type.hh" diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 27cada13..4796f6d1 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -1,19 +1,15 @@ #include "lm/binary_format.hh" #include "lm/lm_exception.hh" +#include "util/file.hh" #include "util/file_piece.hh" +#include +#include #include #include -#include -#include -#include -#include -#include -#include -#include -#include +#include namespace lm { namespace ngram { @@ -24,14 +20,16 @@ const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; const long int kMagicVersion = 5; -// Test values. -struct Sanity { +// Old binary files built on 32-bit machines have this header. +// TODO: eliminate with next binary release. +struct OldSanity { char magic[sizeof(kMagicBytes)]; float zero_f, one_f, minus_half_f; WordIndex one_word_index, max_word_index; uint64_t one_uint64; void SetToReference() { + std::memset(this, 0, sizeof(OldSanity)); std::memcpy(magic, kMagicBytes, sizeof(magic)); zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; one_word_index = 1; @@ -40,27 +38,35 @@ struct Sanity { } }; -const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; -std::size_t TotalHeaderSize(unsigned char order) { - return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); -} +// Test values aligned to 8 bytes. +struct Sanity { + char magic[ALIGN8(sizeof(kMagicBytes))]; + float zero_f, one_f, minus_half_f; + WordIndex one_word_index, max_word_index, padding_to_8; + uint64_t one_uint64; -void ReadLoop(int fd, void *to_void, std::size_t size) { - uint8_t *to = static_cast(to_void); - while (size) { - ssize_t ret = read(fd, to, size); - if (ret == -1) UTIL_THROW(util::ErrnoException, "Failed to read from binary file"); - if (ret == 0) UTIL_THROW(util::ErrnoException, "Binary file too short"); - to += ret; - size -= ret; + void SetToReference() { + std::memset(this, 0, sizeof(Sanity)); + std::memcpy(magic, kMagicBytes, sizeof(kMagicBytes)); + zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; + one_word_index = 1; + max_word_index = std::numeric_limits::max(); + padding_to_8 = 0; + one_uint64 = 1; } +}; + +const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; + +std::size_t TotalHeaderSize(unsigned char order) { + return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); } void WriteHeader(void *to, const Parameters ¶ms) { Sanity header = Sanity(); header.SetToReference(); - memcpy(to, &header, sizeof(Sanity)); + std::memcpy(to, &header, sizeof(Sanity)); char *out = reinterpret_cast(to) + sizeof(Sanity); *reinterpret_cast(out) = params.fixed; @@ -74,14 +80,6 @@ void WriteHeader(void *to, const Parameters ¶ms) { } // namespace -void SeekOrThrow(int fd, off_t off) { - if ((off_t)-1 == lseek(fd, off, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed"); -} - -void AdvanceOrThrow(int fd, off_t off) { - if ((off_t)-1 == lseek(fd, off, SEEK_CUR)) UTIL_THROW(util::ErrnoException, "Seek failed"); -} - uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { if (config.write_mmap) { std::size_t total = TotalHeaderSize(order) + memory_size; @@ -89,7 +87,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ strncpy(reinterpret_cast(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order)); return reinterpret_cast(backing.vocab.get()) + TotalHeaderSize(order); } else { - backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + util::MapAnonymous(memory_size, backing.vocab); return reinterpret_cast(backing.vocab.get()); } } @@ -98,42 +96,58 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { // Grow the file to accomodate the search, using zeros. - if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size)) - UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed"); + try { + util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size); + } catch (util::ErrnoException &e) { + e << " for file " << config.write_mmap; + throw e; + } + if (config.write_method == Config::WRITE_AFTER) { + util::MapAnonymous(memory_size, backing.search); + return reinterpret_cast(backing.search.get()); + } + // mmap it now. // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. - off_t page_size = sysconf(_SC_PAGE_SIZE); - off_t alignment_cruft = adjusted_vocab % page_size; + std::size_t page_size = util::SizePage(); + std::size_t alignment_cruft = adjusted_vocab % page_size; backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); - return reinterpret_cast(backing.search.get()) + alignment_cruft; } else { - backing.search.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + util::MapAnonymous(memory_size, backing.search); return reinterpret_cast(backing.search.get()); } } -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, Backing &backing) { - if (config.write_mmap) { - if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC)) - UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap); - // header and vocab share the same mmap. The header is written here because we know the counts. - Parameters params; - params.counts = counts; - params.fixed.order = counts.size(); - params.fixed.probing_multiplier = config.probing_multiplier; - params.fixed.model_type = model_type; - params.fixed.has_vocabulary = config.include_vocab; - params.fixed.search_version = search_version; - WriteHeader(backing.vocab.get(), params); +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, std::size_t vocab_pad, Backing &backing) { + if (!config.write_mmap) return; + util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); + switch (config.write_method) { + case Config::WRITE_MMAP: + util::SyncOrThrow(backing.search.get(), backing.search.size()); + break; + case Config::WRITE_AFTER: + util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad); + util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size()); + util::FSyncOrThrow(backing.file.get()); + break; } + // header and vocab share the same mmap. The header is written here because we know the counts. + Parameters params = Parameters(); + params.counts = counts; + params.fixed.order = counts.size(); + params.fixed.probing_multiplier = config.probing_multiplier; + params.fixed.model_type = model_type; + params.fixed.has_vocabulary = config.include_vocab; + params.fixed.search_version = search_version; + WriteHeader(backing.vocab.get(), params); } namespace detail { bool IsBinaryFormat(int fd) { - const off_t size = util::SizeFile(fd); - if (size == util::kBadSize || (size <= static_cast(sizeof(Sanity)))) return false; + const uint64_t size = util::SizeFile(fd); + if (size == util::kBadSize || (size <= static_cast(sizeof(Sanity)))) return false; // Try reading the header. util::scoped_memory memory; try { @@ -154,19 +168,23 @@ bool IsBinaryFormat(int fd) { if ((end_ptr != begin_version) && version != kMagicVersion) { UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary"); } + + OldSanity old_sanity = OldSanity(); + old_sanity.SetToReference(); + UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable."); UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture"); } return false; } void ReadHeader(int fd, Parameters &out) { - SeekOrThrow(fd, sizeof(Sanity)); - ReadLoop(fd, &out.fixed, sizeof(out.fixed)); + util::SeekOrThrow(fd, sizeof(Sanity)); + util::ReadOrThrow(fd, &out.fixed, sizeof(out.fixed)); if (out.fixed.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); out.counts.resize(static_cast(out.fixed.order)); - ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); + if (out.fixed.order) util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); } void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms) { @@ -179,11 +197,11 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet } void SeekPastHeader(int fd, const Parameters ¶ms) { - SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); + util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); } uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { - const off_t file_size = util::SizeFile(backing.file.get()); + const uint64_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size; if (file_size != util::kBadSize && static_cast(file_size) < total_map) @@ -194,9 +212,8 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t if (config.enumerate_vocab && !params.fixed.has_vocabulary) UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); - if (config.enumerate_vocab) { - SeekOrThrow(backing.file.get(), total_map); - } + // Seek to vocabulary words + util::SeekOrThrow(backing.file.get(), total_map); return reinterpret_cast(backing.search.get()) + TotalHeaderSize(params.counts.size()); } diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index e9df0892..dd795f62 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -12,7 +12,7 @@ #include #include -#include +#include namespace lm { namespace ngram { @@ -33,10 +33,8 @@ struct FixedWidthParameters { unsigned int search_version; }; -inline std::size_t Align8(std::size_t in) { - std::size_t off = in % 8; - return off ? (in + 8 - off) : in; -} +// This is a macro instead of an inline function so constants can be assigned using it. +#define ALIGN8(a) ((std::ptrdiff_t(((a)-1)/8)+1)*8) // Parameters stored in the header of a binary file. struct Parameters { @@ -53,10 +51,6 @@ struct Backing { util::scoped_memory search; }; -void SeekOrThrow(int fd, off_t off); -// Seek forward -void AdvanceOrThrow(int fd, off_t off); - // Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. @@ -64,7 +58,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t // Write header to binary file. This is done last to prevent incomplete files // from loading. -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, Backing &backing); +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, std::size_t vocab_pad, Backing &backing); namespace detail { diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 2fb64cd0..4da81209 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -3,7 +3,7 @@ #include -#include +#include #include namespace lm { diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index fdb62a71..8cbb69d0 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -8,18 +8,24 @@ #include #include -#include + +#ifdef WIN32 +#include "util/getopt.hh" +#endif namespace lm { namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" "-u sets the log10 probability for if the ARPA file does not have one.\n" " Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have and .\n" -"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" +"-w mmap|after determines how writing is done.\n" +" mmap maps the binary file and writes to it. Default for trie.\n" +" after allocates anonymous memory, builds, and writes. Default for probing.\n\n" "type is either probing or trie. Default is probing.\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" @@ -55,7 +61,7 @@ uint8_t ParseBitCount(const char *from) { unsigned long val = ParseUInt(from); if (val > 25) { util::ParseNumberException e(from); - e << " bit counts are limited to 256."; + e << " bit counts are limited to 25."; } return val; } @@ -87,7 +93,7 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { prefix = 'G'; divide = 1 << 30; } - long int length = std::max(2, lrint(ceil(log10(max_length / divide)))); + long int length = std::max(2, static_cast(ceil(log10((double) max_length / divide)))); std::cout << "Memory estimate:\ntype "; // right align bytes. for (long int i = 0; i < length - 2; ++i) std::cout << ' '; @@ -112,10 +118,10 @@ int main(int argc, char *argv[]) { using namespace lm::ngram; try { - bool quantize = false, set_backoff_bits = false, bhiksha = false; + bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false; lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) { + while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:si")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -129,6 +135,7 @@ int main(int argc, char *argv[]) { case 'a': config.pointer_bhiksha_bits = ParseBitCount(optarg); bhiksha = true; + break; case 'u': config.unknown_missing_logprob = ParseFloat(optarg); break; @@ -141,6 +148,16 @@ int main(int argc, char *argv[]) { case 'm': config.building_memory = ParseUInt(optarg) * 1048576; break; + case 'w': + set_write_method = true; + if (!strcmp(optarg, "mmap")) { + config.write_method = Config::WRITE_MMAP; + } else if (!strcmp(optarg, "after")) { + config.write_method = Config::WRITE_AFTER; + } else { + Usage(argv[0]); + } + break; case 's': config.sentence_marker_missing = lm::SILENT; break; @@ -166,9 +183,11 @@ int main(int argc, char *argv[]) { const char *from_file = argv[optind + 1]; config.write_mmap = argv[optind + 2]; if (!strcmp(model_type, "probing")) { + if (!set_write_method) config.write_method = Config::WRITE_AFTER; if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); ProbingModel(from_file, config); } else if (!strcmp(model_type, "trie")) { + if (!set_write_method) config.write_method = Config::WRITE_MMAP; if (quantize) { if (bhiksha) { QuantArrayTrieModel(from_file, config); @@ -191,7 +210,9 @@ int main(int argc, char *argv[]) { } catch (const std::exception &e) { std::cerr << e.what() << std::endl; + std::cerr << "ERROR" << std::endl; return 1; } + std::cerr << "SUCCESS" << std::endl; return 0; } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 297589a4..dbe762b3 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -17,6 +17,7 @@ Config::Config() : temporary_directory_prefix(NULL), arpa_complain(ALL), write_mmap(NULL), + write_method(WRITE_AFTER), include_vocab(true), prob_bits(8), backoff_bits(8), diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 8564661b..01b75632 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -70,9 +70,17 @@ struct Config { // to NULL to disable. const char *write_mmap; + typedef enum { + WRITE_MMAP, // Map the file directly. + WRITE_AFTER // Write after we're done. + } WriteMethod; + WriteMethod write_method; + // Include the vocab in the binary file? Only effective if write_mmap != NULL. bool include_vocab; + + // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used // to quantize (and one of the remaining backoffs will be 0). diff --git a/klm/lm/left_test.cc b/klm/lm/left_test.cc index 8bb91cb3..c85e5efa 100644 --- a/klm/lm/left_test.cc +++ b/klm/lm/left_test.cc @@ -142,7 +142,7 @@ template float TreeMiddle(const M &m, const std::vector &wo template void LookupVocab(const M &m, const StringPiece &str, std::vector &out) { out.clear(); - for (util::PieceIterator<' '> i(str); i; ++i) { + for (util::TokenIter i(str, ' '); i; ++i) { out.push_back(m.GetVocabulary().Index(*i)); } } @@ -326,10 +326,17 @@ template void FullGrow(const M &m) { } } +const char *FileLocation() { + if (boost::unit_test::framework::master_test_suite().argc < 2) { + return "test.arpa"; + } + return boost::unit_test::framework::master_test_suite().argv[1]; +} + template void Everything() { Config config; config.messages = NULL; - M m("test.arpa", config); + M m(FileLocation(), config); Short(m); Charge(m); diff --git a/klm/lm/model.cc b/klm/lm/model.cc index e4c1ec1d..478ebed1 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -46,7 +46,7 @@ template GenericModel::Ge template void GenericModel::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { SetupMemory(start, params.counts, config); - vocab_.LoadedBinary(fd, config.enumerate_vocab); + vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab); search_.LoadedBinary(); } @@ -82,13 +82,18 @@ template void GenericModel void GenericModel::UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { + util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); + Search::UpdateConfigFromBinary(fd, counts, config); +} + template FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) { @@ -114,7 +119,7 @@ template FullScoreReturn GenericModel void GenericModel FullScoreReturn GenericModel FullScoreReturn GenericModel class GenericModel : public base::Mod * TrieModel. To classify binary files, call RecognizeBinary in * lm/binary_format.hh. */ - GenericModel(const char *file, const Config &config = Config()); + explicit GenericModel(const char *file, const Config &config = Config()); /* Score p(new_word | in_state) and incorporate new_word into out_state. * Note that in_state and out_state must be different references: @@ -137,14 +137,9 @@ template class GenericModel : public base::Mod unsigned char &next_use) const; private: - friend void LoadLM<>(const char *file, const Config &config, GenericModel &to); + friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel &to); - static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { - AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); - Search::UpdateConfigFromBinary(fd, counts, config); - } - - float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const; + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 2654071f..461704d4 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -19,6 +19,20 @@ std::ostream &operator<<(std::ostream &o, const State &state) { namespace { +const char *TestLocation() { + if (boost::unit_test::framework::master_test_suite().argc < 2) { + return "test.arpa"; + } + return boost::unit_test::framework::master_test_suite().argv[1]; +} +const char *TestNoUnkLocation() { + if (boost::unit_test::framework::master_test_suite().argc < 3) { + return "test_nounk.arpa"; + } + return boost::unit_test::framework::master_test_suite().argv[2]; + +} + #define StartTest(word, ngram, score, indep_left) \ ret = model.FullScore( \ state, \ @@ -307,7 +321,7 @@ template void LoadingTest() { { ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; - ModelT m("test.arpa", config); + ModelT m(TestLocation(), config); enumerate.Check(m.GetVocabulary()); BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); Everything(m); @@ -315,7 +329,7 @@ template void LoadingTest() { { ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; - ModelT m("test_nounk.arpa", config); + ModelT m(TestNoUnkLocation(), config); enumerate.Check(m.GetVocabulary()); BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); NoUnkCheck(m); @@ -346,7 +360,7 @@ template void BinaryTest() { config.enumerate_vocab = &enumerate; { - ModelT copy_model("test.arpa", config); + ModelT copy_model(TestLocation(), config); enumerate.Check(copy_model.GetVocabulary()); enumerate.Clear(); Everything(copy_model); @@ -370,14 +384,14 @@ template void BinaryTest() { config.messages = NULL; enumerate.Clear(); { - ModelT copy_model("test_nounk.arpa", config); + ModelT copy_model(TestNoUnkLocation(), config); enumerate.Check(copy_model.GetVocabulary()); enumerate.Clear(); NoUnkCheck(copy_model); } config.write_mmap = NULL; { - ModelT binary("test_nounk.binary", config); + ModelT binary(TestNoUnkLocation(), config); enumerate.Check(binary.GetVocabulary()); NoUnkCheck(binary); } diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index d9db4aa2..8f7a0e1c 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -1,87 +1,4 @@ -#include "lm/enumerate_vocab.hh" -#include "lm/model.hh" - -#include -#include -#include -#include - -#include - -#include -#include - -float FloatSec(const struct timeval &tv) { - return static_cast(tv.tv_sec) + (static_cast(tv.tv_usec) / 1000000000.0); -} - -void PrintUsage(const char *message) { - struct rusage usage; - if (getrusage(RUSAGE_SELF, &usage)) { - perror("getrusage"); - return; - } - std::cerr << message; - std::cerr << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n'; - - // Linux doesn't set memory usage :-(. - std::ifstream status("/proc/self/status", std::ios::in); - std::string line; - while (getline(status, line)) { - if (!strncmp(line.c_str(), "VmRSS:\t", 7)) { - std::cerr << "rss " << (line.c_str() + 7) << '\n'; - break; - } - } -} - -template void Query(const Model &model, bool sentence_context) { - PrintUsage("Loading statistics:\n"); - typename Model::State state, out; - lm::FullScoreReturn ret; - std::string word; - - while (std::cin) { - state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); - float total = 0.0; - bool got = false; - unsigned int oov = 0; - while (std::cin >> word) { - got = true; - lm::WordIndex vocab = model.GetVocabulary().Index(word); - if (vocab == 0) ++oov; - ret = model.FullScore(state, vocab, out); - total += ret.prob; - std::cout << word << '=' << vocab << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\t'; - state = out; - char c; - while (true) { - c = std::cin.get(); - if (!std::cin) break; - if (c == '\n') break; - if (!isspace(c)) { - std::cin.unget(); - break; - } - } - if (c == '\n') break; - } - if (!got && !std::cin) break; - if (sentence_context) { - ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); - total += ret.prob; - std::cout << "=" << model.GetVocabulary().EndSentence() << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\t'; - } - std::cout << "Total: " << total << " OOV: " << oov << '\n'; - } - PrintUsage("After queries:\n"); -} - -template void Query(const char *name) { - lm::ngram::Config config; - Model model(name, config); - Query(model); -} +#include "lm/ngram_query.hh" int main(int argc, char *argv[]) { if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) { @@ -89,34 +6,40 @@ int main(int argc, char *argv[]) { std::cerr << "Input is wrapped in and unless null is passed." << std::endl; return 1; } - bool sentence_context = (argc == 2); - lm::ngram::ModelType model_type; - if (lm::ngram::RecognizeBinary(argv[1], model_type)) { - switch(model_type) { - case lm::ngram::HASH_PROBING: - Query(argv[1], sentence_context); - break; - case lm::ngram::TRIE_SORTED: - Query(argv[1], sentence_context); - break; - case lm::ngram::QUANT_TRIE_SORTED: - Query(argv[1], sentence_context); - break; - case lm::ngram::ARRAY_TRIE_SORTED: - Query(argv[1], sentence_context); - break; - case lm::ngram::QUANT_ARRAY_TRIE_SORTED: - Query(argv[1], sentence_context); - break; - case lm::ngram::HASH_SORTED: - default: - std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; - abort(); + try { + bool sentence_context = (argc == 2); + using namespace lm::ngram; + ModelType model_type; + if (RecognizeBinary(argv[1], model_type)) { + switch(model_type) { + case HASH_PROBING: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + case TRIE_SORTED: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + case QUANT_TRIE_SORTED: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + case ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + case QUANT_ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + case HASH_SORTED: + default: + std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; + abort(); + } + } else { + Query(argv[1], sentence_context, std::cin, std::cout); } - } else { - Query(argv[1], sentence_context); - } - PrintUsage("Total time including destruction:\n"); + PrintUsage("Total time including destruction:\n"); + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; + } return 0; } diff --git a/klm/lm/ngram_query.hh b/klm/lm/ngram_query.hh new file mode 100644 index 00000000..4990df22 --- /dev/null +++ b/klm/lm/ngram_query.hh @@ -0,0 +1,103 @@ +#ifndef LM_NGRAM_QUERY__ +#define LM_NGRAM_QUERY__ + +#include "lm/enumerate_vocab.hh" +#include "lm/model.hh" + +#include +#include +#include +#include + +#include +#if !defined(_WIN32) && !defined(_WIN64) +#include +#include +#endif + +namespace lm { +namespace ngram { + +#if !defined(_WIN32) && !defined(_WIN64) +float FloatSec(const struct timeval &tv) { + return static_cast(tv.tv_sec) + (static_cast(tv.tv_usec) / 1000000000.0); +} +#endif + +void PrintUsage(const char *message) { +#if !defined(_WIN32) && !defined(_WIN64) + struct rusage usage; + if (getrusage(RUSAGE_SELF, &usage)) { + perror("getrusage"); + return; + } + std::cerr << message; + std::cerr << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n'; + + // Linux doesn't set memory usage :-(. + std::ifstream status("/proc/self/status", std::ios::in); + std::string line; + while (getline(status, line)) { + if (!strncmp(line.c_str(), "VmRSS:\t", 7)) { + std::cerr << "rss " << (line.c_str() + 7) << '\n'; + break; + } + } +#endif +} + +template void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { + PrintUsage("Loading statistics:\n"); + typename Model::State state, out; + lm::FullScoreReturn ret; + std::string word; + + while (in_stream) { + state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); + float total = 0.0; + bool got = false; + unsigned int oov = 0; + while (in_stream >> word) { + got = true; + lm::WordIndex vocab = model.GetVocabulary().Index(word); + if (vocab == 0) ++oov; + ret = model.FullScore(state, vocab, out); + total += ret.prob; + out_stream << word << '=' << vocab << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\t'; + state = out; + char c; + while (true) { + c = in_stream.get(); + if (!in_stream) break; + if (c == '\n') break; + if (!isspace(c)) { + in_stream.unget(); + break; + } + } + if (c == '\n') break; + } + if (!got && !in_stream) break; + if (sentence_context) { + ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); + total += ret.prob; + out_stream << "=" << model.GetVocabulary().EndSentence() << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\t'; + } + out_stream << "Total: " << total << " OOV: " << oov << '\n'; + } + PrintUsage("After queries:\n"); +} + +template void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { + Config config; +// config.load_method = util::LAZY; + M model(file, config); + Query(model, sentence_context, in_stream, out_stream); +} + +} // namespace ngram +} // namespace lm + +#endif // LM_NGRAM_QUERY__ + + diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index 98a5d048..a8e0cb21 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -1,31 +1,30 @@ +/* Quantize into bins of equal size as described in + * M. Federico and N. Bertoldi. 2006. How many bits are needed + * to store probabilities for phrase-based translation? In Proc. + * of the Workshop on Statistical Machine Translation, pages + * 94–101, New York City, June. Association for Computa- + * tional Linguistics. + */ + #include "lm/quantize.hh" #include "lm/binary_format.hh" #include "lm/lm_exception.hh" +#include "util/file.hh" #include #include -#include - namespace lm { namespace ngram { -/* Quantize into bins of equal size as described in - * M. Federico and N. Bertoldi. 2006. How many bits are needed - * to store probabilities for phrase-based translation? In Proc. - * of the Workshop on Statistical Machine Translation, pages - * 94–101, New York City, June. Association for Computa- - * tional Linguistics. - */ - namespace { -void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) { - std::sort(values, values_end); - const float *start = values, *finish; +void MakeBins(std::vector &values, float *centers, uint32_t bins) { + std::sort(values.begin(), values.end()); + std::vector::const_iterator start = values.begin(), finish; for (uint32_t i = 0; i < bins; ++i, ++centers, start = finish) { - finish = values + (((values_end - values) * static_cast(i + 1)) / bins); + finish = values.begin() + ((values.size() * static_cast(i + 1)) / bins); if (finish == start) { // zero length bucket. *centers = i ? *(centers - 1) : -std::numeric_limits::infinity(); @@ -41,10 +40,11 @@ const char kSeparatelyQuantizeVersion = 2; void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector &/*counts*/, Config &config) { char version; - if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1) - UTIL_THROW(util::ErrnoException, "Failed to read header for quantization."); + util::ReadOrThrow(fd, &version, 1); + util::ReadOrThrow(fd, &config.prob_bits, 1); + util::ReadOrThrow(fd, &config.backoff_bits, 1); if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); - AdvanceOrThrow(fd, -3); + util::AdvanceOrThrow(fd, -3); } void SeparatelyQuantize::SetupMemory(void *start, const Config &config) { @@ -66,12 +66,12 @@ void SeparatelyQuantize::Train(uint8_t order, std::vector &prob, std::vec float *centers = start_ + TableStart(order) + ProbTableLength(); *(centers++) = kNoExtensionBackoff; *(centers++) = kExtensionBackoff; - MakeBins(&*backoff.begin(), &*backoff.end(), centers, (1ULL << backoff_bits_) - 2); + MakeBins(backoff, centers, (1ULL << backoff_bits_) - 2); } void SeparatelyQuantize::TrainProb(uint8_t order, std::vector &prob) { float *centers = start_ + TableStart(order); - MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_)); + MakeBins(prob, centers, (1ULL << prob_bits_)); } void SeparatelyQuantize::FinishedLoading(const Config &config) { diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 4cf4236e..6d130a57 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -9,7 +9,7 @@ #include #include -#include +#include #include diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index dce73f77..05f761be 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -8,7 +8,7 @@ #include #include -#include +#include namespace lm { diff --git a/klm/lm/return.hh b/klm/lm/return.hh index 15571960..1b55091b 100644 --- a/klm/lm/return.hh +++ b/klm/lm/return.hh @@ -1,7 +1,7 @@ #ifndef LM_RETURN__ #define LM_RETURN__ -#include +#include namespace lm { /* Structure returned by scoring routines. */ diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 247832b0..1d6fb5be 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -30,7 +30,7 @@ template class ActivateLowerMiddle { // TODO: somehow get text of n-gram for this error message. if (!modify_.UnsafeMutableFind(hash, i)) UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram"); - SetExtension(i->MutableValue().backoff); + SetExtension(i->value.backoff); } private: @@ -65,7 +65,7 @@ template void FixSRI(int lower, float negative_lower_prob, unsign blank.prob -= unigrams[vocab_ids[1]].backoff; SetExtension(unigrams[vocab_ids[1]].backoff); // Bigram including a unigram's backoff - middle[0].Insert(Middle::Packing::Make(keys[0], blank)); + middle[0].Insert(detail::ProbBackoffEntry::Make(keys[0], blank)); fix = 1; } else { for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]); @@ -74,22 +74,24 @@ template void FixSRI(int lower, float negative_lower_prob, unsign for (; fix <= n - 3; ++fix) { typename Middle::MutableIterator gotit; if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) { - float &backoff = gotit->MutableValue().backoff; + float &backoff = gotit->value.backoff; SetExtension(backoff); blank.prob -= backoff; } - middle[fix].Insert(Middle::Packing::Make(keys[fix], blank)); + middle[fix].Insert(detail::ProbBackoffEntry::Make(keys[fix], blank)); backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]); } } template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector &middle, Activate activate, Store &store, PositiveProbWarn &warn) { + assert(n >= 2); ReadNGramHeader(f, n); - // vocab ids of words in reverse order + // Both vocab_ids and keys are non-empty because n >= 2. + // vocab ids of words in reverse order. std::vector vocab_ids(n); std::vector keys(n-1); - typename Store::Packing::Value value; + typename Store::Entry::Value value; typename Middle::MutableIterator found; for (size_t i = 0; i < count; ++i) { ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn); @@ -100,7 +102,7 @@ template void ReadNGrams( } // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. util::SetSign(value.prob); - store.Insert(Store::Packing::Make(keys[n-2], value)); + store.Insert(Store::Entry::Make(keys[n-2], value)); // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. int lower; util::FloatEnc fix_prob; @@ -113,9 +115,9 @@ template void ReadNGrams( } if (middle[lower].UnsafeMutableFind(keys[lower], found)) { // Turn off sign bit to indicate that it extends left. - fix_prob.f = found->MutableValue().prob; + fix_prob.f = found->value.prob; fix_prob.i &= ~util::kSignBit; - found->MutableValue().prob = fix_prob.f; + found->value.prob = fix_prob.f; // We don't need to recurse further down because this entry already set the bits for lower entries. break; } @@ -147,7 +149,7 @@ template uint8_t *TemplateHashedSearch template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. - SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config); + SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config); PositiveProbWarn warn(config.positive_log_probability); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index e289fd11..4352c72d 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -8,7 +8,6 @@ #include "lm/weights.hh" #include "util/bit_packing.hh" -#include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" #include @@ -92,8 +91,10 @@ template class TemplateHashedSearch : public Has template void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing); - const Middle *MiddleBegin() const { return &*middle_.begin(); } - const Middle *MiddleEnd() const { return &*middle_.end(); } + typedef typename std::vector::const_iterator MiddleIter; + + MiddleIter MiddleBegin() const { return middle_.begin(); } + MiddleIter MiddleEnd() const { return middle_.end(); } Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { util::FloatEnc val; @@ -105,7 +106,7 @@ template class TemplateHashedSearch : public Has std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl; abort(); } - val.f = found->GetValue().prob; + val.f = found->value.prob; } val.i |= util::kSignBit; prob = val.f; @@ -117,12 +118,12 @@ template class TemplateHashedSearch : public Has typename Middle::ConstIterator found; if (!middle.Find(node, found)) return false; util::FloatEnc enc; - enc.f = found->GetValue().prob; + enc.f = found->value.prob; ret.independent_left = (enc.i & util::kSignBit); ret.extend_left = node; enc.i |= util::kSignBit; ret.prob = enc.f; - backoff = found->GetValue().backoff; + backoff = found->value.backoff; return true; } @@ -132,7 +133,7 @@ template class TemplateHashedSearch : public Has node = CombineWordHash(node, word); typename Middle::ConstIterator found; if (!middle.Find(node, found)) return false; - backoff = found->GetValue().backoff; + backoff = found->value.backoff; return true; } @@ -141,7 +142,7 @@ template class TemplateHashedSearch : public Has node = CombineWordHash(node, word); typename Longest::ConstIterator found; if (!longest.Find(node, found)) return false; - prob = found->GetValue().prob; + prob = found->value.prob; return true; } @@ -160,14 +161,50 @@ template class TemplateHashedSearch : public Has std::vector middle_; }; -// std::identity is an SGI extension :-( -struct IdentityHash : public std::unary_function { - size_t operator()(uint64_t arg) const { return static_cast(arg); } +/* These look like perfect candidates for a template, right? Ancient gcc (4.1 + * on RedHat stale linux) doesn't pack templates correctly. ProbBackoffEntry + * is a multiple of 8 bytes anyway. ProbEntry is 12 bytes so it's set to pack. + */ +struct ProbBackoffEntry { + uint64_t key; + ProbBackoff value; + typedef uint64_t Key; + typedef ProbBackoff Value; + uint64_t GetKey() const { + return key; + } + static ProbBackoffEntry Make(uint64_t key, ProbBackoff value) { + ProbBackoffEntry ret; + ret.key = key; + ret.value = value; + return ret; + } }; +#pragma pack(push) +#pragma pack(4) +struct ProbEntry { + uint64_t key; + Prob value; + typedef uint64_t Key; + typedef Prob Value; + uint64_t GetKey() const { + return key; + } + static ProbEntry Make(uint64_t key, Prob value) { + ProbEntry ret; + ret.key = key; + ret.value = value; + return ret; + } +}; + +#pragma pack(pop) + + struct ProbingHashedSearch : public TemplateHashedSearch< - util::ProbingHashTable, IdentityHash>, - util::ProbingHashTable, IdentityHash> > { + util::ProbingHashTable, + util::ProbingHashTable > { static const ModelType kModelType = HASH_PROBING; }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 4bd3f4ee..ffadfa94 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -13,6 +13,7 @@ #include "lm/weights.hh" #include "lm/word_index.hh" #include "util/ersatz_progress.hh" +#include "util/mmap.hh" #include "util/proxy_iterator.hh" #include "util/scoped.hh" #include "util/sized_iterator.hh" @@ -20,14 +21,15 @@ #include #include #include +#include #include #include #include #include -#include -#include -#include +#if defined(_WIN32) || defined(_WIN64) +#include +#endif namespace lm { namespace ngram { @@ -195,7 +197,7 @@ class SRISucks { void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) { for (unsigned char i = 0; i < kMaxOrder - 1; ++i) { - it_[i] = &*values_[i].begin(); + it_[i] = values_[i].empty() ? NULL : &*values_[i].begin(); } messages_[0].Apply(it_, unigram_file); BackoffMessages *messages = messages_ + 1; @@ -227,8 +229,8 @@ class SRISucks { class FindBlanks { public: - FindBlanks(uint64_t *counts, unsigned char order, const ProbBackoff *unigrams, SRISucks &messages) - : counts_(counts), longest_counts_(counts + order - 1), unigrams_(unigrams), sri_(messages) {} + FindBlanks(unsigned char order, const ProbBackoff *unigrams, SRISucks &messages) + : counts_(order), unigrams_(unigrams), sri_(messages) {} float UnigramProb(WordIndex index) const { return unigrams_[index].prob; @@ -248,7 +250,7 @@ class FindBlanks { } void Longest(const void * /*data*/) { - ++*longest_counts_; + ++counts_.back(); } // Unigrams wrote one past. @@ -256,8 +258,12 @@ class FindBlanks { --counts_[0]; } + const std::vector &Counts() const { + return counts_; + } + private: - uint64_t *const counts_, *const longest_counts_; + std::vector counts_; const ProbBackoff *unigrams_; @@ -375,7 +381,7 @@ template class BlankManager { template void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) { util::ErsatzProgress progress(progress_out, message, unigram_count + 1); - unsigned int unigram = 0; + WordIndex unigram = 0; std::priority_queue grams; grams.push(Gram(&unigram, 1)); for (unsigned char i = 2; i <= total_order; ++i) { @@ -461,42 +467,33 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c } // namespace -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { +template void BuildTrie(SortedFiles &files, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { RecordReader inputs[kMaxOrder - 1]; RecordReader contexts[kMaxOrder - 1]; for (unsigned char i = 2; i <= counts.size(); ++i) { - std::stringstream assembled; - assembled << file_prefix << static_cast(i) << "_merged"; - inputs[i-2].Init(assembled.str(), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff))); - util::RemoveOrThrow(assembled.str().c_str()); - assembled << kContextSuffix; - contexts[i-2].Init(assembled.str(), (i-1) * sizeof(WordIndex)); - util::RemoveOrThrow(assembled.str().c_str()); + inputs[i-2].Init(files.Full(i), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff))); + contexts[i-2].Init(files.Context(i), (i-1) * sizeof(WordIndex)); } SRISucks sri; - std::vector fixed_counts(counts.size()); + std::vector fixed_counts; + util::scoped_FILE unigram_file; + util::scoped_fd unigram_fd(files.StealUnigram()); { - std::string temp(file_prefix); temp += "unigrams"; - util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str())); util::scoped_memory unigrams; - MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); - FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast(unigrams.get()), sri); + MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); + FindBlanks finder(counts.size(), reinterpret_cast(unigrams.get()), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); + fixed_counts = finder.Counts(); } + unigram_file.reset(util::FDOpenOrThrow(unigram_fd)); for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) { if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - util::scoped_FILE unigram_file; - { - std::string name(file_prefix + "unigrams"); - unigram_file.reset(OpenOrThrow(name.c_str(), "r+")); - util::RemoveOrThrow(name.c_str()); - } sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); @@ -587,42 +584,19 @@ template void TrieSearch::LoadedBin longest.LoadedBinary(); } -namespace { -bool IsDirectory(const char *path) { - struct stat info; - if (0 != stat(path, &info)) return false; - return S_ISDIR(info.st_mode); -} -} // namespace - template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { - std::string temporary_directory; + std::string temporary_prefix; if (config.temporary_directory_prefix) { - temporary_directory = config.temporary_directory_prefix; - if (!temporary_directory.empty() && temporary_directory[temporary_directory.size() - 1] != '/' && IsDirectory(temporary_directory.c_str())) - temporary_directory += '/'; + temporary_prefix = config.temporary_directory_prefix; } else if (config.write_mmap) { - temporary_directory = config.write_mmap; + temporary_prefix = config.write_mmap; } else { - temporary_directory = file; - } - // Null on end is kludge to ensure null termination. - temporary_directory += "_trie_tmp_XXXXXX"; - temporary_directory += '\0'; - if (!mkdtemp(&temporary_directory[0])) { - UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str()); + temporary_prefix = file; } - // Chop off null kludge. - temporary_directory.resize(strlen(temporary_directory.c_str())); - // Add directory delimiter. Assumes a real operating system. - temporary_directory += '/'; // At least 1MB sorting memory. - ARPAToSortedFiles(config, f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); + SortedFiles sorted(config, f, counts, std::max(config.building_memory, 1048576), temporary_prefix, vocab); - BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing); - if (rmdir(temporary_directory.c_str()) && config.messages) { - *config.messages << "Failed to delete " << temporary_directory << std::endl; - } + BuildTrie(sorted, counts, config, *this, quant_, vocab, backing); } template class TrieSearch; diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 33ae8cff..5155ca02 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -7,6 +7,7 @@ #include "lm/trie.hh" #include "lm/weights.hh" +#include "util/file.hh" #include "util/file_piece.hh" #include @@ -20,7 +21,8 @@ class SortedVocabulary; namespace trie { template class TrieSearch; -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); +class SortedFiles; +template void BuildTrie(SortedFiles &files, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); template class TrieSearch { public: @@ -40,7 +42,7 @@ template class TrieSearch { static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); - AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); + util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); Bhiksha::UpdateConfigFromBinary(fd, config); } @@ -60,6 +62,8 @@ template class TrieSearch { void LoadedBinary(); + typedef const Middle *MiddleIter; + const Middle *MiddleBegin() const { return middle_begin_; } const Middle *MiddleEnd() const { return middle_end_; } @@ -108,7 +112,7 @@ template class TrieSearch { } private: - friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); + friend void BuildTrie(SortedFiles &files, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); // Middles are managed manually so we can delay construction and they don't have to be copyable. void FreeMiddles() { diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 06cc96ac..ebe9910f 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -1,7 +1,7 @@ #ifndef LM_TRIE__ #define LM_TRIE__ -#include +#include #include diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index bb126f18..b80fed02 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -22,14 +23,6 @@ namespace lm { namespace ngram { namespace trie { -const char *kContextSuffix = "_contexts"; - -FILE *OpenOrThrow(const char *name, const char *mode) { - FILE *ret = fopen(name, mode); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); - return ret; -} - void WriteOrThrow(FILE *to, const void *data, size_t size) { assert(size); if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); @@ -78,28 +71,29 @@ class PartialViewProxy { typedef util::ProxyIterator PartialIter; -std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order) { - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << '_' << batch; - std::string ret(assembled.str()); - util::scoped_fd out(util::CreateOrThrow(ret.c_str())); - util::WriteOrThrow(out.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); - return ret; +FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) { + util::scoped_fd file(maker.Make()); + util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); + return util::FDOpenOrThrow(file); } -void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { +FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) { const size_t context_size = sizeof(WordIndex) * (order - 1); // Sort just the contexts using the same memory. PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); - std::sort(context_begin, context_end, util::SizedCompare(EntryCompare(order - 1))); +#if defined(_WIN32) || defined(_WIN64) + std::stable_sort +#else + std::sort +#endif + (context_begin, context_end, util::SizedCompare(EntryCompare(order - 1))); - std::string name(ngram_file_name + kContextSuffix); - util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); + util::scoped_FILE out(maker.MakeFile()); // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. - if (context_begin == context_end) return; + if (context_begin == context_end) return out.release(); PartialIter i(context_begin); WriteOrThrow(out.get(), i->Data(), context_size); const void *previous = i->Data(); @@ -110,6 +104,7 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil previous = i->Data(); } } + return out.release(); } struct ThrowCombine { @@ -125,14 +120,12 @@ struct FirstCombine { } }; -template void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order, const Combine &combine = ThrowCombine()) { +template FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) { std::size_t entry_size = sizeof(WordIndex) * order + weights_size; RecordReader first, second; - first.Init(first_name.c_str(), entry_size); - util::RemoveOrThrow(first_name.c_str()); - second.Init(second_name.c_str(), entry_size); - util::RemoveOrThrow(second_name.c_str()); - util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); + first.Init(first_file, entry_size); + second.Init(second_file, entry_size); + util::scoped_FILE out_file(maker.MakeFile()); EntryCompare less(order); while (first && second) { if (less(first.Data(), second.Data())) { @@ -149,67 +142,14 @@ template void MergeSortedFiles(const std::string &first_name, co for (RecordReader &remains = (first ? first : second); remains; ++remains) { WriteOrThrow(out_file.get(), remains.Data(), entry_size); } -} - -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { - ReadNGramHeader(f, order); - const size_t count = counts[order - 1]; - // Size of weights. Does it include backoff? - const size_t words_size = sizeof(WordIndex) * order; - const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); - const size_t entry_size = words_size + weights_size; - const size_t batch_size = std::min(count, mem.size() / entry_size); - uint8_t *const begin = reinterpret_cast(mem.get()); - std::deque files; - for (std::size_t batch = 0, done = 0; done < count; ++batch) { - uint8_t *out = begin; - uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; - if (order == counts.size()) { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); - } - } else { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); - } - } - // Sort full records by full n-gram. - util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - // parallel_sort uses too much RAM - std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare(EntryCompare(order))); - files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order)); - WriteContextFile(begin, out_end, files.back(), entry_size, order); - - done += (out_end - begin) / entry_size; - } - - // All individual files created. Merge them. - - std::size_t merge_count = 0; - while (files.size() > 1) { - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); - files.push_back(assembled.str()); - MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); - MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order - 1, FirstCombine()); - files.pop_front(); - files.pop_front(); - } - if (!files.empty()) { - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << "_merged"; - std::string merged_name(assembled.str()); - if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); - std::string context_name = files[0] + kContextSuffix; - merged_name += kContextSuffix; - if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); - } + return out_file.release(); } } // namespace -void RecordReader::Init(const std::string &name, std::size_t entry_size) { - file_.reset(OpenOrThrow(name.c_str(), "r+")); +void RecordReader::Init(FILE *file, std::size_t entry_size) { + rewind(file); + file_ = file; data_.reset(malloc(entry_size)); UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer"); remains_ = true; @@ -219,20 +159,29 @@ void RecordReader::Init(const std::string &name, std::size_t entry_size) { void RecordReader::Overwrite(const void *start, std::size_t amount) { long internal = (uint8_t*)start - (uint8_t*)data_.get(); - UTIL_THROW_IF(fseek(file_.get(), internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); - WriteOrThrow(file_.get(), start, amount); + UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); + WriteOrThrow(file_, start, amount); long forward = entry_size_ - internal - amount; - if (forward) UTIL_THROW_IF(fseek(file_.get(), forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision"); +#if !defined(_WIN32) && !defined(_WIN64) + if (forward) +#endif + UTIL_THROW_IF(fseek(file_, forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision"); } -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { +void RecordReader::Rewind() { + rewind(file_); + remains_ = true; + ++*this; +} + +SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + util::TempMaker maker(file_prefix); PositiveProbWarn warn(config.positive_log_probability); + unigram_.reset(maker.Make()); { - std::string unigram_name = file_prefix + "unigrams"; - util::scoped_fd unigram_file; // In case appears. - size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); + size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_.get(), size_out), size_out); Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); CheckSpecials(config, vocab); if (!vocab.SawUnk()) ++counts[0]; @@ -246,16 +195,96 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector(buffer_use, static_cast((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); buffer = std::min(buffer, buffer_use); - util::scoped_memory mem; - mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); + util::scoped_malloc mem; + mem.reset(malloc(buffer)); if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); for (unsigned char order = 2; order <= counts.size(); ++order) { - ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); + ConvertToSorted(f, vocab, counts, maker, order, warn, mem.get(), buffer); } ReadEnd(f); } +namespace { +class Closer { + public: + explicit Closer(std::deque &files) : files_(files) {} + + ~Closer() { + for (std::deque::iterator i = files_.begin(); i != files_.end(); ++i) { + util::scoped_FILE deleter(*i); + } + } + + void PopFront() { + util::scoped_FILE deleter(files_.front()); + files_.pop_front(); + } + private: + std::deque &files_; +}; +} // namespace + +void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) { + ReadNGramHeader(f, order); + const size_t count = counts[order - 1]; + // Size of weights. Does it include backoff? + const size_t words_size = sizeof(WordIndex) * order; + const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); + const size_t entry_size = words_size + weights_size; + const size_t batch_size = std::min(count, mem_size / entry_size); + uint8_t *const begin = reinterpret_cast(mem); + + std::deque files, contexts; + Closer files_closer(files), contexts_closer(contexts); + + for (std::size_t batch = 0, done = 0; done < count; ++batch) { + uint8_t *out = begin; + uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; + if (order == counts.size()) { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); + } + } else { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); + } + } + // Sort full records by full n-gram. + util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); + // parallel_sort uses too much RAM. TODO: figure out why windows sort doesn't like my proxies. +#if defined(_WIN32) || defined(_WIN64) + std::stable_sort +#else + std::sort +#endif + (NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare(EntryCompare(order))); + files.push_back(DiskFlush(begin, out_end, maker)); + contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order)); + + done += (out_end - begin) / entry_size; + } + + // All individual files created. Merge them. + + while (files.size() > 1) { + files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine())); + files_closer.PopFront(); + files_closer.PopFront(); + contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine())); + contexts_closer.PopFront(); + contexts_closer.PopFront(); + } + + if (!files.empty()) { + // Steal from closers. + full_[order - 2].reset(files.front()); + files.pop_front(); + context_[order - 2].reset(contexts.front()); + contexts.pop_front(); + } +} + } // namespace trie } // namespace ngram } // namespace lm diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh index a6916483..3036319d 100644 --- a/klm/lm/trie_sort.hh +++ b/klm/lm/trie_sort.hh @@ -1,6 +1,9 @@ +// Step of trie builder: create sorted files. + #ifndef LM_TRIE_SORT__ #define LM_TRIE_SORT__ +#include "lm/max_order.hh" #include "lm/word_index.hh" #include "util/file.hh" @@ -11,20 +14,21 @@ #include #include -#include +#include -namespace util { class FilePiece; } +namespace util { +class FilePiece; +class TempMaker; +} // namespace util -// Step of trie builder: create sorted files. namespace lm { +class PositiveProbWarn; namespace ngram { class SortedVocabulary; class Config; namespace trie { -extern const char *kContextSuffix; -FILE *OpenOrThrow(const char *name, const char *mode); void WriteOrThrow(FILE *to, const void *data, size_t size); class EntryCompare : public std::binary_function { @@ -49,15 +53,15 @@ class RecordReader { public: RecordReader() : remains_(true) {} - void Init(const std::string &name, std::size_t entry_size); + void Init(FILE *file, std::size_t entry_size); void *Data() { return data_.get(); } const void *Data() const { return data_.get(); } RecordReader &operator++() { - std::size_t ret = fread(data_.get(), entry_size_, 1, file_.get()); + std::size_t ret = fread(data_.get(), entry_size_, 1, file_); if (!ret) { - UTIL_THROW_IF(!feof(file_.get()), util::ErrnoException, "Error reading temporary file"); + UTIL_THROW_IF(!feof(file_), util::ErrnoException, "Error reading temporary file"); remains_ = false; } return *this; @@ -65,27 +69,46 @@ class RecordReader { operator bool() const { return remains_; } - void Rewind() { - rewind(file_.get()); - remains_ = true; - ++*this; - } + void Rewind(); std::size_t EntrySize() const { return entry_size_; } void Overwrite(const void *start, std::size_t amount); private: + FILE *file_; + util::scoped_malloc data_; bool remains_; std::size_t entry_size_; - - util::scoped_FILE file_; }; -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab); +class SortedFiles { + public: + // Build from ARPA + SortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab); + + int StealUnigram() { + return unigram_.release(); + } + + FILE *Full(unsigned char order) { + return full_[order - 2].get(); + } + + FILE *Context(unsigned char of_order) { + return context_[of_order - 2].get(); + } + + private: + void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size); + + util::scoped_fd unigram_; + + util::scoped_FILE full_[kMaxOrder - 1], context_[kMaxOrder - 1]; +}; } // namespace trie } // namespace ngram diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index ffec41ca..9fd698bb 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -6,12 +6,15 @@ #include "lm/config.hh" #include "lm/weights.hh" #include "util/exception.hh" +#include "util/file.hh" #include "util/joint_sort.hh" #include "util/murmur_hash.hh" #include "util/probing_hash_table.hh" #include +#include + namespace lm { namespace ngram { @@ -29,23 +32,30 @@ const uint64_t kUnknownHash = detail::HashForVocab("", 5); // Sadly some LMs have . const uint64_t kUnknownCapHash = detail::HashForVocab("", 5); -WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { - if (!enumerate) return std::numeric_limits::max(); +void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) { + // Check that we're at the right place by reading which is always first. + char check_unk[6]; + util::ReadOrThrow(fd, check_unk, 6); + UTIL_THROW_IF( + memcmp(check_unk, "", 6), + FormatLoadException, + "Vocabulary words are in the wrong place. This could be because the binary file was built with stale gcc and old kenlm. Stale gcc, including the gcc distributed with RedHat and OS X, has a bug that ignores pragma pack for template-dependent types. New kenlm works around this, so you'll save memory but have to rebuild any binary files using the probing data structure."); + if (!enumerate) return; + enumerate->Add(0, ""); + + // Read all the words after unk. const std::size_t kInitialRead = 16384; std::string buf; buf.reserve(kInitialRead + 100); buf.resize(kInitialRead); - WordIndex index = 0; + WordIndex index = 1; // Read already. while (true) { - ssize_t got = read(fd, &buf[0], kInitialRead); - UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words"); - if (got == 0) return index; + std::size_t got = util::ReadOrEOF(fd, &buf[0], kInitialRead); + if (got == 0) break; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; - ssize_t ret = read(fd, &next_char, 1); - UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words"); - UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word."); + util::ReadOrThrow(fd, &next_char, 1); buf.push_back(next_char); } // Ok now we have null terminated strings. @@ -55,6 +65,8 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { i += length + 1 /* null byte */; } } + + UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file."); } } // namespace @@ -69,8 +81,7 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { } void WriteWordsWrapper::Write(int fd) { - if ((off_t)-1 == lseek(fd, 0, SEEK_END)) - UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words"); + util::SeekEnd(fd); util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); } @@ -114,8 +125,10 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { if (enumerate_) { - util::PairedIterator values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); - util::JointSort(begin_, end_, values); + if (!strings_to_enumerate_.empty()) { + util::PairedIterator values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); + util::JointSort(begin_, end_, values); + } for (WordIndex i = 0; i < static_cast(end_ - begin_); ++i) { // strikes again: +1 here. enumerate_->Add(i + 1, strings_to_enumerate_[i]); @@ -131,11 +144,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { bound_ = end_ - begin_ + 1; } -void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { +void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { end_ = begin_ + *(reinterpret_cast(begin_) - 1); - ReadWords(fd, to); SetSpecial(Index(""), Index(""), 0); bound_ = end_ - begin_ + 1; + if (have_words) ReadWords(fd, to, bound_); } namespace { @@ -153,12 +166,12 @@ struct ProbingVocabularyHeader { ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { - return Align8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); + return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); } void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { header_ = static_cast(start); - lookup_ = Lookup(static_cast(start) + Align8(sizeof(detail::ProbingVocabularyHeader)), allocated); + lookup_ = Lookup(static_cast(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated); bound_ = 1; saw_unk_ = false; } @@ -178,7 +191,7 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) { return 0; } else { if (enumerate_) enumerate_->Add(bound_, str); - lookup_.Insert(Lookup::Packing::Make(hashed, bound_)); + lookup_.Insert(ProbingVocabuaryEntry::Make(hashed, bound_)); return bound_++; } } @@ -190,12 +203,12 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { SetSpecial(Index(""), Index(""), 0); } -void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { +void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); lookup_.LoadedBinary(); - ReadWords(fd, to); bound_ = header_->bound; SetSpecial(Index(""), Index(""), 0); + if (have_words) ReadWords(fd, to, bound_); } void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 3c3414fb..06fdefe4 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -4,7 +4,6 @@ #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" -#include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" #include "util/string_piece.hh" @@ -83,7 +82,7 @@ class SortedVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } - void LoadedBinary(int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); private: uint64_t *begin_, *end_; @@ -100,6 +99,26 @@ class SortedVocabulary : public base::Vocabulary { std::vector strings_to_enumerate_; }; +#pragma pack(push) +#pragma pack(4) +struct ProbingVocabuaryEntry { + uint64_t key; + WordIndex value; + + typedef uint64_t Key; + uint64_t GetKey() const { + return key; + } + + static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) { + ProbingVocabuaryEntry ret; + ret.key = key; + ret.value = value; + return ret; + } +}; +#pragma pack(pop) + // Vocabulary storing a map from uint64_t to WordIndex. class ProbingVocabulary : public base::Vocabulary { public: @@ -107,7 +126,7 @@ class ProbingVocabulary : public base::Vocabulary { WordIndex Index(const StringPiece &str) const { Lookup::ConstIterator i; - return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0; + return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; } static size_t Size(std::size_t entries, const Config &config); @@ -124,17 +143,14 @@ class ProbingVocabulary : public base::Vocabulary { void FinishedLoading(ProbBackoff *reorder_vocab); + std::size_t UnkCountChangePadding() const { return 0; } + bool SawUnk() const { return saw_unk_; } - void LoadedBinary(int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); private: - // std::identity is an SGI extension :-( - struct IdentityHash : public std::unary_function { - std::size_t operator()(uint64_t arg) const { return static_cast(arg); } - }; - - typedef util::ProbingHashTable, IdentityHash> Lookup; + typedef util::ProbingHashTable Lookup; Lookup lookup_; -- cgit v1.2.3