diff options
author | Patrick Simianer <p@simianer.de> | 2012-03-13 09:24:47 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2012-03-13 09:24:47 +0100 |
commit | c3a9ea64251605532c7954959662643a6a927bb7 (patch) | |
tree | fed6048a5acdaf3834740107771c2bc48f26fd4d /klm | |
parent | 867bca3e5fa0cdd63bf032e5859fb5092d9a4ca1 (diff) | |
parent | a45af4a3704531a8382cd231f6445b3a33b598a3 (diff) |
merge with upstream
Diffstat (limited to 'klm')
50 files changed, 1381 insertions, 1029 deletions
diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc index bf86fd4b..cdeafb47 100644 --- a/klm/lm/bhiksha.cc +++ b/klm/lm/bhiksha.cc @@ -1,5 +1,6 @@ #include "lm/bhiksha.hh" #include "lm/config.hh" +#include "util/file.hh" #include <limits> @@ -12,12 +13,12 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_ const uint8_t kArrayBhikshaVersion = 0; +// TODO: put this in binary file header instead when I change the binary file format again. void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { uint8_t version; uint8_t configured_bits; - if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) { - UTIL_THROW(util::ErrnoException, "Could not read from binary file"); - } + util::ReadOrThrow(fd, &version, 1); + util::ReadOrThrow(fd, &configured_bits, 1); if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); config.pointer_bhiksha_bits = configured_bits; } diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index 3df43dda..5182ee2e 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -13,7 +13,7 @@ #ifndef LM_BHIKSHA__ #define LM_BHIKSHA__ -#include <inttypes.h> +#include <stdint.h> #include <assert.h> #include "lm/model_type.hh" diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 27cada13..4796f6d1 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -1,19 +1,15 @@ #include "lm/binary_format.hh" #include "lm/lm_exception.hh" +#include "util/file.hh" #include "util/file_piece.hh" +#include <cstddef> +#include <cstring> #include <limits> #include <string> -#include <fcntl.h> -#include <errno.h> -#include <stdlib.h> -#include <string.h> -#include <sys/mman.h> -#include <sys/types.h> -#include <sys/stat.h> -#include <unistd.h> +#include <stdint.h> namespace lm { namespace ngram { @@ -24,14 +20,16 @@ const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; const long int kMagicVersion = 5; -// Test values. -struct Sanity { +// Old binary files built on 32-bit machines have this header. +// TODO: eliminate with next binary release. +struct OldSanity { char magic[sizeof(kMagicBytes)]; float zero_f, one_f, minus_half_f; WordIndex one_word_index, max_word_index; uint64_t one_uint64; void SetToReference() { + std::memset(this, 0, sizeof(OldSanity)); std::memcpy(magic, kMagicBytes, sizeof(magic)); zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; one_word_index = 1; @@ -40,27 +38,35 @@ struct Sanity { } }; -const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; -std::size_t TotalHeaderSize(unsigned char order) { - return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); -} +// Test values aligned to 8 bytes. +struct Sanity { + char magic[ALIGN8(sizeof(kMagicBytes))]; + float zero_f, one_f, minus_half_f; + WordIndex one_word_index, max_word_index, padding_to_8; + uint64_t one_uint64; -void ReadLoop(int fd, void *to_void, std::size_t size) { - uint8_t *to = static_cast<uint8_t*>(to_void); - while (size) { - ssize_t ret = read(fd, to, size); - if (ret == -1) UTIL_THROW(util::ErrnoException, "Failed to read from binary file"); - if (ret == 0) UTIL_THROW(util::ErrnoException, "Binary file too short"); - to += ret; - size -= ret; + void SetToReference() { + std::memset(this, 0, sizeof(Sanity)); + std::memcpy(magic, kMagicBytes, sizeof(kMagicBytes)); + zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; + one_word_index = 1; + max_word_index = std::numeric_limits<WordIndex>::max(); + padding_to_8 = 0; + one_uint64 = 1; } +}; + +const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; + +std::size_t TotalHeaderSize(unsigned char order) { + return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); } void WriteHeader(void *to, const Parameters ¶ms) { Sanity header = Sanity(); header.SetToReference(); - memcpy(to, &header, sizeof(Sanity)); + std::memcpy(to, &header, sizeof(Sanity)); char *out = reinterpret_cast<char*>(to) + sizeof(Sanity); *reinterpret_cast<FixedWidthParameters*>(out) = params.fixed; @@ -74,14 +80,6 @@ void WriteHeader(void *to, const Parameters ¶ms) { } // namespace -void SeekOrThrow(int fd, off_t off) { - if ((off_t)-1 == lseek(fd, off, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed"); -} - -void AdvanceOrThrow(int fd, off_t off) { - if ((off_t)-1 == lseek(fd, off, SEEK_CUR)) UTIL_THROW(util::ErrnoException, "Seek failed"); -} - uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { if (config.write_mmap) { std::size_t total = TotalHeaderSize(order) + memory_size; @@ -89,7 +87,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order)); return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order); } else { - backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + util::MapAnonymous(memory_size, backing.vocab); return reinterpret_cast<uint8_t*>(backing.vocab.get()); } } @@ -98,42 +96,58 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { // Grow the file to accomodate the search, using zeros. - if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size)) - UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed"); + try { + util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size); + } catch (util::ErrnoException &e) { + e << " for file " << config.write_mmap; + throw e; + } + if (config.write_method == Config::WRITE_AFTER) { + util::MapAnonymous(memory_size, backing.search); + return reinterpret_cast<uint8_t*>(backing.search.get()); + } + // mmap it now. // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. - off_t page_size = sysconf(_SC_PAGE_SIZE); - off_t alignment_cruft = adjusted_vocab % page_size; + std::size_t page_size = util::SizePage(); + std::size_t alignment_cruft = adjusted_vocab % page_size; backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); - return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft; } else { - backing.search.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + util::MapAnonymous(memory_size, backing.search); return reinterpret_cast<uint8_t*>(backing.search.get()); } } -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing) { - if (config.write_mmap) { - if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC)) - UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap); - // header and vocab share the same mmap. The header is written here because we know the counts. - Parameters params; - params.counts = counts; - params.fixed.order = counts.size(); - params.fixed.probing_multiplier = config.probing_multiplier; - params.fixed.model_type = model_type; - params.fixed.has_vocabulary = config.include_vocab; - params.fixed.search_version = search_version; - WriteHeader(backing.vocab.get(), params); +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { + if (!config.write_mmap) return; + util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); + switch (config.write_method) { + case Config::WRITE_MMAP: + util::SyncOrThrow(backing.search.get(), backing.search.size()); + break; + case Config::WRITE_AFTER: + util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad); + util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size()); + util::FSyncOrThrow(backing.file.get()); + break; } + // header and vocab share the same mmap. The header is written here because we know the counts. + Parameters params = Parameters(); + params.counts = counts; + params.fixed.order = counts.size(); + params.fixed.probing_multiplier = config.probing_multiplier; + params.fixed.model_type = model_type; + params.fixed.has_vocabulary = config.include_vocab; + params.fixed.search_version = search_version; + WriteHeader(backing.vocab.get(), params); } namespace detail { bool IsBinaryFormat(int fd) { - const off_t size = util::SizeFile(fd); - if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(Sanity)))) return false; + const uint64_t size = util::SizeFile(fd); + if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false; // Try reading the header. util::scoped_memory memory; try { @@ -154,19 +168,23 @@ bool IsBinaryFormat(int fd) { if ((end_ptr != begin_version) && version != kMagicVersion) { UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary"); } + + OldSanity old_sanity = OldSanity(); + old_sanity.SetToReference(); + UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable."); UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture"); } return false; } void ReadHeader(int fd, Parameters &out) { - SeekOrThrow(fd, sizeof(Sanity)); - ReadLoop(fd, &out.fixed, sizeof(out.fixed)); + util::SeekOrThrow(fd, sizeof(Sanity)); + util::ReadOrThrow(fd, &out.fixed, sizeof(out.fixed)); if (out.fixed.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); out.counts.resize(static_cast<std::size_t>(out.fixed.order)); - ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); + if (out.fixed.order) util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); } void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms) { @@ -179,11 +197,11 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet } void SeekPastHeader(int fd, const Parameters ¶ms) { - SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); + util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); } uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { - const off_t file_size = util::SizeFile(backing.file.get()); + const uint64_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size; if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) @@ -194,9 +212,8 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t if (config.enumerate_vocab && !params.fixed.has_vocabulary) UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); - if (config.enumerate_vocab) { - SeekOrThrow(backing.file.get(), total_map); - } + // Seek to vocabulary words + util::SeekOrThrow(backing.file.get(), total_map); return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size()); } diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index e9df0892..dd795f62 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -12,7 +12,7 @@ #include <cstddef> #include <vector> -#include <inttypes.h> +#include <stdint.h> namespace lm { namespace ngram { @@ -33,10 +33,8 @@ struct FixedWidthParameters { unsigned int search_version; }; -inline std::size_t Align8(std::size_t in) { - std::size_t off = in % 8; - return off ? (in + 8 - off) : in; -} +// This is a macro instead of an inline function so constants can be assigned using it. +#define ALIGN8(a) ((std::ptrdiff_t(((a)-1)/8)+1)*8) // Parameters stored in the header of a binary file. struct Parameters { @@ -53,10 +51,6 @@ struct Backing { util::scoped_memory search; }; -void SeekOrThrow(int fd, off_t off); -// Seek forward -void AdvanceOrThrow(int fd, off_t off); - // Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. @@ -64,7 +58,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t // Write header to binary file. This is done last to prevent incomplete files // from loading. -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing); +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing); namespace detail { diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 2fb64cd0..4da81209 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -3,7 +3,7 @@ #include <limits> -#include <inttypes.h> +#include <stdint.h> #include <math.h> namespace lm { diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index fdb62a71..8cbb69d0 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -8,18 +8,24 @@ #include <math.h> #include <stdlib.h> -#include <unistd.h> + +#ifdef WIN32 +#include "util/getopt.hh" +#endif namespace lm { namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" "-u sets the log10 probability for <unk> if the ARPA file does not have one.\n" " Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have <s> and </s>.\n" -"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" +"-w mmap|after determines how writing is done.\n" +" mmap maps the binary file and writes to it. Default for trie.\n" +" after allocates anonymous memory, builds, and writes. Default for probing.\n\n" "type is either probing or trie. Default is probing.\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" @@ -55,7 +61,7 @@ uint8_t ParseBitCount(const char *from) { unsigned long val = ParseUInt(from); if (val > 25) { util::ParseNumberException e(from); - e << " bit counts are limited to 256."; + e << " bit counts are limited to 25."; } return val; } @@ -87,7 +93,7 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { prefix = 'G'; divide = 1 << 30; } - long int length = std::max<long int>(2, lrint(ceil(log10(max_length / divide)))); + long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide)))); std::cout << "Memory estimate:\ntype "; // right align bytes. for (long int i = 0; i < length - 2; ++i) std::cout << ' '; @@ -112,10 +118,10 @@ int main(int argc, char *argv[]) { using namespace lm::ngram; try { - bool quantize = false, set_backoff_bits = false, bhiksha = false; + bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false; lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) { + while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:si")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -129,6 +135,7 @@ int main(int argc, char *argv[]) { case 'a': config.pointer_bhiksha_bits = ParseBitCount(optarg); bhiksha = true; + break; case 'u': config.unknown_missing_logprob = ParseFloat(optarg); break; @@ -141,6 +148,16 @@ int main(int argc, char *argv[]) { case 'm': config.building_memory = ParseUInt(optarg) * 1048576; break; + case 'w': + set_write_method = true; + if (!strcmp(optarg, "mmap")) { + config.write_method = Config::WRITE_MMAP; + } else if (!strcmp(optarg, "after")) { + config.write_method = Config::WRITE_AFTER; + } else { + Usage(argv[0]); + } + break; case 's': config.sentence_marker_missing = lm::SILENT; break; @@ -166,9 +183,11 @@ int main(int argc, char *argv[]) { const char *from_file = argv[optind + 1]; config.write_mmap = argv[optind + 2]; if (!strcmp(model_type, "probing")) { + if (!set_write_method) config.write_method = Config::WRITE_AFTER; if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); ProbingModel(from_file, config); } else if (!strcmp(model_type, "trie")) { + if (!set_write_method) config.write_method = Config::WRITE_MMAP; if (quantize) { if (bhiksha) { QuantArrayTrieModel(from_file, config); @@ -191,7 +210,9 @@ int main(int argc, char *argv[]) { } catch (const std::exception &e) { std::cerr << e.what() << std::endl; + std::cerr << "ERROR" << std::endl; return 1; } + std::cerr << "SUCCESS" << std::endl; return 0; } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 297589a4..dbe762b3 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -17,6 +17,7 @@ Config::Config() : temporary_directory_prefix(NULL), arpa_complain(ALL), write_mmap(NULL), + write_method(WRITE_AFTER), include_vocab(true), prob_bits(8), backoff_bits(8), diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 8564661b..01b75632 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -70,9 +70,17 @@ struct Config { // to NULL to disable. const char *write_mmap; + typedef enum { + WRITE_MMAP, // Map the file directly. + WRITE_AFTER // Write after we're done. + } WriteMethod; + WriteMethod write_method; + // Include the vocab in the binary file? Only effective if write_mmap != NULL. bool include_vocab; + + // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used // to quantize (and one of the remaining backoffs will be 0). diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 41f71f84..a07f9803 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -112,7 +112,7 @@ inline size_t hash_value(const ChartState &state) { size_t hashes[2]; hashes[0] = hash_value(state.left); hashes[1] = hash_value(state.right); - return util::MurmurHashNative(hashes, sizeof(size_t), state.full); + return util::MurmurHashNative(hashes, sizeof(size_t) * 2, state.full); } template <class M> class RuleScore { diff --git a/klm/lm/left_test.cc b/klm/lm/left_test.cc index 8bb91cb3..c85e5efa 100644 --- a/klm/lm/left_test.cc +++ b/klm/lm/left_test.cc @@ -142,7 +142,7 @@ template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &wo template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vector<WordIndex> &out) { out.clear(); - for (util::PieceIterator<' '> i(str); i; ++i) { + for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) { out.push_back(m.GetVocabulary().Index(*i)); } } @@ -326,10 +326,17 @@ template <class M> void FullGrow(const M &m) { } } +const char *FileLocation() { + if (boost::unit_test::framework::master_test_suite().argc < 2) { + return "test.arpa"; + } + return boost::unit_test::framework::master_test_suite().argv[1]; +} + template <class M> void Everything() { Config config; config.messages = NULL; - M m("test.arpa", config); + M m(FileLocation(), config); Short(m); Charge(m); diff --git a/klm/lm/model.cc b/klm/lm/model.cc index e4c1ec1d..478ebed1 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -46,7 +46,7 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { SetupMemory(start, params.counts, config); - vocab_.LoadedBinary(fd, config.enumerate_vocab); + vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab); search_.LoadedBinary(); } @@ -82,13 +82,18 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.unigram.Unknown().backoff = 0.0; search_.unigram.Unknown().prob = config.unknown_missing_logprob; } - FinishFile(config, kModelType, kVersion, counts, backing_); + FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_); } catch (util::Exception &e) { e << " Byte: " << f.Offset(); throw; } } +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { + util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); + Search::UpdateConfigFromBinary(fd, counts, config); +} + template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) { @@ -114,7 +119,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, } float backoff; // i is the order of the backoff we're looking for. - const Middle *mid_iter = search_.MiddleBegin() + start - 2; + typename Search::MiddleIter mid_iter = search_.MiddleBegin() + start - 2; for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) { if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break; ret.prob += backoff; @@ -134,7 +139,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.LookupUnigram(*context_rbegin, out_state.backoff[0], node, ignored); out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; float *backoff_out = out_state.backoff + 1; - const typename Search::Middle *mid = search_.MiddleBegin(); + typename Search::MiddleIter mid(search_.MiddleBegin()); for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) { if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) { std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words); @@ -161,7 +166,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, // If this function is called, then it does depend on left words. ret.independent_left = false; ret.extend_left = extend_pointer; - const typename Search::Middle *mid_iter = search_.MiddleBegin() + extend_length - 1; + typename Search::MiddleIter mid_iter(search_.MiddleBegin() + extend_length - 1); const WordIndex *i = add_rbegin; for (; ; ++i, ++backoff_out, ++mid_iter) { if (i == add_rend) { @@ -230,7 +235,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, // Ok start by looking up the bigram. const WordIndex *hist_iter = context_rbegin; - const typename Search::Middle *mid_iter = search_.MiddleBegin(); + typename Search::MiddleIter mid_iter(search_.MiddleBegin()); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { // Ran out of history. Typically no backoff, but this could be a blank. diff --git a/klm/lm/model.hh b/klm/lm/model.hh index c278acd6..6ea62a78 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -90,7 +90,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod * TrieModel. To classify binary files, call RecognizeBinary in * lm/binary_format.hh. */ - GenericModel(const char *file, const Config &config = Config()); + explicit GenericModel(const char *file, const Config &config = Config()); /* Score p(new_word | in_state) and incorporate new_word into out_state. * Note that in_state and out_state must be different references: @@ -137,14 +137,9 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod unsigned char &next_use) const; private: - friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); + friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); - static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { - AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); - Search::UpdateConfigFromBinary(fd, counts, config); - } - - float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const; + static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 2654071f..461704d4 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -19,6 +19,20 @@ std::ostream &operator<<(std::ostream &o, const State &state) { namespace { +const char *TestLocation() { + if (boost::unit_test::framework::master_test_suite().argc < 2) { + return "test.arpa"; + } + return boost::unit_test::framework::master_test_suite().argv[1]; +} +const char *TestNoUnkLocation() { + if (boost::unit_test::framework::master_test_suite().argc < 3) { + return "test_nounk.arpa"; + } + return boost::unit_test::framework::master_test_suite().argv[2]; + +} + #define StartTest(word, ngram, score, indep_left) \ ret = model.FullScore( \ state, \ @@ -307,7 +321,7 @@ template <class ModelT> void LoadingTest() { { ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; - ModelT m("test.arpa", config); + ModelT m(TestLocation(), config); enumerate.Check(m.GetVocabulary()); BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); Everything(m); @@ -315,7 +329,7 @@ template <class ModelT> void LoadingTest() { { ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; - ModelT m("test_nounk.arpa", config); + ModelT m(TestNoUnkLocation(), config); enumerate.Check(m.GetVocabulary()); BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); NoUnkCheck(m); @@ -346,7 +360,7 @@ template <class ModelT> void BinaryTest() { config.enumerate_vocab = &enumerate; { - ModelT copy_model("test.arpa", config); + ModelT copy_model(TestLocation(), config); enumerate.Check(copy_model.GetVocabulary()); enumerate.Clear(); Everything(copy_model); @@ -370,14 +384,14 @@ template <class ModelT> void BinaryTest() { config.messages = NULL; enumerate.Clear(); { - ModelT copy_model("test_nounk.arpa", config); + ModelT copy_model(TestNoUnkLocation(), config); enumerate.Check(copy_model.GetVocabulary()); enumerate.Clear(); NoUnkCheck(copy_model); } config.write_mmap = NULL; { - ModelT binary("test_nounk.binary", config); + ModelT binary(TestNoUnkLocation(), config); enumerate.Check(binary.GetVocabulary()); NoUnkCheck(binary); } diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index d9db4aa2..8f7a0e1c 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -1,87 +1,4 @@ -#include "lm/enumerate_vocab.hh" -#include "lm/model.hh" - -#include <cstdlib> -#include <fstream> -#include <iostream> -#include <string> - -#include <ctype.h> - -#include <sys/resource.h> -#include <sys/time.h> - -float FloatSec(const struct timeval &tv) { - return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000000.0); -} - -void PrintUsage(const char *message) { - struct rusage usage; - if (getrusage(RUSAGE_SELF, &usage)) { - perror("getrusage"); - return; - } - std::cerr << message; - std::cerr << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n'; - - // Linux doesn't set memory usage :-(. - std::ifstream status("/proc/self/status", std::ios::in); - std::string line; - while (getline(status, line)) { - if (!strncmp(line.c_str(), "VmRSS:\t", 7)) { - std::cerr << "rss " << (line.c_str() + 7) << '\n'; - break; - } - } -} - -template <class Model> void Query(const Model &model, bool sentence_context) { - PrintUsage("Loading statistics:\n"); - typename Model::State state, out; - lm::FullScoreReturn ret; - std::string word; - - while (std::cin) { - state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); - float total = 0.0; - bool got = false; - unsigned int oov = 0; - while (std::cin >> word) { - got = true; - lm::WordIndex vocab = model.GetVocabulary().Index(word); - if (vocab == 0) ++oov; - ret = model.FullScore(state, vocab, out); - total += ret.prob; - std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; - state = out; - char c; - while (true) { - c = std::cin.get(); - if (!std::cin) break; - if (c == '\n') break; - if (!isspace(c)) { - std::cin.unget(); - break; - } - } - if (c == '\n') break; - } - if (!got && !std::cin) break; - if (sentence_context) { - ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); - total += ret.prob; - std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; - } - std::cout << "Total: " << total << " OOV: " << oov << '\n'; - } - PrintUsage("After queries:\n"); -} - -template <class Model> void Query(const char *name) { - lm::ngram::Config config; - Model model(name, config); - Query(model); -} +#include "lm/ngram_query.hh" int main(int argc, char *argv[]) { if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) { @@ -89,34 +6,40 @@ int main(int argc, char *argv[]) { std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl; return 1; } - bool sentence_context = (argc == 2); - lm::ngram::ModelType model_type; - if (lm::ngram::RecognizeBinary(argv[1], model_type)) { - switch(model_type) { - case lm::ngram::HASH_PROBING: - Query<lm::ngram::ProbingModel>(argv[1], sentence_context); - break; - case lm::ngram::TRIE_SORTED: - Query<lm::ngram::TrieModel>(argv[1], sentence_context); - break; - case lm::ngram::QUANT_TRIE_SORTED: - Query<lm::ngram::QuantTrieModel>(argv[1], sentence_context); - break; - case lm::ngram::ARRAY_TRIE_SORTED: - Query<lm::ngram::ArrayTrieModel>(argv[1], sentence_context); - break; - case lm::ngram::QUANT_ARRAY_TRIE_SORTED: - Query<lm::ngram::QuantArrayTrieModel>(argv[1], sentence_context); - break; - case lm::ngram::HASH_SORTED: - default: - std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; - abort(); + try { + bool sentence_context = (argc == 2); + using namespace lm::ngram; + ModelType model_type; + if (RecognizeBinary(argv[1], model_type)) { + switch(model_type) { + case HASH_PROBING: + Query<lm::ngram::ProbingModel>(argv[1], sentence_context, std::cin, std::cout); + break; + case TRIE_SORTED: + Query<TrieModel>(argv[1], sentence_context, std::cin, std::cout); + break; + case QUANT_TRIE_SORTED: + Query<QuantTrieModel>(argv[1], sentence_context, std::cin, std::cout); + break; + case ARRAY_TRIE_SORTED: + Query<ArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout); + break; + case QUANT_ARRAY_TRIE_SORTED: + Query<QuantArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout); + break; + case HASH_SORTED: + default: + std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; + abort(); + } + } else { + Query<ProbingModel>(argv[1], sentence_context, std::cin, std::cout); } - } else { - Query<lm::ngram::ProbingModel>(argv[1], sentence_context); - } - PrintUsage("Total time including destruction:\n"); + PrintUsage("Total time including destruction:\n"); + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; + } return 0; } diff --git a/klm/lm/ngram_query.hh b/klm/lm/ngram_query.hh new file mode 100644 index 00000000..4990df22 --- /dev/null +++ b/klm/lm/ngram_query.hh @@ -0,0 +1,103 @@ +#ifndef LM_NGRAM_QUERY__ +#define LM_NGRAM_QUERY__ + +#include "lm/enumerate_vocab.hh" +#include "lm/model.hh" + +#include <cstdlib> +#include <fstream> +#include <iostream> +#include <string> + +#include <ctype.h> +#if !defined(_WIN32) && !defined(_WIN64) +#include <sys/resource.h> +#include <sys/time.h> +#endif + +namespace lm { +namespace ngram { + +#if !defined(_WIN32) && !defined(_WIN64) +float FloatSec(const struct timeval &tv) { + return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000000.0); +} +#endif + +void PrintUsage(const char *message) { +#if !defined(_WIN32) && !defined(_WIN64) + struct rusage usage; + if (getrusage(RUSAGE_SELF, &usage)) { + perror("getrusage"); + return; + } + std::cerr << message; + std::cerr << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n'; + + // Linux doesn't set memory usage :-(. + std::ifstream status("/proc/self/status", std::ios::in); + std::string line; + while (getline(status, line)) { + if (!strncmp(line.c_str(), "VmRSS:\t", 7)) { + std::cerr << "rss " << (line.c_str() + 7) << '\n'; + break; + } + } +#endif +} + +template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { + PrintUsage("Loading statistics:\n"); + typename Model::State state, out; + lm::FullScoreReturn ret; + std::string word; + + while (in_stream) { + state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); + float total = 0.0; + bool got = false; + unsigned int oov = 0; + while (in_stream >> word) { + got = true; + lm::WordIndex vocab = model.GetVocabulary().Index(word); + if (vocab == 0) ++oov; + ret = model.FullScore(state, vocab, out); + total += ret.prob; + out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; + state = out; + char c; + while (true) { + c = in_stream.get(); + if (!in_stream) break; + if (c == '\n') break; + if (!isspace(c)) { + in_stream.unget(); + break; + } + } + if (c == '\n') break; + } + if (!got && !in_stream) break; + if (sentence_context) { + ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); + total += ret.prob; + out_stream << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; + } + out_stream << "Total: " << total << " OOV: " << oov << '\n'; + } + PrintUsage("After queries:\n"); +} + +template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { + Config config; +// config.load_method = util::LAZY; + M model(file, config); + Query(model, sentence_context, in_stream, out_stream); +} + +} // namespace ngram +} // namespace lm + +#endif // LM_NGRAM_QUERY__ + + diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index 98a5d048..a8e0cb21 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -1,31 +1,30 @@ +/* Quantize into bins of equal size as described in + * M. Federico and N. Bertoldi. 2006. How many bits are needed + * to store probabilities for phrase-based translation? In Proc. + * of the Workshop on Statistical Machine Translation, pages + * 94–101, New York City, June. Association for Computa- + * tional Linguistics. + */ + #include "lm/quantize.hh" #include "lm/binary_format.hh" #include "lm/lm_exception.hh" +#include "util/file.hh" #include <algorithm> #include <numeric> -#include <unistd.h> - namespace lm { namespace ngram { -/* Quantize into bins of equal size as described in - * M. Federico and N. Bertoldi. 2006. How many bits are needed - * to store probabilities for phrase-based translation? In Proc. - * of the Workshop on Statistical Machine Translation, pages - * 94–101, New York City, June. Association for Computa- - * tional Linguistics. - */ - namespace { -void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) { - std::sort(values, values_end); - const float *start = values, *finish; +void MakeBins(std::vector<float> &values, float *centers, uint32_t bins) { + std::sort(values.begin(), values.end()); + std::vector<float>::const_iterator start = values.begin(), finish; for (uint32_t i = 0; i < bins; ++i, ++centers, start = finish) { - finish = values + (((values_end - values) * static_cast<uint64_t>(i + 1)) / bins); + finish = values.begin() + ((values.size() * static_cast<uint64_t>(i + 1)) / bins); if (finish == start) { // zero length bucket. *centers = i ? *(centers - 1) : -std::numeric_limits<float>::infinity(); @@ -41,10 +40,11 @@ const char kSeparatelyQuantizeVersion = 2; void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) { char version; - if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1) - UTIL_THROW(util::ErrnoException, "Failed to read header for quantization."); + util::ReadOrThrow(fd, &version, 1); + util::ReadOrThrow(fd, &config.prob_bits, 1); + util::ReadOrThrow(fd, &config.backoff_bits, 1); if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); - AdvanceOrThrow(fd, -3); + util::AdvanceOrThrow(fd, -3); } void SeparatelyQuantize::SetupMemory(void *start, const Config &config) { @@ -66,12 +66,12 @@ void SeparatelyQuantize::Train(uint8_t order, std::vector<float> &prob, std::vec float *centers = start_ + TableStart(order) + ProbTableLength(); *(centers++) = kNoExtensionBackoff; *(centers++) = kExtensionBackoff; - MakeBins(&*backoff.begin(), &*backoff.end(), centers, (1ULL << backoff_bits_) - 2); + MakeBins(backoff, centers, (1ULL << backoff_bits_) - 2); } void SeparatelyQuantize::TrainProb(uint8_t order, std::vector<float> &prob) { float *centers = start_ + TableStart(order); - MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_)); + MakeBins(prob, centers, (1ULL << prob_bits_)); } void SeparatelyQuantize::FinishedLoading(const Config &config) { diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 4cf4236e..6d130a57 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -9,7 +9,7 @@ #include <algorithm> #include <vector> -#include <inttypes.h> +#include <stdint.h> #include <iostream> diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index dce73f77..05f761be 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -8,7 +8,7 @@ #include <ctype.h> #include <string.h> -#include <inttypes.h> +#include <stdint.h> namespace lm { diff --git a/klm/lm/return.hh b/klm/lm/return.hh index 15571960..1b55091b 100644 --- a/klm/lm/return.hh +++ b/klm/lm/return.hh @@ -1,7 +1,7 @@ #ifndef LM_RETURN__ #define LM_RETURN__ -#include <inttypes.h> +#include <stdint.h> namespace lm { /* Structure returned by scoring routines. */ diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 247832b0..1d6fb5be 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -30,7 +30,7 @@ template <class Middle> class ActivateLowerMiddle { // TODO: somehow get text of n-gram for this error message. if (!modify_.UnsafeMutableFind(hash, i)) UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram"); - SetExtension(i->MutableValue().backoff); + SetExtension(i->value.backoff); } private: @@ -65,7 +65,7 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign blank.prob -= unigrams[vocab_ids[1]].backoff; SetExtension(unigrams[vocab_ids[1]].backoff); // Bigram including a unigram's backoff - middle[0].Insert(Middle::Packing::Make(keys[0], blank)); + middle[0].Insert(detail::ProbBackoffEntry::Make(keys[0], blank)); fix = 1; } else { for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]); @@ -74,22 +74,24 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign for (; fix <= n - 3; ++fix) { typename Middle::MutableIterator gotit; if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) { - float &backoff = gotit->MutableValue().backoff; + float &backoff = gotit->value.backoff; SetExtension(backoff); blank.prob -= backoff; } - middle[fix].Insert(Middle::Packing::Make(keys[fix], blank)); + middle[fix].Insert(detail::ProbBackoffEntry::Make(keys[fix], blank)); backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]); } } template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) { + assert(n >= 2); ReadNGramHeader(f, n); - // vocab ids of words in reverse order + // Both vocab_ids and keys are non-empty because n >= 2. + // vocab ids of words in reverse order. std::vector<WordIndex> vocab_ids(n); std::vector<uint64_t> keys(n-1); - typename Store::Packing::Value value; + typename Store::Entry::Value value; typename Middle::MutableIterator found; for (size_t i = 0; i < count; ++i) { ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn); @@ -100,7 +102,7 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams( } // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. util::SetSign(value.prob); - store.Insert(Store::Packing::Make(keys[n-2], value)); + store.Insert(Store::Entry::Make(keys[n-2], value)); // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. int lower; util::FloatEnc fix_prob; @@ -113,9 +115,9 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams( } if (middle[lower].UnsafeMutableFind(keys[lower], found)) { // Turn off sign bit to indicate that it extends left. - fix_prob.f = found->MutableValue().prob; + fix_prob.f = found->value.prob; fix_prob.i &= ~util::kSignBit; - found->MutableValue().prob = fix_prob.f; + found->value.prob = fix_prob.f; // We don't need to recurse further down because this entry already set the bits for lower entries. break; } @@ -147,7 +149,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT, template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. - SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config); + SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config); PositiveProbWarn warn(config.positive_log_probability); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index e289fd11..4352c72d 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -8,7 +8,6 @@ #include "lm/weights.hh" #include "util/bit_packing.hh" -#include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" #include <algorithm> @@ -92,8 +91,10 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing); - const Middle *MiddleBegin() const { return &*middle_.begin(); } - const Middle *MiddleEnd() const { return &*middle_.end(); } + typedef typename std::vector<Middle>::const_iterator MiddleIter; + + MiddleIter MiddleBegin() const { return middle_.begin(); } + MiddleIter MiddleEnd() const { return middle_.end(); } Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { util::FloatEnc val; @@ -105,7 +106,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl; abort(); } - val.f = found->GetValue().prob; + val.f = found->value.prob; } val.i |= util::kSignBit; prob = val.f; @@ -117,12 +118,12 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has typename Middle::ConstIterator found; if (!middle.Find(node, found)) return false; util::FloatEnc enc; - enc.f = found->GetValue().prob; + enc.f = found->value.prob; ret.independent_left = (enc.i & util::kSignBit); ret.extend_left = node; enc.i |= util::kSignBit; ret.prob = enc.f; - backoff = found->GetValue().backoff; + backoff = found->value.backoff; return true; } @@ -132,7 +133,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has node = CombineWordHash(node, word); typename Middle::ConstIterator found; if (!middle.Find(node, found)) return false; - backoff = found->GetValue().backoff; + backoff = found->value.backoff; return true; } @@ -141,7 +142,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has node = CombineWordHash(node, word); typename Longest::ConstIterator found; if (!longest.Find(node, found)) return false; - prob = found->GetValue().prob; + prob = found->value.prob; return true; } @@ -160,14 +161,50 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has std::vector<Middle> middle_; }; -// std::identity is an SGI extension :-( -struct IdentityHash : public std::unary_function<uint64_t, size_t> { - size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); } +/* These look like perfect candidates for a template, right? Ancient gcc (4.1 + * on RedHat stale linux) doesn't pack templates correctly. ProbBackoffEntry + * is a multiple of 8 bytes anyway. ProbEntry is 12 bytes so it's set to pack. + */ +struct ProbBackoffEntry { + uint64_t key; + ProbBackoff value; + typedef uint64_t Key; + typedef ProbBackoff Value; + uint64_t GetKey() const { + return key; + } + static ProbBackoffEntry Make(uint64_t key, ProbBackoff value) { + ProbBackoffEntry ret; + ret.key = key; + ret.value = value; + return ret; + } }; +#pragma pack(push) +#pragma pack(4) +struct ProbEntry { + uint64_t key; + Prob value; + typedef uint64_t Key; + typedef Prob Value; + uint64_t GetKey() const { + return key; + } + static ProbEntry Make(uint64_t key, Prob value) { + ProbEntry ret; + ret.key = key; + ret.value = value; + return ret; + } +}; + +#pragma pack(pop) + + struct ProbingHashedSearch : public TemplateHashedSearch< - util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>, - util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > { + util::ProbingHashTable<ProbBackoffEntry, util::IdentityHash>, + util::ProbingHashTable<ProbEntry, util::IdentityHash> > { static const ModelType kModelType = HASH_PROBING; }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 4bd3f4ee..ffadfa94 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -13,6 +13,7 @@ #include "lm/weights.hh" #include "lm/word_index.hh" #include "util/ersatz_progress.hh" +#include "util/mmap.hh" #include "util/proxy_iterator.hh" #include "util/scoped.hh" #include "util/sized_iterator.hh" @@ -20,14 +21,15 @@ #include <algorithm> #include <cstring> #include <cstdio> +#include <cstdlib> #include <queue> #include <limits> #include <numeric> #include <vector> -#include <sys/mman.h> -#include <sys/types.h> -#include <sys/stat.h> +#if defined(_WIN32) || defined(_WIN64) +#include <windows.h> +#endif namespace lm { namespace ngram { @@ -195,7 +197,7 @@ class SRISucks { void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) { for (unsigned char i = 0; i < kMaxOrder - 1; ++i) { - it_[i] = &*values_[i].begin(); + it_[i] = values_[i].empty() ? NULL : &*values_[i].begin(); } messages_[0].Apply(it_, unigram_file); BackoffMessages *messages = messages_ + 1; @@ -227,8 +229,8 @@ class SRISucks { class FindBlanks { public: - FindBlanks(uint64_t *counts, unsigned char order, const ProbBackoff *unigrams, SRISucks &messages) - : counts_(counts), longest_counts_(counts + order - 1), unigrams_(unigrams), sri_(messages) {} + FindBlanks(unsigned char order, const ProbBackoff *unigrams, SRISucks &messages) + : counts_(order), unigrams_(unigrams), sri_(messages) {} float UnigramProb(WordIndex index) const { return unigrams_[index].prob; @@ -248,7 +250,7 @@ class FindBlanks { } void Longest(const void * /*data*/) { - ++*longest_counts_; + ++counts_.back(); } // Unigrams wrote one past. @@ -256,8 +258,12 @@ class FindBlanks { --counts_[0]; } + const std::vector<uint64_t> &Counts() const { + return counts_; + } + private: - uint64_t *const counts_, *const longest_counts_; + std::vector<uint64_t> counts_; const ProbBackoff *unigrams_; @@ -375,7 +381,7 @@ template <class Doing> class BlankManager { template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) { util::ErsatzProgress progress(progress_out, message, unigram_count + 1); - unsigned int unigram = 0; + WordIndex unigram = 0; std::priority_queue<Gram> grams; grams.push(Gram(&unigram, 1)); for (unsigned char i = 2; i <= total_order; ++i) { @@ -461,42 +467,33 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c } // namespace -template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { RecordReader inputs[kMaxOrder - 1]; RecordReader contexts[kMaxOrder - 1]; for (unsigned char i = 2; i <= counts.size(); ++i) { - std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(i) << "_merged"; - inputs[i-2].Init(assembled.str(), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff))); - util::RemoveOrThrow(assembled.str().c_str()); - assembled << kContextSuffix; - contexts[i-2].Init(assembled.str(), (i-1) * sizeof(WordIndex)); - util::RemoveOrThrow(assembled.str().c_str()); + inputs[i-2].Init(files.Full(i), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff))); + contexts[i-2].Init(files.Context(i), (i-1) * sizeof(WordIndex)); } SRISucks sri; - std::vector<uint64_t> fixed_counts(counts.size()); + std::vector<uint64_t> fixed_counts; + util::scoped_FILE unigram_file; + util::scoped_fd unigram_fd(files.StealUnigram()); { - std::string temp(file_prefix); temp += "unigrams"; - util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str())); util::scoped_memory unigrams; - MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); - FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); + MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); + FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); + fixed_counts = finder.Counts(); } + unigram_file.reset(util::FDOpenOrThrow(unigram_fd)); for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) { if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - util::scoped_FILE unigram_file; - { - std::string name(file_prefix + "unigrams"); - unigram_file.reset(OpenOrThrow(name.c_str(), "r+")); - util::RemoveOrThrow(name.c_str()); - } sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); @@ -587,42 +584,19 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBin longest.LoadedBinary(); } -namespace { -bool IsDirectory(const char *path) { - struct stat info; - if (0 != stat(path, &info)) return false; - return S_ISDIR(info.st_mode); -} -} // namespace - template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { - std::string temporary_directory; + std::string temporary_prefix; if (config.temporary_directory_prefix) { - temporary_directory = config.temporary_directory_prefix; - if (!temporary_directory.empty() && temporary_directory[temporary_directory.size() - 1] != '/' && IsDirectory(temporary_directory.c_str())) - temporary_directory += '/'; + temporary_prefix = config.temporary_directory_prefix; } else if (config.write_mmap) { - temporary_directory = config.write_mmap; + temporary_prefix = config.write_mmap; } else { - temporary_directory = file; - } - // Null on end is kludge to ensure null termination. - temporary_directory += "_trie_tmp_XXXXXX"; - temporary_directory += '\0'; - if (!mkdtemp(&temporary_directory[0])) { - UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str()); + temporary_prefix = file; } - // Chop off null kludge. - temporary_directory.resize(strlen(temporary_directory.c_str())); - // Add directory delimiter. Assumes a real operating system. - temporary_directory += '/'; // At least 1MB sorting memory. - ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); + SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab); - BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing); - if (rmdir(temporary_directory.c_str()) && config.messages) { - *config.messages << "Failed to delete " << temporary_directory << std::endl; - } + BuildTrie(sorted, counts, config, *this, quant_, vocab, backing); } template class TrieSearch<DontQuantize, DontBhiksha>; diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 33ae8cff..5155ca02 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -7,6 +7,7 @@ #include "lm/trie.hh" #include "lm/weights.hh" +#include "util/file.hh" #include "util/file_piece.hh" #include <vector> @@ -20,7 +21,8 @@ class SortedVocabulary; namespace trie { template <class Quant, class Bhiksha> class TrieSearch; -template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); +class SortedFiles; +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); template <class Quant, class Bhiksha> class TrieSearch { public: @@ -40,7 +42,7 @@ template <class Quant, class Bhiksha> class TrieSearch { static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); - AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); + util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); Bhiksha::UpdateConfigFromBinary(fd, config); } @@ -60,6 +62,8 @@ template <class Quant, class Bhiksha> class TrieSearch { void LoadedBinary(); + typedef const Middle *MiddleIter; + const Middle *MiddleBegin() const { return middle_begin_; } const Middle *MiddleEnd() const { return middle_end_; } @@ -108,7 +112,7 @@ template <class Quant, class Bhiksha> class TrieSearch { } private: - friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); + friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); // Middles are managed manually so we can delay construction and they don't have to be copyable. void FreeMiddles() { diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 06cc96ac..ebe9910f 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -1,7 +1,7 @@ #ifndef LM_TRIE__ #define LM_TRIE__ -#include <inttypes.h> +#include <stdint.h> #include <cstddef> diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index bb126f18..b80fed02 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -14,6 +14,7 @@ #include <algorithm> #include <cstring> #include <cstdio> +#include <cstdlib> #include <deque> #include <limits> #include <vector> @@ -22,14 +23,6 @@ namespace lm { namespace ngram { namespace trie { -const char *kContextSuffix = "_contexts"; - -FILE *OpenOrThrow(const char *name, const char *mode) { - FILE *ret = fopen(name, mode); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); - return ret; -} - void WriteOrThrow(FILE *to, const void *data, size_t size) { assert(size); if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); @@ -78,28 +71,29 @@ class PartialViewProxy { typedef util::ProxyIterator<PartialViewProxy> PartialIter; -std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order) { - std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch; - std::string ret(assembled.str()); - util::scoped_fd out(util::CreateOrThrow(ret.c_str())); - util::WriteOrThrow(out.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); - return ret; +FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) { + util::scoped_fd file(maker.Make()); + util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); + return util::FDOpenOrThrow(file); } -void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { +FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) { const size_t context_size = sizeof(WordIndex) * (order - 1); // Sort just the contexts using the same memory. PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); - std::sort(context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1))); +#if defined(_WIN32) || defined(_WIN64) + std::stable_sort +#else + std::sort +#endif + (context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1))); - std::string name(ngram_file_name + kContextSuffix); - util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); + util::scoped_FILE out(maker.MakeFile()); // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. - if (context_begin == context_end) return; + if (context_begin == context_end) return out.release(); PartialIter i(context_begin); WriteOrThrow(out.get(), i->Data(), context_size); const void *previous = i->Data(); @@ -110,6 +104,7 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil previous = i->Data(); } } + return out.release(); } struct ThrowCombine { @@ -125,14 +120,12 @@ struct FirstCombine { } }; -template <class Combine> void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order, const Combine &combine = ThrowCombine()) { +template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) { std::size_t entry_size = sizeof(WordIndex) * order + weights_size; RecordReader first, second; - first.Init(first_name.c_str(), entry_size); - util::RemoveOrThrow(first_name.c_str()); - second.Init(second_name.c_str(), entry_size); - util::RemoveOrThrow(second_name.c_str()); - util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); + first.Init(first_file, entry_size); + second.Init(second_file, entry_size); + util::scoped_FILE out_file(maker.MakeFile()); EntryCompare less(order); while (first && second) { if (less(first.Data(), second.Data())) { @@ -149,67 +142,14 @@ template <class Combine> void MergeSortedFiles(const std::string &first_name, co for (RecordReader &remains = (first ? first : second); remains; ++remains) { WriteOrThrow(out_file.get(), remains.Data(), entry_size); } -} - -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { - ReadNGramHeader(f, order); - const size_t count = counts[order - 1]; - // Size of weights. Does it include backoff? - const size_t words_size = sizeof(WordIndex) * order; - const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); - const size_t entry_size = words_size + weights_size; - const size_t batch_size = std::min(count, mem.size() / entry_size); - uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get()); - std::deque<std::string> files; - for (std::size_t batch = 0, done = 0; done < count; ++batch) { - uint8_t *out = begin; - uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; - if (order == counts.size()) { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn); - } - } else { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn); - } - } - // Sort full records by full n-gram. - util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - // parallel_sort uses too much RAM - std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order))); - files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order)); - WriteContextFile(begin, out_end, files.back(), entry_size, order); - - done += (out_end - begin) / entry_size; - } - - // All individual files created. Merge them. - - std::size_t merge_count = 0; - while (files.size() > 1) { - std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++); - files.push_back(assembled.str()); - MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); - MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order - 1, FirstCombine()); - files.pop_front(); - files.pop_front(); - } - if (!files.empty()) { - std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(order) << "_merged"; - std::string merged_name(assembled.str()); - if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); - std::string context_name = files[0] + kContextSuffix; - merged_name += kContextSuffix; - if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); - } + return out_file.release(); } } // namespace -void RecordReader::Init(const std::string &name, std::size_t entry_size) { - file_.reset(OpenOrThrow(name.c_str(), "r+")); +void RecordReader::Init(FILE *file, std::size_t entry_size) { + rewind(file); + file_ = file; data_.reset(malloc(entry_size)); UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer"); remains_ = true; @@ -219,20 +159,29 @@ void RecordReader::Init(const std::string &name, std::size_t entry_size) { void RecordReader::Overwrite(const void *start, std::size_t amount) { long internal = (uint8_t*)start - (uint8_t*)data_.get(); - UTIL_THROW_IF(fseek(file_.get(), internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); - WriteOrThrow(file_.get(), start, amount); + UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); + WriteOrThrow(file_, start, amount); long forward = entry_size_ - internal - amount; - if (forward) UTIL_THROW_IF(fseek(file_.get(), forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision"); +#if !defined(_WIN32) && !defined(_WIN64) + if (forward) +#endif + UTIL_THROW_IF(fseek(file_, forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision"); } -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { +void RecordReader::Rewind() { + rewind(file_); + remains_ = true; + ++*this; +} + +SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + util::TempMaker maker(file_prefix); PositiveProbWarn warn(config.positive_log_probability); + unigram_.reset(maker.Make()); { - std::string unigram_name = file_prefix + "unigrams"; - util::scoped_fd unigram_file; // In case <unk> appears. - size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); + size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_.get(), size_out), size_out); Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn); CheckSpecials(config, vocab); if (!vocab.SawUnk()) ++counts[0]; @@ -246,16 +195,96 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); buffer = std::min<size_t>(buffer, buffer_use); - util::scoped_memory mem; - mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); + util::scoped_malloc mem; + mem.reset(malloc(buffer)); if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); for (unsigned char order = 2; order <= counts.size(); ++order) { - ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); + ConvertToSorted(f, vocab, counts, maker, order, warn, mem.get(), buffer); } ReadEnd(f); } +namespace { +class Closer { + public: + explicit Closer(std::deque<FILE*> &files) : files_(files) {} + + ~Closer() { + for (std::deque<FILE*>::iterator i = files_.begin(); i != files_.end(); ++i) { + util::scoped_FILE deleter(*i); + } + } + + void PopFront() { + util::scoped_FILE deleter(files_.front()); + files_.pop_front(); + } + private: + std::deque<FILE*> &files_; +}; +} // namespace + +void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) { + ReadNGramHeader(f, order); + const size_t count = counts[order - 1]; + // Size of weights. Does it include backoff? + const size_t words_size = sizeof(WordIndex) * order; + const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); + const size_t entry_size = words_size + weights_size; + const size_t batch_size = std::min(count, mem_size / entry_size); + uint8_t *const begin = reinterpret_cast<uint8_t*>(mem); + + std::deque<FILE*> files, contexts; + Closer files_closer(files), contexts_closer(contexts); + + for (std::size_t batch = 0, done = 0; done < count; ++batch) { + uint8_t *out = begin; + uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; + if (order == counts.size()) { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn); + } + } else { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn); + } + } + // Sort full records by full n-gram. + util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); + // parallel_sort uses too much RAM. TODO: figure out why windows sort doesn't like my proxies. +#if defined(_WIN32) || defined(_WIN64) + std::stable_sort +#else + std::sort +#endif + (NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order))); + files.push_back(DiskFlush(begin, out_end, maker)); + contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order)); + + done += (out_end - begin) / entry_size; + } + + // All individual files created. Merge them. + + while (files.size() > 1) { + files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine())); + files_closer.PopFront(); + files_closer.PopFront(); + contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine())); + contexts_closer.PopFront(); + contexts_closer.PopFront(); + } + + if (!files.empty()) { + // Steal from closers. + full_[order - 2].reset(files.front()); + files.pop_front(); + context_[order - 2].reset(contexts.front()); + contexts.pop_front(); + } +} + } // namespace trie } // namespace ngram } // namespace lm diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh index a6916483..3036319d 100644 --- a/klm/lm/trie_sort.hh +++ b/klm/lm/trie_sort.hh @@ -1,6 +1,9 @@ +// Step of trie builder: create sorted files. + #ifndef LM_TRIE_SORT__ #define LM_TRIE_SORT__ +#include "lm/max_order.hh" #include "lm/word_index.hh" #include "util/file.hh" @@ -11,20 +14,21 @@ #include <string> #include <vector> -#include <inttypes.h> +#include <stdint.h> -namespace util { class FilePiece; } +namespace util { +class FilePiece; +class TempMaker; +} // namespace util -// Step of trie builder: create sorted files. namespace lm { +class PositiveProbWarn; namespace ngram { class SortedVocabulary; class Config; namespace trie { -extern const char *kContextSuffix; -FILE *OpenOrThrow(const char *name, const char *mode); void WriteOrThrow(FILE *to, const void *data, size_t size); class EntryCompare : public std::binary_function<const void*, const void*, bool> { @@ -49,15 +53,15 @@ class RecordReader { public: RecordReader() : remains_(true) {} - void Init(const std::string &name, std::size_t entry_size); + void Init(FILE *file, std::size_t entry_size); void *Data() { return data_.get(); } const void *Data() const { return data_.get(); } RecordReader &operator++() { - std::size_t ret = fread(data_.get(), entry_size_, 1, file_.get()); + std::size_t ret = fread(data_.get(), entry_size_, 1, file_); if (!ret) { - UTIL_THROW_IF(!feof(file_.get()), util::ErrnoException, "Error reading temporary file"); + UTIL_THROW_IF(!feof(file_), util::ErrnoException, "Error reading temporary file"); remains_ = false; } return *this; @@ -65,27 +69,46 @@ class RecordReader { operator bool() const { return remains_; } - void Rewind() { - rewind(file_.get()); - remains_ = true; - ++*this; - } + void Rewind(); std::size_t EntrySize() const { return entry_size_; } void Overwrite(const void *start, std::size_t amount); private: + FILE *file_; + util::scoped_malloc data_; bool remains_; std::size_t entry_size_; - - util::scoped_FILE file_; }; -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab); +class SortedFiles { + public: + // Build from ARPA + SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab); + + int StealUnigram() { + return unigram_.release(); + } + + FILE *Full(unsigned char order) { + return full_[order - 2].get(); + } + + FILE *Context(unsigned char of_order) { + return context_[of_order - 2].get(); + } + + private: + void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size); + + util::scoped_fd unigram_; + + util::scoped_FILE full_[kMaxOrder - 1], context_[kMaxOrder - 1]; +}; } // namespace trie } // namespace ngram diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index ffec41ca..9fd698bb 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -6,12 +6,15 @@ #include "lm/config.hh" #include "lm/weights.hh" #include "util/exception.hh" +#include "util/file.hh" #include "util/joint_sort.hh" #include "util/murmur_hash.hh" #include "util/probing_hash_table.hh" #include <string> +#include <string.h> + namespace lm { namespace ngram { @@ -29,23 +32,30 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5); // Sadly some LMs have <UNK>. const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5); -WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { - if (!enumerate) return std::numeric_limits<WordIndex>::max(); +void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) { + // Check that we're at the right place by reading <unk> which is always first. + char check_unk[6]; + util::ReadOrThrow(fd, check_unk, 6); + UTIL_THROW_IF( + memcmp(check_unk, "<unk>", 6), + FormatLoadException, + "Vocabulary words are in the wrong place. This could be because the binary file was built with stale gcc and old kenlm. Stale gcc, including the gcc distributed with RedHat and OS X, has a bug that ignores pragma pack for template-dependent types. New kenlm works around this, so you'll save memory but have to rebuild any binary files using the probing data structure."); + if (!enumerate) return; + enumerate->Add(0, "<unk>"); + + // Read all the words after unk. const std::size_t kInitialRead = 16384; std::string buf; buf.reserve(kInitialRead + 100); buf.resize(kInitialRead); - WordIndex index = 0; + WordIndex index = 1; // Read <unk> already. while (true) { - ssize_t got = read(fd, &buf[0], kInitialRead); - UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words"); - if (got == 0) return index; + std::size_t got = util::ReadOrEOF(fd, &buf[0], kInitialRead); + if (got == 0) break; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; - ssize_t ret = read(fd, &next_char, 1); - UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words"); - UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word."); + util::ReadOrThrow(fd, &next_char, 1); buf.push_back(next_char); } // Ok now we have null terminated strings. @@ -55,6 +65,8 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { i += length + 1 /* null byte */; } } + + UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file."); } } // namespace @@ -69,8 +81,7 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { } void WriteWordsWrapper::Write(int fd) { - if ((off_t)-1 == lseek(fd, 0, SEEK_END)) - UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words"); + util::SeekEnd(fd); util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); } @@ -114,8 +125,10 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { if (enumerate_) { - util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); - util::JointSort(begin_, end_, values); + if (!strings_to_enumerate_.empty()) { + util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); + util::JointSort(begin_, end_, values); + } for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) { // <unk> strikes again: +1 here. enumerate_->Add(i + 1, strings_to_enumerate_[i]); @@ -131,11 +144,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { bound_ = end_ - begin_ + 1; } -void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { +void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); - ReadWords(fd, to); SetSpecial(Index("<s>"), Index("</s>"), 0); bound_ = end_ - begin_ + 1; + if (have_words) ReadWords(fd, to, bound_); } namespace { @@ -153,12 +166,12 @@ struct ProbingVocabularyHeader { ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { - return Align8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); + return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); } void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { header_ = static_cast<detail::ProbingVocabularyHeader*>(start); - lookup_ = Lookup(static_cast<uint8_t*>(start) + Align8(sizeof(detail::ProbingVocabularyHeader)), allocated); + lookup_ = Lookup(static_cast<uint8_t*>(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated); bound_ = 1; saw_unk_ = false; } @@ -178,7 +191,7 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) { return 0; } else { if (enumerate_) enumerate_->Add(bound_, str); - lookup_.Insert(Lookup::Packing::Make(hashed, bound_)); + lookup_.Insert(ProbingVocabuaryEntry::Make(hashed, bound_)); return bound_++; } } @@ -190,12 +203,12 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { SetSpecial(Index("<s>"), Index("</s>"), 0); } -void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { +void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); lookup_.LoadedBinary(); - ReadWords(fd, to); bound_ = header_->bound; SetSpecial(Index("<s>"), Index("</s>"), 0); + if (have_words) ReadWords(fd, to, bound_); } void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 3c3414fb..06fdefe4 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -4,7 +4,6 @@ #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" -#include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" #include "util/string_piece.hh" @@ -83,7 +82,7 @@ class SortedVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } - void LoadedBinary(int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); private: uint64_t *begin_, *end_; @@ -100,6 +99,26 @@ class SortedVocabulary : public base::Vocabulary { std::vector<std::string> strings_to_enumerate_; }; +#pragma pack(push) +#pragma pack(4) +struct ProbingVocabuaryEntry { + uint64_t key; + WordIndex value; + + typedef uint64_t Key; + uint64_t GetKey() const { + return key; + } + + static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) { + ProbingVocabuaryEntry ret; + ret.key = key; + ret.value = value; + return ret; + } +}; +#pragma pack(pop) + // Vocabulary storing a map from uint64_t to WordIndex. class ProbingVocabulary : public base::Vocabulary { public: @@ -107,7 +126,7 @@ class ProbingVocabulary : public base::Vocabulary { WordIndex Index(const StringPiece &str) const { Lookup::ConstIterator i; - return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0; + return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; } static size_t Size(std::size_t entries, const Config &config); @@ -124,17 +143,14 @@ class ProbingVocabulary : public base::Vocabulary { void FinishedLoading(ProbBackoff *reorder_vocab); + std::size_t UnkCountChangePadding() const { return 0; } + bool SawUnk() const { return saw_unk_; } - void LoadedBinary(int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); private: - // std::identity is an SGI extension :-( - struct IdentityHash : public std::unary_function<uint64_t, std::size_t> { - std::size_t operator()(uint64_t arg) const { return static_cast<std::size_t>(arg); } - }; - - typedef util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, WordIndex>, IdentityHash> Lookup; + typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup; Lookup lookup_; diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 33266b94..73a5cb22 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -1,33 +1,37 @@ #ifndef UTIL_BIT_PACKING__ #define UTIL_BIT_PACKING__ -/* Bit-level packing routines */ +/* Bit-level packing routines + * + * WARNING WARNING WARNING: + * The write functions assume that memory is zero initially. This makes them + * faster and is the appropriate case for mmapped language model construction. + * These routines assume that unaligned access to uint64_t is fast. This is + * the case on x86_64. I'm not sure how fast unaligned 64-bit access is on + * x86 but my target audience is large language models for which 64-bit is + * necessary. + * + * Call the BitPackingSanity function to sanity check. Calling once suffices, + * but it may be called multiple times when that's inconvenient. + * + * ARM and MinGW ports contributed by Hideo Okuma and Tomoyuki Yoshimura at + * NICT. + */ #include <assert.h> #ifdef __APPLE__ #include <architecture/byte_order.h> #elif __linux__ #include <endian.h> -#else +#elif !defined(_WIN32) && !defined(_WIN64) #include <arpa/nameser_compat.h> #endif -#include <inttypes.h> - -namespace util { +#include <stdint.h> -/* WARNING WARNING WARNING: - * The write functions assume that memory is zero initially. This makes them - * faster and is the appropriate case for mmapped language model construction. - * These routines assume that unaligned access to uint64_t is fast and that - * storage is little endian. This is the case on x86_64. I'm not sure how - * fast unaligned 64-bit access is on x86 but my target audience is large - * language models for which 64-bit is necessary. - * - * Call the BitPackingSanity function to sanity check. Calling once suffices, - * but it may be called multiple times when that's inconvenient. - */ +#include <string.h> +namespace util { // Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct. #if BYTE_ORDER == LITTLE_ENDIAN @@ -43,7 +47,14 @@ inline uint8_t BitPackShift(uint8_t bit, uint8_t length) { #endif inline uint64_t ReadOff(const void *base, uint64_t bit_off) { +#if defined(__arm) || defined(__arm__) + const uint8_t *base_off = reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3); + uint64_t value64; + memcpy(&value64, base_off, sizeof(value64)); + return value64; +#else return *reinterpret_cast<const uint64_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)); +#endif } /* Pack integers up to 57 bits using their least significant digits. @@ -57,18 +68,41 @@ inline uint64_t ReadInt57(const void *base, uint64_t bit_off, uint8_t length, ui * Assumes the memory is zero initially. */ inline void WriteInt57(void *base, uint64_t bit_off, uint8_t length, uint64_t value) { +#if defined(__arm) || defined(__arm__) + uint8_t *base_off = reinterpret_cast<uint8_t*>(base) + (bit_off >> 3); + uint64_t value64; + memcpy(&value64, base_off, sizeof(value64)); + value64 |= (value << BitPackShift(bit_off & 7, length)); + memcpy(base_off, &value64, sizeof(value64)); +#else *reinterpret_cast<uint64_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |= (value << BitPackShift(bit_off & 7, length)); +#endif } /* Same caveats as above, but for a 25 bit limit. */ inline uint32_t ReadInt25(const void *base, uint64_t bit_off, uint8_t length, uint32_t mask) { +#if defined(__arm) || defined(__arm__) + const uint8_t *base_off = reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3); + uint32_t value32; + memcpy(&value32, base_off, sizeof(value32)); + return (value32 >> BitPackShift(bit_off & 7, length)) & mask; +#else return (*reinterpret_cast<const uint32_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)) >> BitPackShift(bit_off & 7, length)) & mask; +#endif } inline void WriteInt25(void *base, uint64_t bit_off, uint8_t length, uint32_t value) { +#if defined(__arm) || defined(__arm__) + uint8_t *base_off = reinterpret_cast<uint8_t*>(base) + (bit_off >> 3); + uint32_t value32; + memcpy(&value32, base_off, sizeof(value32)); + value32 |= (value << BitPackShift(bit_off & 7, length)); + memcpy(base_off, &value32, sizeof(value32)); +#else *reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |= (value << BitPackShift(bit_off & 7, length)); +#endif } typedef union { float f; uint32_t i; } FloatEnc; diff --git a/klm/util/exception.cc b/klm/util/exception.cc index 96951495..c4f8c04c 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -66,7 +66,7 @@ const char *HandleStrerror(const char *ret, const char * /*buf*/) { ErrnoException::ErrnoException() throw() : errno_(errno) { char buf[200]; buf[0] = 0; -#ifdef sun +#if defined(sun) || defined(_WIN32) || defined(_WIN64) const char *add = strerror(errno); #else const char *add = HandleStrerror(strerror_r(errno, buf, 200), buf); diff --git a/klm/util/file.cc b/klm/util/file.cc index d707568e..176737fa 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -9,8 +9,12 @@ #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> -#include <unistd.h> -#include <inttypes.h> +#include <stdint.h> + +#if defined(_WIN32) || defined(_WIN64) +#include <windows.h> +#include <io.h> +#endif namespace util { @@ -30,33 +34,71 @@ scoped_FILE::~scoped_FILE() { int OpenReadOrThrow(const char *name) { int ret; +#if defined(_WIN32) || defined(_WIN64) + UTIL_THROW_IF(-1 == (ret = _open(name, _O_BINARY | _O_RDONLY)), ErrnoException, "while opening " << name); +#else UTIL_THROW_IF(-1 == (ret = open(name, O_RDONLY)), ErrnoException, "while opening " << name); +#endif return ret; } int CreateOrThrow(const char *name) { int ret; - UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR)), ErrnoException, "while creating " << name); +#if defined(_WIN32) || defined(_WIN64) + UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); +#else + UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name); +#endif return ret; } -off_t SizeFile(int fd) { +uint64_t SizeFile(int fd) { +#if defined(_WIN32) || defined(_WIN64) + __int64 ret = _filelengthi64(fd); + return (ret == -1) ? kBadSize : ret; +#else struct stat sb; if (fstat(fd, &sb) == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; return sb.st_size; +#endif +} + +void ResizeOrThrow(int fd, uint64_t to) { +#if defined(_WIN32) || defined(_WIN64) + UTIL_THROW_IF(_chsize_s(fd, to), ErrnoException, "Resizing to " << to << " bytes failed"); +#else + UTIL_THROW_IF(ftruncate(fd, to), ErrnoException, "Resizing to " << to << " bytes failed"); +#endif } +#ifdef WIN32 +typedef int ssize_t; +#endif + void ReadOrThrow(int fd, void *to_void, std::size_t amount) { uint8_t *to = static_cast<uint8_t*>(to_void); while (amount) { ssize_t ret = read(fd, to, amount); - if (ret == -1) UTIL_THROW(ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); - if (ret == 0) UTIL_THROW(Exception, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); + UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); + UTIL_THROW_IF(ret == 0, EndOfFileException, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); amount -= ret; to += ret; } } +std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) { + uint8_t *to = static_cast<uint8_t*>(to_void); + std::size_t remaining = amount; + while (remaining) { + ssize_t ret = read(fd, to, remaining); + UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << remaining << " from fd " << fd << " failed."); + if (!ret) return amount - remaining; + remaining -= ret; + to += ret; + } + return amount; +} + void WriteOrThrow(int fd, const void *data_void, std::size_t size) { const uint8_t *data = static_cast<const uint8_t*>(data_void); while (size) { @@ -67,8 +109,172 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) { } } -void RemoveOrThrow(const char *name) { - UTIL_THROW_IF(std::remove(name), util::ErrnoException, "Could not remove " << name); +void FSyncOrThrow(int fd) { +// Apparently windows doesn't have fsync? +#if !defined(_WIN32) && !defined(_WIN64) + UTIL_THROW_IF(-1 == fsync(fd), ErrnoException, "Sync of " << fd << " failed."); +#endif +} + +namespace { +void InternalSeek(int fd, off_t off, int whence) { + UTIL_THROW_IF((off_t)-1 == lseek(fd, off, whence), ErrnoException, "Seek failed"); +} +} // namespace + +void SeekOrThrow(int fd, uint64_t off) { + InternalSeek(fd, off, SEEK_SET); +} + +void AdvanceOrThrow(int fd, int64_t off) { + InternalSeek(fd, off, SEEK_CUR); +} + +void SeekEnd(int fd) { + InternalSeek(fd, 0, SEEK_END); +} + +std::FILE *FDOpenOrThrow(scoped_fd &file) { + std::FILE *ret = fdopen(file.get(), "r+b"); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen"); + file.release(); + return ret; +} + +TempMaker::TempMaker(const std::string &prefix) : base_(prefix) { + base_ += "XXXXXX"; +} + +// Sigh. Windows temporary file creation is full of race conditions. +#if defined(_WIN32) || defined(_WIN64) +/* mkstemp extracted from libc/sysdeps/posix/tempname.c. Copyright + (C) 1991-1999, 2000, 2001, 2006 Free Software Foundation, Inc. + + The GNU C Library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. */ + +/* This has been modified from the original version to rename the function and + * set the Windows temporary flag. */ + +static const char letters[] = +"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + +/* Generate a temporary file name based on TMPL. TMPL must match the + rules for mk[s]temp (i.e. end in "XXXXXX"). The name constructed + does not exist at the time of the call to mkstemp. TMPL is + overwritten with the result. */ +int +mkstemp_and_unlink(char *tmpl) +{ + int len; + char *XXXXXX; + static unsigned long long value; + unsigned long long random_time_bits; + unsigned int count; + int fd = -1; + int save_errno = errno; + + /* A lower bound on the number of temporary files to attempt to + generate. The maximum total number of temporary file names that + can exist for a given template is 62**6. It should never be + necessary to try all these combinations. Instead if a reasonable + number of names is tried (we define reasonable as 62**3) fail to + give the system administrator the chance to remove the problems. */ +#define ATTEMPTS_MIN (62 * 62 * 62) + + /* The number of times to attempt to generate a temporary file. To + conform to POSIX, this must be no smaller than TMP_MAX. */ +#if ATTEMPTS_MIN < TMP_MAX + unsigned int attempts = TMP_MAX; +#else + unsigned int attempts = ATTEMPTS_MIN; +#endif + + len = strlen (tmpl); + if (len < 6 || strcmp (&tmpl[len - 6], "XXXXXX")) + { + errno = EINVAL; + return -1; + } + +/* This is where the Xs start. */ + XXXXXX = &tmpl[len - 6]; + + /* Get some more or less random data. */ + { + SYSTEMTIME stNow; + FILETIME ftNow; + + // get system time + GetSystemTime(&stNow); + stNow.wMilliseconds = 500; + if (!SystemTimeToFileTime(&stNow, &ftNow)) + { + errno = -1; + return -1; + } + + random_time_bits = (((unsigned long long)ftNow.dwHighDateTime << 32) + | (unsigned long long)ftNow.dwLowDateTime); + } + value += random_time_bits ^ (unsigned long long)GetCurrentThreadId (); + + for (count = 0; count < attempts; value += 7777, ++count) + { + unsigned long long v = value; + + /* Fill in the random bits. */ + XXXXXX[0] = letters[v % 62]; + v /= 62; + XXXXXX[1] = letters[v % 62]; + v /= 62; + XXXXXX[2] = letters[v % 62]; + v /= 62; + XXXXXX[3] = letters[v % 62]; + v /= 62; + XXXXXX[4] = letters[v % 62]; + v /= 62; + XXXXXX[5] = letters[v % 62]; + + /* Modified for windows and to unlink */ + // fd = open (tmpl, O_RDWR | O_CREAT | O_EXCL, _S_IREAD | _S_IWRITE); + fd = _open (tmpl, _O_RDWR | _O_CREAT | _O_TEMPORARY | _O_EXCL | _O_BINARY, _S_IREAD | _S_IWRITE); + if (fd >= 0) + { + errno = save_errno; + return fd; + } + else if (errno != EEXIST) + return -1; + } + + /* We got out of the loop because we ran out of combinations to try. */ + errno = EEXIST; + return -1; +} +#else +int +mkstemp_and_unlink(char *tmpl) { + int ret = mkstemp(tmpl); + if (ret == -1) return -1; + UTIL_THROW_IF(unlink(tmpl), util::ErrnoException, "Failed to delete " << tmpl); + return ret; +} +#endif + +int TempMaker::Make() const { + std::string copy(base_); + copy.push_back(0); + int ret; + UTIL_THROW_IF(-1 == (ret = mkstemp_and_unlink(©[0])), util::ErrnoException, "Failed to make a temporary based on " << base_); + return ret; +} + +std::FILE *TempMaker::MakeFile() const { + util::scoped_fd file(Make()); + return FDOpenOrThrow(file); } } // namespace util diff --git a/klm/util/file.hh b/klm/util/file.hh index d6cca41d..72c8ea76 100644 --- a/klm/util/file.hh +++ b/klm/util/file.hh @@ -1,8 +1,11 @@ #ifndef UTIL_FILE__ #define UTIL_FILE__ +#include <cstddef> #include <cstdio> -#include <unistd.h> +#include <string> + +#include <stdint.h> namespace util { @@ -52,22 +55,52 @@ class scoped_FILE { file_ = to; } + std::FILE *release() { + std::FILE *ret = file_; + file_ = NULL; + return ret; + } + private: std::FILE *file_; }; +// Open for read only. int OpenReadOrThrow(const char *name); - +// Create file if it doesn't exist, truncate if it does. Opened for write. int CreateOrThrow(const char *name); // Return value for SizeFile when it can't size properly. -const off_t kBadSize = -1; -off_t SizeFile(int fd); +const uint64_t kBadSize = (uint64_t)-1; +uint64_t SizeFile(int fd); + +void ResizeOrThrow(int fd, uint64_t to); void ReadOrThrow(int fd, void *to, std::size_t size); +std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount); + void WriteOrThrow(int fd, const void *data_void, std::size_t size); -void RemoveOrThrow(const char *name); +void FSyncOrThrow(int fd); + +// Seeking +void SeekOrThrow(int fd, uint64_t off); +void AdvanceOrThrow(int fd, int64_t off); +void SeekEnd(int fd); + +std::FILE *FDOpenOrThrow(scoped_fd &file); + +class TempMaker { + public: + explicit TempMaker(const std::string &prefix); + + int Make() const; + + std::FILE *MakeFile() const; + + private: + std::string base_; +}; } // namespace util diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index b57582a0..081e662b 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -2,6 +2,10 @@ #include "util/exception.hh" #include "util/file.hh" +#include "util/mmap.hh" +#ifdef WIN32 +#include <io.h> +#endif // WIN32 #include <iostream> #include <string> @@ -11,14 +15,8 @@ #include <ctype.h> #include <fcntl.h> #include <stdlib.h> -#include <sys/mman.h> #include <sys/types.h> #include <sys/stat.h> -#include <unistd.h> - -#ifdef HAVE_ZLIB -#include <zlib.h> -#endif namespace util { @@ -26,24 +24,24 @@ ParseNumberException::ParseNumberException(StringPiece value) throw() { *this << "Could not parse \"" << value << "\" into a number"; } -GZException::GZException(void *file) { #ifdef HAVE_ZLIB +GZException::GZException(gzFile file) { int num; - *this << gzerror(file, &num) << " from zlib"; -#endif // HAVE_ZLIB + *this << gzerror( file, &num) << " from zlib"; } +#endif // HAVE_ZLIB // Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale). const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; -FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) : - file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)), +FilePiece::FilePiece(const char *name, std::ostream *show_progress, std::size_t min_buffer) : + file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(SizePage()), progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { Initialize(name, show_progress, min_buffer); } -FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, off_t min_buffer) : - file_(fd), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)), +FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std::size_t min_buffer) : + file_(fd), total_size_(SizeFile(file_.get())), page_(SizePage()), progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { Initialize(name, show_progress, min_buffer); } @@ -63,7 +61,7 @@ FilePiece::~FilePiece() { } StringPiece FilePiece::ReadLine(char delim) { - size_t skip = 0; + std::size_t skip = 0; while (true) { for (const char *i = position_ + skip; i < position_end_; ++i) { if (*i == delim) { @@ -94,13 +92,13 @@ unsigned long int FilePiece::ReadULong() { return ReadNumber<unsigned long int>(); } -void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) { +void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer) { #ifdef HAVE_ZLIB gz_file_ = NULL; #endif file_name_ = name; - default_map_size_ = page_ * std::max<off_t>((min_buffer / page_ + 1), 2); + default_map_size_ = page_ * std::max<std::size_t>((min_buffer / page_ + 1), 2); position_ = NULL; position_end_ = NULL; mapped_offset_ = 0; @@ -130,7 +128,7 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t namespace { void ParseNumber(const char *begin, char *&end, float &out) { -#ifdef sun +#if defined(sun) || defined(WIN32) out = static_cast<float>(strtod(begin, &end)); #else out = strtof(begin, &end); @@ -171,7 +169,7 @@ template <class T> T FilePiece::ReadNumber() { } const char *FilePiece::FindDelimiterOrEOF(const bool *delim) { - size_t skip = 0; + std::size_t skip = 0; while (true) { for (const char *i = position_ + skip; i < position_end_; ++i) { if (delim[static_cast<unsigned char>(*i)]) return i; @@ -190,7 +188,7 @@ void FilePiece::Shift() { progress_.Finished(); throw EndOfFileException(); } - off_t desired_begin = position_ - data_.begin() + mapped_offset_; + uint64_t desired_begin = position_ - data_.begin() + mapped_offset_; if (!fallback_to_read_) MMapShift(desired_begin); // Notice an mmap failure might set the fallback. @@ -201,18 +199,18 @@ void FilePiece::Shift() { } } -void FilePiece::MMapShift(off_t desired_begin) { +void FilePiece::MMapShift(uint64_t desired_begin) { // Use mmap. - off_t ignore = desired_begin % page_; + uint64_t ignore = desired_begin % page_; // Duplicate request for Shift means give more data. if (position_ == data_.begin() + ignore) { default_map_size_ *= 2; } // Local version so that in case of failure it doesn't overwrite the class variable. - off_t mapped_offset = desired_begin - ignore; + uint64_t mapped_offset = desired_begin - ignore; - off_t mapped_size; - if (default_map_size_ >= static_cast<size_t>(total_size_ - mapped_offset)) { + uint64_t mapped_size; + if (default_map_size_ >= static_cast<std::size_t>(total_size_ - mapped_offset)) { at_end_ = true; mapped_size = total_size_ - mapped_offset; } else { @@ -221,15 +219,11 @@ void FilePiece::MMapShift(off_t desired_begin) { // Forcibly clear the existing mmap first. data_.reset(); - data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_SHARED - // Populate where available on linux -#ifdef MAP_POPULATE - | MAP_POPULATE -#endif - , *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED); - if (data_.get() == MAP_FAILED) { + try { + MapRead(POPULATE_OR_LAZY, *file_, mapped_offset, mapped_size, data_); + } catch (const util::ErrnoException &e) { if (desired_begin) { - if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either."); + SeekOrThrow(*file_, desired_begin); } // The mmap was scheduled to end the file, but now we're going to read it. at_end_ = false; @@ -259,6 +253,10 @@ void FilePiece::TransitionToRead() { #endif } +#ifdef WIN32 +typedef int ssize_t; +#endif + void FilePiece::ReadShift() { assert(fallback_to_read_); // Bytes [data_.begin(), position_) have been consumed. diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index a627f38c..af93d8aa 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -8,9 +8,14 @@ #include "util/mmap.hh" #include "util/string_piece.hh" +#include <cstddef> #include <string> -#include <cstddef> +#include <stdint.h> + +#ifdef HAVE_ZLIB +#include <zlib.h> +#endif namespace util { @@ -22,7 +27,9 @@ class ParseNumberException : public Exception { class GZException : public Exception { public: - explicit GZException(void *file); +#ifdef HAVE_ZLIB + explicit GZException(gzFile file); +#endif GZException() throw() {} ~GZException() throw() {} }; @@ -33,9 +40,9 @@ extern const bool kSpaces[256]; class FilePiece { public: // 32 MB default. - explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); + explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432); // Takes ownership of fd. name is used for messages. - explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); + explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432); ~FilePiece(); @@ -70,14 +77,14 @@ class FilePiece { } } - off_t Offset() const { + uint64_t Offset() const { return position_ - data_.begin() + mapped_offset_; } const std::string &FileName() const { return file_name_; } private: - void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer); + void Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer); template <class T> T ReadNumber(); @@ -91,7 +98,7 @@ class FilePiece { void Shift(); // Backends to Shift(). - void MMapShift(off_t desired_begin); + void MMapShift(uint64_t desired_begin); void TransitionToRead(); void ReadShift(); @@ -99,11 +106,11 @@ class FilePiece { const char *position_, *last_space_, *position_end_; scoped_fd file_; - const off_t total_size_; - const off_t page_; + const uint64_t total_size_; + const uint64_t page_; - size_t default_map_size_; - off_t mapped_offset_; + std::size_t default_map_size_; + uint64_t mapped_offset_; // Order matters: file_ should always be destroyed after this. scoped_memory data_; @@ -116,7 +123,7 @@ class FilePiece { std::string file_name_; #ifdef HAVE_ZLIB - void *gz_file_; + gzFile gz_file_; #endif // HAVE_ZLIB }; diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc index dc9ec7e7..f912e18a 100644 --- a/klm/util/file_piece_test.cc +++ b/klm/util/file_piece_test.cc @@ -1,3 +1,4 @@ +// Tests might fail if you have creative characters in your path. Sue me. #include "util/file_piece.hh" #include "util/scoped.hh" @@ -14,10 +15,18 @@ namespace util { namespace { +std::string FileLocation() { + if (boost::unit_test::framework::master_test_suite().argc < 2) { + return "file_piece.cc"; + } + std::string ret(boost::unit_test::framework::master_test_suite().argv[1]); + return ret; +} + /* mmap implementation */ BOOST_AUTO_TEST_CASE(MMapReadLine) { - std::fstream ref("file_piece.cc", std::ios::in); - FilePiece test("file_piece.cc", NULL, 1); + std::fstream ref(FileLocation().c_str(), std::ios::in); + FilePiece test(FileLocation().c_str(), NULL, 1); std::string ref_line; while (getline(ref, ref_line)) { StringPiece test_line(test.ReadLine()); @@ -35,9 +44,13 @@ BOOST_AUTO_TEST_CASE(MMapReadLine) { */ /* read() implementation */ BOOST_AUTO_TEST_CASE(StreamReadLine) { - std::fstream ref("file_piece.cc", std::ios::in); + std::fstream ref(FileLocation().c_str(), std::ios::in); + + std::string popen_args = "cat \""; + popen_args += FileLocation(); + popen_args += '"'; - FILE *catter = popen("cat file_piece.cc", "r"); + FILE *catter = popen(popen_args.c_str(), "r"); BOOST_REQUIRE(catter); FilePiece test(dup(fileno(catter)), "file_piece.cc", NULL, 1); @@ -58,10 +71,15 @@ BOOST_AUTO_TEST_CASE(StreamReadLine) { // gzip file BOOST_AUTO_TEST_CASE(PlainZipReadLine) { - std::fstream ref("file_piece.cc", std::ios::in); + std::string location(FileLocation()); + std::fstream ref(location.c_str(), std::ios::in); - BOOST_REQUIRE_EQUAL(0, system("gzip <file_piece.cc >file_piece.cc.gz")); - FilePiece test("file_piece.cc.gz", NULL, 1); + std::string command("gzip <\""); + command += location + "\" >\"" + location + "\".gz"; + + BOOST_REQUIRE_EQUAL(0, system(command.c_str())); + FilePiece test((location + ".gz").c_str(), NULL, 1); + unlink((location + ".gz").c_str()); std::string ref_line; while (getline(ref, ref_line)) { StringPiece test_line(test.ReadLine()); @@ -77,12 +95,15 @@ BOOST_AUTO_TEST_CASE(PlainZipReadLine) { // the test. #ifndef __APPLE__ BOOST_AUTO_TEST_CASE(StreamZipReadLine) { - std::fstream ref("file_piece.cc", std::ios::in); + std::fstream ref(FileLocation().c_str(), std::ios::in); + + std::string command("gzip <\""); + command += FileLocation() + "\""; - FILE * catter = popen("gzip <file_piece.cc", "r"); + FILE * catter = popen(command.c_str(), "r"); BOOST_REQUIRE(catter); - FilePiece test(dup(fileno(catter)), "file_piece.cc", NULL, 1); + FilePiece test(dup(fileno(catter)), "file_piece.cc.gz", NULL, 1); std::string ref_line; while (getline(ref, ref_line)) { StringPiece test_line(test.ReadLine()); diff --git a/klm/util/getopt.c b/klm/util/getopt.c new file mode 100644 index 00000000..992c96b0 --- /dev/null +++ b/klm/util/getopt.c @@ -0,0 +1,78 @@ +/* +POSIX getopt for Windows + +AT&T Public License + +Code given out at the 1985 UNIFORUM conference in Dallas. +*/ + +#ifndef __GNUC__ + +#include "getopt.hh" +#include <stdio.h> +#include <string.h> + +#define NULL 0 +#define EOF (-1) +#define ERR(s, c) if(opterr){\ + char errbuf[2];\ + errbuf[0] = c; errbuf[1] = '\n';\ + fputs(argv[0], stderr);\ + fputs(s, stderr);\ + fputc(c, stderr);} + //(void) write(2, argv[0], (unsigned)strlen(argv[0]));\ + //(void) write(2, s, (unsigned)strlen(s));\ + //(void) write(2, errbuf, 2);} + +int opterr = 1; +int optind = 1; +int optopt; +char *optarg; + +int +getopt(argc, argv, opts) +int argc; +char **argv, *opts; +{ + static int sp = 1; + register int c; + register char *cp; + + if(sp == 1) + if(optind >= argc || + argv[optind][0] != '-' || argv[optind][1] == '\0') + return(EOF); + else if(strcmp(argv[optind], "--") == NULL) { + optind++; + return(EOF); + } + optopt = c = argv[optind][sp]; + if(c == ':' || (cp=strchr(opts, c)) == NULL) { + ERR(": illegal option -- ", c); + if(argv[optind][++sp] == '\0') { + optind++; + sp = 1; + } + return('?'); + } + if(*++cp == ':') { + if(argv[optind][sp+1] != '\0') + optarg = &argv[optind++][sp+1]; + else if(++optind >= argc) { + ERR(": option requires an argument -- ", c); + sp = 1; + return('?'); + } else + optarg = argv[optind++]; + sp = 1; + } else { + if(argv[optind][++sp] == '\0') { + sp = 1; + optind++; + } + optarg = NULL; + } + return(c); +} + +#endif /* __GNUC__ */ diff --git a/klm/util/getopt.hh b/klm/util/getopt.hh new file mode 100644 index 00000000..6ad97732 --- /dev/null +++ b/klm/util/getopt.hh @@ -0,0 +1,33 @@ +/* +POSIX getopt for Windows + +AT&T Public License + +Code given out at the 1985 UNIFORUM conference in Dallas. +*/ + +#ifdef __GNUC__ +#include <getopt.h> +#endif +#ifndef __GNUC__ + +#ifndef _WINGETOPT_H_ +#define _WINGETOPT_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +extern int opterr; +extern int optind; +extern int optopt; +extern char *optarg; +extern int getopt(int argc, char **argv, char *opts); + +#ifdef __cplusplus +} +#endif + +#endif /* _GETOPT_H_ */ +#endif /* __GNUC__ */ + diff --git a/klm/util/key_value_packing.hh b/klm/util/key_value_packing.hh deleted file mode 100644 index b84a5aad..00000000 --- a/klm/util/key_value_packing.hh +++ /dev/null @@ -1,126 +0,0 @@ -#ifndef UTIL_KEY_VALUE_PACKING__ -#define UTIL_KEY_VALUE_PACKING__ - -/* Why such a general interface? I'm planning on doing bit-level packing. */ - -#include <algorithm> -#include <cstddef> -#include <cstring> - -#include <inttypes.h> - -namespace util { - -template <class Key, class Value> struct Entry { - Key key; - Value value; - - const Key &GetKey() const { return key; } - const Value &GetValue() const { return value; } - - Value &MutableValue() { return value; } - - void Set(const Key &key_in, const Value &value_in) { - SetKey(key_in); - SetValue(value_in); - } - void SetKey(const Key &key_in) { key = key_in; } - void SetValue(const Value &value_in) { value = value_in; } - - bool operator<(const Entry<Key, Value> &other) const { return GetKey() < other.GetKey(); } -}; - -// And now for a brief interlude to specialize std::swap. -} // namespace util -namespace std { -template <class Key, class Value> void swap(util::Entry<Key, Value> &first, util::Entry<Key, Value> &second) { - swap(first.key, second.key); - swap(first.value, second.value); -} -}// namespace std -namespace util { - -template <class KeyT, class ValueT> class AlignedPacking { - public: - typedef KeyT Key; - typedef ValueT Value; - - public: - static const std::size_t kBytes = sizeof(Entry<Key, Value>); - static const std::size_t kBits = kBytes * 8; - - typedef Entry<Key, Value> * MutableIterator; - typedef const Entry<Key, Value> * ConstIterator; - typedef const Entry<Key, Value> & ConstReference; - - static MutableIterator FromVoid(void *start) { - return reinterpret_cast<MutableIterator>(start); - } - - static Entry<Key, Value> Make(const Key &key, const Value &value) { - Entry<Key, Value> ret; - ret.Set(key, value); - return ret; - } -}; - -template <class KeyT, class ValueT> class ByteAlignedPacking { - public: - typedef KeyT Key; - typedef ValueT Value; - - private: -#pragma pack(push) -#pragma pack(1) - struct RawEntry { - Key key; - Value value; - - const Key &GetKey() const { return key; } - const Value &GetValue() const { return value; } - - Value &MutableValue() { return value; } - - void Set(const Key &key_in, const Value &value_in) { - SetKey(key_in); - SetValue(value_in); - } - void SetKey(const Key &key_in) { key = key_in; } - void SetValue(const Value &value_in) { value = value_in; } - - bool operator<(const RawEntry &other) const { return GetKey() < other.GetKey(); } - }; -#pragma pack(pop) - - friend void std::swap<>(RawEntry&, RawEntry&); - - public: - typedef RawEntry *MutableIterator; - typedef const RawEntry *ConstIterator; - typedef RawEntry &ConstReference; - - static const std::size_t kBytes = sizeof(RawEntry); - static const std::size_t kBits = kBytes * 8; - - static MutableIterator FromVoid(void *start) { - return MutableIterator(reinterpret_cast<RawEntry*>(start)); - } - - static RawEntry Make(const Key &key, const Value &value) { - RawEntry ret; - ret.Set(key, value); - return ret; - } -}; - -} // namespace util -namespace std { -template <class Key, class Value> void swap( - typename util::ByteAlignedPacking<Key, Value>::RawEntry &first, - typename util::ByteAlignedPacking<Key, Value>::RawEntry &second) { - swap(first.key, second.key); - swap(first.value, second.value); -} -}// namespace std - -#endif // UTIL_KEY_VALUE_PACKING__ diff --git a/klm/util/key_value_packing_test.cc b/klm/util/key_value_packing_test.cc deleted file mode 100644 index a0d33fd7..00000000 --- a/klm/util/key_value_packing_test.cc +++ /dev/null @@ -1,75 +0,0 @@ -#include "util/key_value_packing.hh" - -#include <boost/random/mersenne_twister.hpp> -#include <boost/random/uniform_int.hpp> -#include <boost/random/variate_generator.hpp> -#include <boost/scoped_array.hpp> -#define BOOST_TEST_MODULE KeyValueStoreTest -#include <boost/test/unit_test.hpp> - -#include <limits> -#include <stdlib.h> - -namespace util { -namespace { - -BOOST_AUTO_TEST_CASE(basic_in_out) { - typedef ByteAlignedPacking<uint64_t, unsigned char> Packing; - void *backing = malloc(Packing::kBytes * 2); - Packing::MutableIterator i(Packing::FromVoid(backing)); - i->SetKey(10); - BOOST_CHECK_EQUAL(10, i->GetKey()); - i->SetValue(3); - BOOST_CHECK_EQUAL(3, i->GetValue()); - ++i; - i->SetKey(5); - BOOST_CHECK_EQUAL(5, i->GetKey()); - i->SetValue(42); - BOOST_CHECK_EQUAL(42, i->GetValue()); - - Packing::ConstIterator c(i); - BOOST_CHECK_EQUAL(5, c->GetKey()); - --c; - BOOST_CHECK_EQUAL(10, c->GetKey()); - BOOST_CHECK_EQUAL(42, i->GetValue()); - - BOOST_CHECK_EQUAL(5, i->GetKey()); - free(backing); -} - -BOOST_AUTO_TEST_CASE(simple_sort) { - typedef ByteAlignedPacking<uint64_t, unsigned char> Packing; - char foo[Packing::kBytes * 4]; - Packing::MutableIterator begin(Packing::FromVoid(foo)); - Packing::MutableIterator i = begin; - i->SetKey(0); ++i; - i->SetKey(2); ++i; - i->SetKey(3); ++i; - i->SetKey(1); ++i; - std::sort(begin, i); - BOOST_CHECK_EQUAL(0, begin[0].GetKey()); - BOOST_CHECK_EQUAL(1, begin[1].GetKey()); - BOOST_CHECK_EQUAL(2, begin[2].GetKey()); - BOOST_CHECK_EQUAL(3, begin[3].GetKey()); -} - -BOOST_AUTO_TEST_CASE(big_sort) { - typedef ByteAlignedPacking<uint64_t, unsigned char> Packing; - boost::scoped_array<char> memory(new char[Packing::kBytes * 1000]); - Packing::MutableIterator begin(Packing::FromVoid(memory.get())); - - boost::mt19937 rng; - boost::uniform_int<uint64_t> range(0, std::numeric_limits<uint64_t>::max()); - boost::variate_generator<boost::mt19937&, boost::uniform_int<uint64_t> > gen(rng, range); - - for (size_t i = 0; i < 1000; ++i) { - (begin + i)->SetKey(gen()); - } - std::sort(begin, begin + 1000); - for (size_t i = 0; i < 999; ++i) { - BOOST_CHECK(begin[i] < begin[i+1]); - } -} - -} // namespace -} // namespace util diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index 279bafa8..3b1c58b8 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -1,23 +1,63 @@ +/* Memory mapping wrappers. + * ARM and MinGW ports contributed by Hideo Okuma and Tomoyuki Yoshimura at + * NICT. + */ +#include "util/mmap.hh" + #include "util/exception.hh" #include "util/file.hh" -#include "util/mmap.hh" #include <iostream> #include <assert.h> #include <fcntl.h> #include <sys/types.h> -#include <sys/mman.h> +#include <sys/stat.h> #include <stdlib.h> -#include <unistd.h> + +#if defined(_WIN32) || defined(_WIN64) +#include <windows.h> +#include <io.h> +#else +#include <sys/mman.h> +#endif namespace util { +long SizePage() { +#if defined(_WIN32) || defined(_WIN64) + SYSTEM_INFO si; + GetSystemInfo(&si); + return si.dwAllocationGranularity; +#else + return sysconf(_SC_PAGE_SIZE); +#endif +} + +void SyncOrThrow(void *start, size_t length) { +#if defined(_WIN32) || defined(_WIN64) + UTIL_THROW_IF(!::FlushViewOfFile(start, length), ErrnoException, "Failed to sync mmap"); +#else + UTIL_THROW_IF(msync(start, length, MS_SYNC), ErrnoException, "Failed to sync mmap"); +#endif +} + +void UnmapOrThrow(void *start, size_t length) { +#if defined(_WIN32) || defined(_WIN64) + UTIL_THROW_IF(!::UnmapViewOfFile(start), ErrnoException, "Failed to unmap a file"); +#else + UTIL_THROW_IF(munmap(start, length), ErrnoException, "munmap failed"); +#endif +} + scoped_mmap::~scoped_mmap() { if (data_ != (void*)-1) { - // Thanks Denis Filimonov for pointing out NFS likes msync first. - if (msync(data_, size_, MS_SYNC) || munmap(data_, size_)) { - std::cerr << "msync or mmap failed for " << size_ << " bytes." << std::endl; + try { + // Thanks Denis Filimonov for pointing out NFS likes msync first. + SyncOrThrow(data_, size_); + UnmapOrThrow(data_, size_); + } catch (const util::ErrnoException &e) { + std::cerr << e.what(); abort(); } } @@ -52,29 +92,40 @@ void scoped_memory::call_realloc(std::size_t size) { } } -void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, off_t offset) { +void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, uint64_t offset) { #ifdef MAP_POPULATE // Linux specific if (prefault) { flags |= MAP_POPULATE; } #endif +#if defined(_WIN32) || defined(_WIN64) + int protectC = for_write ? PAGE_READWRITE : PAGE_READONLY; + int protectM = for_write ? FILE_MAP_WRITE : FILE_MAP_READ; + uint64_t total_size = size + offset; + HANDLE hMapping = CreateFileMapping((HANDLE)_get_osfhandle(fd), NULL, protectC, total_size >> 32, static_cast<DWORD>(total_size), NULL); + UTIL_THROW_IF(!hMapping, ErrnoException, "CreateFileMapping failed"); + LPVOID ret = MapViewOfFile(hMapping, protectM, offset >> 32, offset, size); + CloseHandle(hMapping); + UTIL_THROW_IF(!ret, ErrnoException, "MapViewOfFile failed"); +#else int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ; void *ret = mmap(NULL, size, protect, flags, fd, offset); - if (ret == MAP_FAILED) { - UTIL_THROW(ErrnoException, "mmap failed for size " << size << " at offset " << offset); - } + UTIL_THROW_IF(ret == MAP_FAILED, ErrnoException, "mmap failed for size " << size << " at offset " << offset); +#endif return ret; } const int kFileFlags = -#ifdef MAP_FILE +#if defined(_WIN32) || defined(_WIN64) + 0 // MapOrThrow ignores flags on windows +#elif defined(MAP_FILE) MAP_FILE | MAP_SHARED #else MAP_SHARED #endif ; -void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out) { +void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scoped_memory &out) { switch (method) { case LAZY: out.reset(MapOrThrow(size, false, kFileFlags, false, fd, offset), size, scoped_memory::MMAP_ALLOCATED); @@ -91,30 +142,38 @@ void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_m case READ: out.reset(malloc(size), size, scoped_memory::MALLOC_ALLOCATED); if (!out.get()) UTIL_THROW(util::ErrnoException, "Allocating " << size << " bytes with malloc"); - if (-1 == lseek(fd, offset, SEEK_SET)) UTIL_THROW(ErrnoException, "lseek to " << offset << " in fd " << fd << " failed."); + SeekOrThrow(fd, offset); ReadOrThrow(fd, out.get(), size); break; } } -void *MapAnonymous(std::size_t size) { - return MapOrThrow(size, true, -#ifdef MAP_ANONYMOUS - MAP_ANONYMOUS // Linux +// Allocates zeroed memory in to. +void MapAnonymous(std::size_t size, util::scoped_memory &to) { + to.reset(); +#if defined(_WIN32) || defined(_WIN64) + to.reset(calloc(1, size), size, scoped_memory::MALLOC_ALLOCATED); #else - MAP_ANON // BSD + to.reset(MapOrThrow(size, true, +# if defined(MAP_ANONYMOUS) + MAP_ANONYMOUS | MAP_PRIVATE // Linux +# else + MAP_ANON | MAP_PRIVATE // BSD +# endif + , false, -1, 0), size, scoped_memory::MMAP_ALLOCATED); #endif - | MAP_PRIVATE, false, -1, 0); +} + +void *MapZeroedWrite(int fd, std::size_t size) { + ResizeOrThrow(fd, 0); + ResizeOrThrow(fd, size); + return MapOrThrow(size, true, kFileFlags, false, fd, 0); } void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) { - file.reset(open(name, O_CREAT | O_RDWR | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)); - if (-1 == file.get()) - UTIL_THROW(ErrnoException, "Failed to open " << name << " for writing"); - if (-1 == ftruncate(file.get(), size)) - UTIL_THROW(ErrnoException, "ftruncate on " << name << " to " << size << " failed"); + file.reset(CreateOrThrow(name)); try { - return MapOrThrow(size, true, kFileFlags, false, file.get(), 0); + return MapZeroedWrite(file.get(), size); } catch (ErrnoException &e) { e << " in file " << name; throw; diff --git a/klm/util/mmap.hh b/klm/util/mmap.hh index b0eb6672..b218c4d1 100644 --- a/klm/util/mmap.hh +++ b/klm/util/mmap.hh @@ -4,13 +4,15 @@ #include <cstddef> -#include <inttypes.h> +#include <stdint.h> #include <sys/types.h> namespace util { class scoped_fd; +long SizePage(); + // (void*)-1 is MAP_FAILED; this is done to avoid including the mmap header here. class scoped_mmap { public: @@ -94,15 +96,19 @@ typedef enum { extern const int kFileFlags; // Wrapper around mmap to check it worked and hide some platform macros. -void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, off_t offset = 0); +void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, uint64_t offset = 0); -void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out); +void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scoped_memory &out); -void *MapAnonymous(std::size_t size); +void MapAnonymous(std::size_t size, scoped_memory &to); // Open file name with mmap of size bytes, all of which are initially zero. +void *MapZeroedWrite(int fd, std::size_t size); void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file); +// msync wrapper +void SyncOrThrow(void *start, size_t length); + } // namespace util #endif // UTIL_MMAP__ diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index ef5783fe..6accc21a 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -7,9 +7,11 @@ * placed in namespace util * add MurmurHashNative * default option = 0 for seed + * ARM port from NICT */ #include "util/murmur_hash.hh" +#include <string.h> namespace util { @@ -28,12 +30,24 @@ uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) uint64_t h = seed ^ (len * m); +#if defined(__arm) || defined(__arm__) + const size_t ksize = sizeof(uint64_t); + const unsigned char * data = (const unsigned char *)key; + const unsigned char * end = data + (std::size_t)(len/8) * ksize; +#else const uint64_t * data = (const uint64_t *)key; const uint64_t * end = data + (len/8); +#endif while(data != end) { +#if defined(__arm) || defined(__arm__) + uint64_t k; + memcpy(&k, data, ksize); + data += ksize; +#else uint64_t k = *data++; +#endif k *= m; k ^= k >> r; @@ -75,16 +89,30 @@ uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) unsigned int h1 = seed ^ len; unsigned int h2 = 0; +#if defined(__arm) || defined(__arm__) + size_t ksize = sizeof(unsigned int); + const unsigned char * data = (const unsigned char *)key; +#else const unsigned int * data = (const unsigned int *)key; +#endif + unsigned int k1, k2; while(len >= 8) { - unsigned int k1 = *data++; +#if defined(__arm) || defined(__arm__) + memcpy(&k1, data, ksize); + data += ksize; + memcpy(&k2, data, ksize); + data += ksize; +#else + k1 = *data++; + k2 = *data++; +#endif + k1 *= m; k1 ^= k1 >> r; k1 *= m; h1 *= m; h1 ^= k1; len -= 4; - unsigned int k2 = *data++; k2 *= m; k2 ^= k2 >> r; k2 *= m; h2 *= m; h2 ^= k2; len -= 4; @@ -92,7 +120,12 @@ uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) if(len >= 4) { - unsigned int k1 = *data++; +#if defined(__arm) || defined(__arm__) + memcpy(&k1, data, ksize); + data += ksize; +#else + k1 = *data++; +#endif k1 *= m; k1 ^= k1 >> r; k1 *= m; h1 *= m; h1 ^= k1; len -= 4; diff --git a/klm/util/murmur_hash.hh b/klm/util/murmur_hash.hh index 78fe583f..638aaeb2 100644 --- a/klm/util/murmur_hash.hh +++ b/klm/util/murmur_hash.hh @@ -1,7 +1,7 @@ #ifndef UTIL_MURMUR_HASH__ #define UTIL_MURMUR_HASH__ #include <cstddef> -#include <inttypes.h> +#include <stdint.h> namespace util { diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 8122d69c..f466cebc 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -18,27 +18,33 @@ class ProbingSizeException : public Exception { ~ProbingSizeException() throw() {} }; +// std::identity is an SGI extension :-( +struct IdentityHash { + template <class T> T operator()(T arg) const { return arg; } +}; + /* Non-standard hash table * Buckets must be set at the beginning and must be greater than maximum number - * of elements, else an infinite loop happens. + * of elements, else it throws ProbingSizeException. * Memory management and initialization is externalized to make it easier to * serialize these to disk and load them quickly. * Uses linear probing to find value. * Only insert and lookup operations. */ -template <class PackingT, class HashT, class EqualT = std::equal_to<typename PackingT::Key> > class ProbingHashTable { +template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class ProbingHashTable { public: - typedef PackingT Packing; - typedef typename Packing::Key Key; - typedef typename Packing::MutableIterator MutableIterator; - typedef typename Packing::ConstIterator ConstIterator; - + typedef EntryT Entry; + typedef typename Entry::Key Key; + typedef const Entry *ConstIterator; + typedef Entry *MutableIterator; typedef HashT Hash; typedef EqualT Equal; + public: static std::size_t Size(std::size_t entries, float multiplier) { - return std::max(entries + 1, static_cast<std::size_t>(multiplier * static_cast<float>(entries))) * Packing::kBytes; + std::size_t buckets = std::max(entries + 1, static_cast<std::size_t>(multiplier * static_cast<float>(entries))); + return buckets * sizeof(Entry); } // Must be assigned to later. @@ -49,9 +55,9 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac {} ProbingHashTable(void *start, std::size_t allocated, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) - : begin_(Packing::FromVoid(start)), - buckets_(allocated / Packing::kBytes), - end_(begin_ + (allocated / Packing::kBytes)), + : begin_(reinterpret_cast<MutableIterator>(start)), + buckets_(allocated / sizeof(Entry)), + end_(begin_ + buckets_), invalid_(invalid), hash_(hash_func), equal_(equal_func), @@ -62,11 +68,10 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac {} template <class T> MutableIterator Insert(const T &t) { - if (++entries_ >= buckets_) - UTIL_THROW(ProbingSizeException, "Hash table with " << buckets_ << " buckets is full."); #ifdef DEBUG assert(initialized_); #endif + UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full."); for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) { if (equal_(i->GetKey(), invalid_)) { *i = t; return i; } if (++i == end_) { i = begin_; } @@ -84,7 +89,7 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac if (equal_(got, key)) { out = i; return true; } if (equal_(got, invalid_)) return false; if (++i == end_) i = begin_; - } + } } template <class Key> bool Find(const Key key, ConstIterator &out) const { diff --git a/klm/util/probing_hash_table_test.cc b/klm/util/probing_hash_table_test.cc index ff2f5af3..ef68e5f2 100644 --- a/klm/util/probing_hash_table_test.cc +++ b/klm/util/probing_hash_table_test.cc @@ -1,6 +1,6 @@ #include "util/probing_hash_table.hh" -#include "util/key_value_packing.hh" +#include <stdint.h> #define BOOST_TEST_MODULE ProbingHashTableTest #include <boost/test/unit_test.hpp> @@ -9,17 +9,34 @@ namespace util { namespace { -typedef AlignedPacking<char, uint64_t> Packing; -typedef ProbingHashTable<Packing, boost::hash<char> > Table; +struct Entry { + unsigned char key; + typedef unsigned char Key; + + unsigned char GetKey() const { + return key; + } + + uint64_t GetValue() const { + return value; + } + + uint64_t value; +}; + +typedef ProbingHashTable<Entry, boost::hash<unsigned char> > Table; BOOST_AUTO_TEST_CASE(simple) { char mem[Table::Size(10, 1.2)]; memset(mem, 0, sizeof(mem)); Table table(mem, sizeof(mem)); - Packing::ConstIterator i = Packing::ConstIterator(); + const Entry *i = NULL; BOOST_CHECK(!table.Find(2, i)); - table.Insert(Packing::Make(3, 328920)); + Entry to_ins; + to_ins.key = 3; + to_ins.value = 328920; + table.Insert(to_ins); BOOST_REQUIRE(table.Find(3, i)); BOOST_CHECK_EQUAL(3, i->GetKey()); BOOST_CHECK_EQUAL(static_cast<uint64_t>(328920), i->GetValue()); diff --git a/klm/util/sized_iterator.hh b/klm/util/sized_iterator.hh index 47dfc245..aabcc531 100644 --- a/klm/util/sized_iterator.hh +++ b/klm/util/sized_iterator.hh @@ -6,7 +6,7 @@ #include <functional> #include <string> -#include <inttypes.h> +#include <stdint.h> #include <string.h> namespace util { diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 0d6ecbbd..7700d9e6 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -5,7 +5,7 @@ #include <cstddef> #include <assert.h> -#include <inttypes.h> +#include <stdint.h> namespace util { @@ -122,99 +122,6 @@ template <class Iterator, class Accessor> Iterator BinaryBelow( return begin - 1; } -// To use this template, you need to define a Pivot function to match Key. -template <class PackingT> class SortedUniformMap { - public: - typedef PackingT Packing; - typedef typename Packing::ConstIterator ConstIterator; - typedef typename Packing::MutableIterator MutableIterator; - - struct Accessor { - public: - typedef typename Packing::Key Key; - const Key &operator()(const ConstIterator &i) const { return i->GetKey(); } - Key &operator()(const MutableIterator &i) const { return i->GetKey(); } - }; - - // Offer consistent API with probing hash. - static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) { - return sizeof(uint64_t) + entries * Packing::kBytes; - } - - SortedUniformMap() -#ifdef DEBUG - : initialized_(false), loaded_(false) -#endif - {} - - SortedUniformMap(void *start, std::size_t /*allocated*/) : - begin_(Packing::FromVoid(reinterpret_cast<uint64_t*>(start) + 1)), - end_(begin_), size_ptr_(reinterpret_cast<uint64_t*>(start)) -#ifdef DEBUG - , initialized_(true), loaded_(false) -#endif - {} - - void LoadedBinary() { -#ifdef DEBUG - assert(initialized_); - assert(!loaded_); - loaded_ = true; -#endif - // Restore the size. - end_ = begin_ + *size_ptr_; - } - - // Caller responsible for not exceeding specified size. Do not call after FinishedInserting. - template <class T> void Insert(const T &t) { -#ifdef DEBUG - assert(initialized_); - assert(!loaded_); -#endif - *end_ = t; - ++end_; - } - - void FinishedInserting() { -#ifdef DEBUG - assert(initialized_); - assert(!loaded_); - loaded_ = true; -#endif - std::sort(begin_, end_); - *size_ptr_ = (end_ - begin_); - } - - // Don't use this to change the key. - template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) { -#ifdef DEBUG - assert(initialized_); - assert(loaded_); -#endif - return SortedUniformFind<MutableIterator, Accessor, Pivot64>(begin_, end_, key, out); - } - - // Do not call before FinishedInserting. - template <class Key> bool Find(const Key key, ConstIterator &out) const { -#ifdef DEBUG - assert(initialized_); - assert(loaded_); -#endif - return SortedUniformFind<ConstIterator, Accessor, Pivot64>(Accessor(), ConstIterator(begin_), ConstIterator(end_), key, out); - } - - ConstIterator begin() const { return begin_; } - ConstIterator end() const { return end_; } - - private: - typename Packing::MutableIterator begin_, end_; - uint64_t *size_ptr_; -#ifdef DEBUG - bool initialized_; - bool loaded_; -#endif -}; - } // namespace util #endif // UTIL_SORTED_UNIFORM__ diff --git a/klm/util/sorted_uniform_test.cc b/klm/util/sorted_uniform_test.cc index 4aa4c8aa..d9f6fad1 100644 --- a/klm/util/sorted_uniform_test.cc +++ b/klm/util/sorted_uniform_test.cc @@ -1,12 +1,11 @@ #include "util/sorted_uniform.hh" -#include "util/key_value_packing.hh" - #include <boost/random/mersenne_twister.hpp> #include <boost/random/uniform_int.hpp> #include <boost/random/variate_generator.hpp> #include <boost/scoped_array.hpp> #include <boost/unordered_map.hpp> + #define BOOST_TEST_MODULE SortedUniformTest #include <boost/test/unit_test.hpp> @@ -17,74 +16,86 @@ namespace util { namespace { -template <class Map, class Key, class Value> void Check(const Map &map, const boost::unordered_map<Key, Value> &reference, const Key key) { +template <class KeyT, class ValueT> struct Entry { + typedef KeyT Key; + typedef ValueT Value; + + Key key; + Value value; + + Key GetKey() const { + return key; + } + + Value GetValue() const { + return value; + } + + bool operator<(const Entry<Key,Value> &other) const { + return key < other.key; + } +}; + +template <class KeyT> struct Accessor { + typedef KeyT Key; + template <class Value> Key operator()(const Entry<Key, Value> *entry) const { + return entry->GetKey(); + } +}; + +template <class Key, class Value> void Check(const Entry<Key, Value> *begin, const Entry<Key, Value> *end, const boost::unordered_map<Key, Value> &reference, const Key key) { typename boost::unordered_map<Key, Value>::const_iterator ref = reference.find(key); - typename Map::ConstIterator i = typename Map::ConstIterator(); + typedef const Entry<Key, Value> *It; + // g++ can't tell that require will crash and burn. + It i = NULL; + bool ret = SortedUniformFind<It, Accessor<Key>, Pivot64>(Accessor<Key>(), begin, end, key, i); if (ref == reference.end()) { - BOOST_CHECK(!map.Find(key, i)); + BOOST_CHECK(!ret); } else { - // g++ can't tell that require will crash and burn. - BOOST_REQUIRE(map.Find(key, i)); + BOOST_REQUIRE(ret); BOOST_CHECK_EQUAL(ref->second, i->GetValue()); } } -typedef SortedUniformMap<AlignedPacking<uint64_t, uint32_t> > TestMap; - BOOST_AUTO_TEST_CASE(empty) { - char buf[TestMap::Size(0)]; - TestMap map(buf, TestMap::Size(0)); - map.FinishedInserting(); - TestMap::ConstIterator i; - BOOST_CHECK(!map.Find(42, i)); -} - -BOOST_AUTO_TEST_CASE(one) { - char buf[TestMap::Size(1)]; - TestMap map(buf, sizeof(buf)); - Entry<uint64_t, uint32_t> e; - e.Set(42,2); - map.Insert(e); - map.FinishedInserting(); - TestMap::ConstIterator i = TestMap::ConstIterator(); - BOOST_REQUIRE(map.Find(42, i)); - BOOST_CHECK(i == map.begin()); - BOOST_CHECK(!map.Find(43, i)); - BOOST_CHECK(!map.Find(41, i)); + typedef const Entry<uint64_t, float> T; + const T *i; + bool ret = SortedUniformFind<const T*, Accessor<uint64_t>, Pivot64>(Accessor<uint64_t>(), (const T*)NULL, (const T*)NULL, (uint64_t)10, i); + BOOST_CHECK(!ret); } template <class Key> void RandomTest(Key upper, size_t entries, size_t queries) { typedef unsigned char Value; - typedef SortedUniformMap<AlignedPacking<Key, unsigned char> > Map; - boost::scoped_array<char> buffer(new char[Map::Size(entries)]); - Map map(buffer.get(), entries); boost::mt19937 rng; boost::uniform_int<Key> range_key(0, upper); boost::uniform_int<Value> range_value(0, 255); boost::variate_generator<boost::mt19937&, boost::uniform_int<Key> > gen_key(rng, range_key); boost::variate_generator<boost::mt19937&, boost::uniform_int<unsigned char> > gen_value(rng, range_value); + typedef Entry<Key, Value> Ent; + std::vector<Ent> backing; boost::unordered_map<Key, unsigned char> reference; - Entry<Key, unsigned char> ent; + Ent ent; for (size_t i = 0; i < entries; ++i) { Key key = gen_key(); unsigned char value = gen_value(); if (reference.insert(std::make_pair(key, value)).second) { - ent.Set(key, value); - map.Insert(Entry<Key, unsigned char>(ent)); + ent.key = key; + ent.value = value; + backing.push_back(ent); } } - map.FinishedInserting(); + std::sort(backing.begin(), backing.end()); // Random queries. for (size_t i = 0; i < queries; ++i) { const Key key = gen_key(); - Check<Map, Key, unsigned char>(map, reference, key); + Check<Key, unsigned char>(&*backing.begin(), &*backing.end(), reference, key); } typename boost::unordered_map<Key, unsigned char>::const_iterator it = reference.begin(); for (size_t i = 0; (i < queries) && (it != reference.end()); ++i, ++it) { - Check<Map, Key, unsigned char>(map, reference, it->second); + Check<Key, unsigned char>(&*backing.begin(), &*backing.end(), reference, it->second); } } diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index 413bda0b..c7e1c863 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -1,6 +1,7 @@ #ifndef UTIL_TOKENIZE_PIECE__ #define UTIL_TOKENIZE_PIECE__ +#include "util/exception.hh" #include "util/string_piece.hh" #include <boost/iterator/iterator_facade.hpp> @@ -8,63 +9,25 @@ #include <algorithm> #include <iostream> -/* Usage: - * - * for (PieceIterator<' '> i(" foo \r\n bar "); i; ++i) { - * std::cout << *i << "\n"; - * } - * - */ - namespace util { -// Tokenize a StringPiece using an iterator interface. boost::tokenizer doesn't work with StringPiece. -template <char d> class PieceIterator : public boost::iterator_facade<PieceIterator<d>, const StringPiece, boost::forward_traversal_tag> { +// Thrown on dereference when out of tokens to parse +class OutOfTokens : public Exception { public: - // Default construct is end, which is also accessed by kEndPieceIterator; - PieceIterator() {} - - explicit PieceIterator(const StringPiece &str) - : after_(str) { - increment(); - } + OutOfTokens() throw() {} + ~OutOfTokens() throw() {} +}; - bool operator!() const { - return after_.data() == 0; - } - operator bool() const { - return after_.data() != 0; - } +class SingleCharacter { + public: + explicit SingleCharacter(char delim) : delim_(delim) {} - static PieceIterator<d> end() { - return PieceIterator<d>(); + StringPiece Find(const StringPiece &in) const { + return StringPiece(std::find(in.data(), in.data() + in.size(), delim_), 1); } private: - friend class boost::iterator_core_access; - - void increment() { - const char *start = after_.data(); - for (; (start != after_.data() + after_.size()) && (d == *start); ++start) {} - if (start == after_.data() + after_.size()) { - // End condition. - after_.clear(); - return; - } - const char *finish = start; - for (; (finish != after_.data() + after_.size()) && (d != *finish); ++finish) {} - current_ = StringPiece(start, finish - start); - after_ = StringPiece(finish, after_.data() + after_.size() - finish); - } - - bool equal(const PieceIterator &other) const { - return after_.data() == other.after_.data(); - } - - const StringPiece &dereference() const { return current_; } - - StringPiece current_; - StringPiece after_; + char delim_; }; class MultiCharacter { @@ -95,7 +58,7 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it public: TokenIter() {} - TokenIter(const StringPiece &str, const Find &finder) : after_(str), finder_(finder) { + template <class Construct> TokenIter(const StringPiece &str, const Construct &construct) : after_(str), finder_(construct) { increment(); } @@ -130,6 +93,7 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it } const StringPiece &dereference() const { + UTIL_THROW_IF(!current_.data(), OutOfTokens, "Ran out of tokens"); return current_; } diff --git a/klm/util/tokenize_piece_test.cc b/klm/util/tokenize_piece_test.cc index e07ebcf5..d856018f 100644 --- a/klm/util/tokenize_piece_test.cc +++ b/klm/util/tokenize_piece_test.cc @@ -9,53 +9,7 @@ namespace util { namespace { -BOOST_AUTO_TEST_CASE(simple) { - PieceIterator<' '> it("single spaced words."); - BOOST_REQUIRE(it); - BOOST_CHECK_EQUAL(StringPiece("single"), *it); - ++it; - BOOST_REQUIRE(it); - BOOST_CHECK_EQUAL(StringPiece("spaced"), *it); - ++it; - BOOST_REQUIRE(it); - BOOST_CHECK_EQUAL(StringPiece("words."), *it); - ++it; - BOOST_CHECK(!it); -} - -BOOST_AUTO_TEST_CASE(null_delimiter) { - const char str[] = "\0first\0\0second\0\0\0third\0fourth\0\0\0"; - PieceIterator<'\0'> it(StringPiece(str, sizeof(str) - 1)); - BOOST_REQUIRE(it); - BOOST_CHECK_EQUAL(StringPiece("first"), *it); - ++it; - BOOST_REQUIRE(it); - BOOST_CHECK_EQUAL(StringPiece("second"), *it); - ++it; - BOOST_REQUIRE(it); - BOOST_CHECK_EQUAL(StringPiece("third"), *it); - ++it; - BOOST_REQUIRE(it); - BOOST_CHECK_EQUAL(StringPiece("fourth"), *it); - ++it; - BOOST_CHECK(!it); -} - -BOOST_AUTO_TEST_CASE(null_entries) { - const char str[] = "\0split\0\0 \0me\0 "; - PieceIterator<' '> it(StringPiece(str, sizeof(str) - 1)); - BOOST_REQUIRE(it); - const char first[] = "\0split\0\0"; - BOOST_CHECK_EQUAL(StringPiece(first, sizeof(first) - 1), *it); - ++it; - BOOST_REQUIRE(it); - const char second[] = "\0me\0"; - BOOST_CHECK_EQUAL(StringPiece(second, sizeof(second) - 1), *it); - ++it; - BOOST_CHECK(!it); -} - -/*BOOST_AUTO_TEST_CASE(pipe_pipe_none) { +BOOST_AUTO_TEST_CASE(pipe_pipe_none) { const char str[] = "nodelimit at all"; TokenIter<MultiCharacter> it(str, MultiCharacter("|||")); BOOST_REQUIRE(it); @@ -79,7 +33,7 @@ BOOST_AUTO_TEST_CASE(remove_empty) { const char str[] = "|||"; TokenIter<MultiCharacter, true> it(str, MultiCharacter("|||")); BOOST_CHECK(!it); -}*/ +} BOOST_AUTO_TEST_CASE(remove_empty_keep) { const char str[] = " |||"; |