From d3e2ec203a5cf550320caa8023ac3dd103b0be7d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 13 Oct 2014 00:42:37 -0400 Subject: new kenlm --- klm/util/read_compressed.cc | 417 +++++++++++++++++++++----------------------- 1 file changed, 195 insertions(+), 222 deletions(-) (limited to 'klm/util/read_compressed.cc') diff --git a/klm/util/read_compressed.cc b/klm/util/read_compressed.cc index b62a6e83..cee98040 100644 --- a/klm/util/read_compressed.cc +++ b/klm/util/read_compressed.cc @@ -49,6 +49,8 @@ class ReadBase { thunk.internal_.reset(with); } + ReadBase *Current(ReadCompressed &thunk) { return thunk.internal_.get(); } + static uint64_t &ReadCount(ReadCompressed &thunk) { return thunk.raw_amount_; } @@ -56,6 +58,8 @@ class ReadBase { namespace { +ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, std::size_t already_size, bool require_compressed); + // Completed file that other classes can thunk to. class Complete : public ReadBase { public: @@ -80,7 +84,7 @@ class Uncompressed : public ReadBase { class UncompressedWithHeader : public ReadBase { public: - UncompressedWithHeader(int fd, void *already_data, std::size_t already_size) : fd_(fd) { + UncompressedWithHeader(int fd, const void *already_data, std::size_t already_size) : fd_(fd) { assert(already_size); buf_.reset(malloc(already_size)); if (!buf_.get()) throw std::bad_alloc(); @@ -91,6 +95,7 @@ class UncompressedWithHeader : public ReadBase { std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { assert(buf_.get()); + assert(remain_ != end_); std::size_t sending = std::min(amount, end_ - remain_); memcpy(to, remain_, sending); remain_ += sending; @@ -108,23 +113,51 @@ class UncompressedWithHeader : public ReadBase { scoped_fd fd_; }; -#ifdef HAVE_ZLIB -class GZip : public ReadBase { +static const std::size_t kInputBuffer = 16384; + +template class StreamCompressed : public ReadBase { + public: + StreamCompressed(int fd, const void *already_data, std::size_t already_size) + : file_(fd), + in_buffer_(MallocOrThrow(kInputBuffer)), + back_(memcpy(in_buffer_.get(), already_data, already_size), already_size) {} + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + if (amount == 0) return 0; + back_.SetOutput(to, amount); + do { + if (!back_.Stream().avail_in) ReadInput(thunk); + if (!back_.Process()) { + // reached end, at least for the compressed portion. + std::size_t ret = static_cast(static_cast(back_.Stream().next_out)) - static_cast(to); + ReplaceThis(ReadFactory(file_.release(), ReadCount(thunk), back_.Stream().next_in, back_.Stream().avail_in, true), thunk); + if (ret) return ret; + // We did not read anything this round, so clients might think EOF. Transfer responsibility to the next reader. + return Current(thunk)->Read(to, amount, thunk); + } + } while (back_.Stream().next_out == to); + return static_cast(static_cast(back_.Stream().next_out)) - static_cast(to); + } + private: - static const std::size_t kInputBuffer = 16384; + void ReadInput(ReadCompressed &thunk) { + assert(!back_.Stream().avail_in); + std::size_t got = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); + back_.SetInput(in_buffer_.get(), got); + ReadCount(thunk) += got; + } + + scoped_fd file_; + scoped_malloc in_buffer_; + + Compression back_; +}; + +#ifdef HAVE_ZLIB +class GZip { public: - GZip(int fd, void *already_data, std::size_t already_size) - : file_(fd), in_buffer_(malloc(kInputBuffer)) { - if (!in_buffer_.get()) throw std::bad_alloc(); - assert(already_size < kInputBuffer); - if (already_size) { - memcpy(in_buffer_.get(), already_data, already_size); - stream_.next_in = static_cast(in_buffer_.get()); - stream_.avail_in = already_size; - stream_.avail_in += ReadOrEOF(file_.get(), static_cast(in_buffer_.get()) + already_size, kInputBuffer - already_size); - } else { - stream_.avail_in = 0; - } + GZip(const void *base, std::size_t amount) { + SetInput(base, amount); stream_.zalloc = Z_NULL; stream_.zfree = Z_NULL; stream_.opaque = Z_NULL; @@ -141,227 +174,154 @@ class GZip : public ReadBase { } } - std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { - if (amount == 0) return 0; + void SetOutput(void *to, std::size_t amount) { stream_.next_out = static_cast(to); stream_.avail_out = std::min(std::numeric_limits::max(), amount); - do { - if (!stream_.avail_in) ReadInput(thunk); - int result = inflate(&stream_, 0); - switch (result) { - case Z_OK: - break; - case Z_STREAM_END: - { - std::size_t ret = static_cast(stream_.next_out) - static_cast(to); - ReplaceThis(new Complete(), thunk); - return ret; - } - case Z_ERRNO: - UTIL_THROW(ErrnoException, "zlib error"); - default: - UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result); - } - } while (stream_.next_out == to); - return static_cast(stream_.next_out) - static_cast(to); } - private: - void ReadInput(ReadCompressed &thunk) { - assert(!stream_.avail_in); - stream_.next_in = static_cast(in_buffer_.get()); - stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); - ReadCount(thunk) += stream_.avail_in; + void SetInput(const void *base, std::size_t amount) { + assert(amount < static_cast(std::numeric_limits::max())); + stream_.next_in = const_cast(static_cast(base)); + stream_.avail_in = amount; } - scoped_fd file_; - scoped_malloc in_buffer_; + const z_stream &Stream() const { return stream_; } + + bool Process() { + int result = inflate(&stream_, 0); + switch (result) { + case Z_OK: + return true; + case Z_STREAM_END: + return false; + case Z_ERRNO: + UTIL_THROW(ErrnoException, "zlib error"); + default: + UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result); + } + } + + private: z_stream stream_; }; #endif // HAVE_ZLIB -const uint8_t kBZMagic[3] = {'B', 'Z', 'h'}; - #ifdef HAVE_BZLIB -class BZip : public ReadBase { +class BZip { public: - BZip(int fd, void *already_data, std::size_t already_size) { - scoped_fd hold(fd); - closer_.reset(FDOpenReadOrThrow(hold)); - file_ = NULL; - Open(already_data, already_size); + BZip(const void *base, std::size_t amount) { + memset(&stream_, 0, sizeof(stream_)); + SetInput(base, amount); + HandleError(BZ2_bzDecompressInit(&stream_, 0, 0)); } - BZip(FILE *file, void *already_data, std::size_t already_size) { - closer_.reset(file); - file_ = NULL; - Open(already_data, already_size); + ~BZip() { + try { + HandleError(BZ2_bzDecompressEnd(&stream_)); + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); + } } - ~BZip() { - Close(file_); + bool Process() { + int ret = BZ2_bzDecompress(&stream_); + if (ret == BZ_STREAM_END) return false; + HandleError(ret); + return true; } - std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { - assert(file_); - int bzerror = BZ_OK; - int ret = BZ2_bzRead(&bzerror, file_, to, std::min(static_cast(INT_MAX), amount)); - long pos = ftell(closer_.get()); - if (pos != -1) ReadCount(thunk) = pos; - switch (bzerror) { - case BZ_STREAM_END: - /* bzip2 files can be concatenated by e.g. pbzip2. Annoyingly, the - * library doesn't handle this internally. This gets the trailing - * data, grows it up to magic as needed, validates the magic, and - * reopens. - */ - { - bzerror = BZ_OK; - void *trailing_data; - int trailing_size; - BZ2_bzReadGetUnused(&bzerror, file_, &trailing_data, &trailing_size); - UTIL_THROW_IF(bzerror != BZ_OK, BZException, "bzip2 error in BZ2_bzReadGetUnused " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror); - std::string trailing(static_cast(trailing_data), trailing_size); - Close(file_); - - if (trailing_size < (int)sizeof(kBZMagic)) { - trailing.resize(sizeof(kBZMagic)); - if (1 != fread(&trailing[trailing_size], sizeof(kBZMagic) - trailing_size, 1, closer_.get())) { - UTIL_THROW_IF(trailing_size, BZException, "File has trailing cruft"); - // Legitimate end of file. - ReplaceThis(new Complete(), thunk); - return ret; - } - } - UTIL_THROW_IF(memcmp(trailing.data(), kBZMagic, sizeof(kBZMagic)), BZException, "Trailing cruft is not another bzip2 stream"); - Open(&trailing[0], trailing.size()); - } - return ret; - case BZ_OK: - return ret; - default: - UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror); - } + void SetOutput(void *base, std::size_t amount) { + stream_.next_out = static_cast(base); + stream_.avail_out = std::min(std::numeric_limits::max(), amount); } + void SetInput(const void *base, std::size_t amount) { + stream_.next_in = const_cast(static_cast(base)); + stream_.avail_in = amount; + } + + const bz_stream &Stream() const { return stream_; } + private: - void Open(void *already_data, std::size_t already_size) { - assert(!file_); - int bzerror = BZ_OK; - file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size); - switch (bzerror) { + void HandleError(int value) { + switch(value) { case BZ_OK: return; case BZ_CONFIG_ERROR: - UTIL_THROW(BZException, "Looks like bzip2 was miscompiled."); + UTIL_THROW(BZException, "bzip2 seems to be miscompiled."); case BZ_PARAM_ERROR: - UTIL_THROW(BZException, "Parameter error"); - case BZ_IO_ERROR: - UTIL_THROW(BZException, "IO error reading file"); + UTIL_THROW(BZException, "bzip2 Parameter error"); + case BZ_DATA_ERROR: + UTIL_THROW(BZException, "bzip2 detected a corrupt file"); + case BZ_DATA_ERROR_MAGIC: + UTIL_THROW(BZException, "bzip2 detected bad magic bytes. Perhaps this was not a bzip2 file after all?"); case BZ_MEM_ERROR: throw std::bad_alloc(); default: - UTIL_THROW(BZException, "Unknown bzip2 error code " << bzerror); + UTIL_THROW(BZException, "Unknown bzip2 error code " << value); } - assert(file_); } - static void Close(BZFILE *&file) { - if (file == NULL) return; - int bzerror = BZ_OK; - BZ2_bzReadClose(&bzerror, file); - if (bzerror != BZ_OK) { - std::cerr << "bz2 readclose error number " << bzerror << std::endl; - abort(); - } - file = NULL; - } - - scoped_FILE closer_; - BZFILE *file_; + bz_stream stream_; }; #endif // HAVE_BZLIB #ifdef HAVE_XZLIB -class XZip : public ReadBase { - private: - static const std::size_t kInputBuffer = 16384; +class XZip { public: - XZip(int fd, void *already_data, std::size_t already_size) - : file_(fd), in_buffer_(malloc(kInputBuffer)), stream_(), action_(LZMA_RUN) { - if (!in_buffer_.get()) throw std::bad_alloc(); - assert(already_size < kInputBuffer); - if (already_size) { - memcpy(in_buffer_.get(), already_data, already_size); - stream_.next_in = static_cast(in_buffer_.get()); - stream_.avail_in = already_size; - stream_.avail_in += ReadOrEOF(file_.get(), static_cast(in_buffer_.get()) + already_size, kInputBuffer - already_size); - } else { - stream_.avail_in = 0; - } - stream_.allocator = NULL; - lzma_ret ret = lzma_stream_decoder(&stream_, UINT64_MAX, LZMA_CONCATENATED); - switch (ret) { - case LZMA_OK: - break; - case LZMA_MEM_ERROR: - UTIL_THROW(ErrnoException, "xz open error"); - default: - UTIL_THROW(XZException, "xz error code " << ret); - } + XZip(const void *base, std::size_t amount) + : stream_(), action_(LZMA_RUN) { + memset(&stream_, 0, sizeof(stream_)); + SetInput(base, amount); + HandleError(lzma_stream_decoder(&stream_, UINT64_MAX, 0)); } ~XZip() { lzma_end(&stream_); } - std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { - if (amount == 0) return 0; - stream_.next_out = static_cast(to); + void SetOutput(void *base, std::size_t amount) { + stream_.next_out = static_cast(base); stream_.avail_out = amount; - do { - if (!stream_.avail_in) ReadInput(thunk); - lzma_ret status = lzma_code(&stream_, action_); - switch (status) { - case LZMA_OK: - break; - case LZMA_STREAM_END: - UTIL_THROW_IF(action_ != LZMA_FINISH, XZException, "Input not finished yet."); - { - std::size_t ret = static_cast(stream_.next_out) - static_cast(to); - ReplaceThis(new Complete(), thunk); - return ret; - } - case LZMA_MEM_ERROR: - throw std::bad_alloc(); - case LZMA_FORMAT_ERROR: - UTIL_THROW(XZException, "xzlib says file format not recognized"); - case LZMA_OPTIONS_ERROR: - UTIL_THROW(XZException, "xzlib says unsupported compression options"); - case LZMA_DATA_ERROR: - UTIL_THROW(XZException, "xzlib says this file is corrupt"); - case LZMA_BUF_ERROR: - UTIL_THROW(XZException, "xzlib says unexpected end of input"); - default: - UTIL_THROW(XZException, "unrecognized xzlib error " << status); - } - } while (stream_.next_out == to); - return static_cast(stream_.next_out) - static_cast(to); + } + + void SetInput(const void *base, std::size_t amount) { + stream_.next_in = static_cast(base); + stream_.avail_in = amount; + if (!amount) action_ = LZMA_FINISH; + } + + const lzma_stream &Stream() const { return stream_; } + + bool Process() { + lzma_ret status = lzma_code(&stream_, action_); + if (status == LZMA_STREAM_END) return false; + HandleError(status); + return true; } private: - void ReadInput(ReadCompressed &thunk) { - assert(!stream_.avail_in); - stream_.next_in = static_cast(in_buffer_.get()); - stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); - if (!stream_.avail_in) action_ = LZMA_FINISH; - ReadCount(thunk) += stream_.avail_in; + void HandleError(lzma_ret value) { + switch (value) { + case LZMA_OK: + return; + case LZMA_MEM_ERROR: + throw std::bad_alloc(); + case LZMA_FORMAT_ERROR: + UTIL_THROW(XZException, "xzlib says file format not recognized"); + case LZMA_OPTIONS_ERROR: + UTIL_THROW(XZException, "xzlib says unsupported compression options"); + case LZMA_DATA_ERROR: + UTIL_THROW(XZException, "xzlib says this file is corrupt"); + case LZMA_BUF_ERROR: + UTIL_THROW(XZException, "xzlib says unexpected end of input"); + default: + UTIL_THROW(XZException, "unrecognized xzlib error " << value); + } } - scoped_fd file_; - scoped_malloc in_buffer_; lzma_stream stream_; - lzma_action action_; }; #endif // HAVE_XZLIB @@ -384,66 +344,67 @@ class IStreamReader : public ReadBase { }; enum MagicResult { - UNKNOWN, GZIP, BZIP, XZIP + UTIL_UNKNOWN, UTIL_GZIP, UTIL_BZIP, UTIL_XZIP }; -MagicResult DetectMagic(const void *from_void) { +MagicResult DetectMagic(const void *from_void, std::size_t length) { const uint8_t *header = static_cast(from_void); - if (header[0] == 0x1f && header[1] == 0x8b) { - return GZIP; + if (length >= 2 && header[0] == 0x1f && header[1] == 0x8b) { + return UTIL_GZIP; } - if (!memcmp(header, kBZMagic, sizeof(kBZMagic))) { - return BZIP; + const uint8_t kBZMagic[3] = {'B', 'Z', 'h'}; + if (length >= sizeof(kBZMagic) && !memcmp(header, kBZMagic, sizeof(kBZMagic))) { + return UTIL_BZIP; } const uint8_t kXZMagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 }; - if (!memcmp(header, kXZMagic, sizeof(kXZMagic))) { - return XZIP; + if (length >= sizeof(kXZMagic) && !memcmp(header, kXZMagic, sizeof(kXZMagic))) { + return UTIL_XZIP; } - return UNKNOWN; + return UTIL_UNKNOWN; } -ReadBase *ReadFactory(int fd, uint64_t &raw_amount) { +ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, const std::size_t already_size, bool require_compressed) { scoped_fd hold(fd); - unsigned char header[ReadCompressed::kMagicSize]; - raw_amount = ReadOrEOF(fd, header, ReadCompressed::kMagicSize); - if (!raw_amount) - return new Uncompressed(hold.release()); - if (raw_amount != ReadCompressed::kMagicSize) - return new UncompressedWithHeader(hold.release(), header, raw_amount); - switch (DetectMagic(header)) { - case GZIP: + std::string header(reinterpret_cast(already_data), already_size); + if (header.size() < ReadCompressed::kMagicSize) { + std::size_t original = header.size(); + header.resize(ReadCompressed::kMagicSize); + std::size_t got = ReadOrEOF(fd, &header[original], ReadCompressed::kMagicSize - original); + raw_amount += got; + header.resize(original + got); + } + if (header.empty()) { + return new Complete(); + } + switch (DetectMagic(&header[0], header.size())) { + case UTIL_GZIP: #ifdef HAVE_ZLIB - return new GZip(hold.release(), header, ReadCompressed::kMagicSize); + return new StreamCompressed(hold.release(), header.data(), header.size()); #else UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in."); #endif - case BZIP: + case UTIL_BZIP: #ifdef HAVE_BZLIB - return new BZip(hold.release(), header, ReadCompressed::kMagicSize); + return new StreamCompressed(hold.release(), &header[0], header.size()); #else - UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZ), but bzip support was not compiled in."); + UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZh), but bzip support was not compiled in."); #endif - case XZIP: + case UTIL_XZIP: #ifdef HAVE_XZLIB - return new XZip(hold.release(), header, ReadCompressed::kMagicSize); + return new StreamCompressed(hold.release(), header.data(), header.size()); #else UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in."); #endif - case UNKNOWN: - break; - } - try { - SeekOrThrow(fd, 0); - } catch (const util::ErrnoException &e) { - return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize); + default: + UTIL_THROW_IF(require_compressed, CompressedException, "Uncompressed data detected after a compresssed file. This could be supported but usually indicates an error."); + return new UncompressedWithHeader(hold.release(), header.data(), header.size()); } - return new Uncompressed(hold.release()); } } // namespace bool ReadCompressed::DetectCompressedMagic(const void *from_void) { - return DetectMagic(from_void) != UNKNOWN; + return DetectMagic(from_void, kMagicSize) != UTIL_UNKNOWN; } ReadCompressed::ReadCompressed(int fd) { @@ -459,8 +420,9 @@ ReadCompressed::ReadCompressed() {} ReadCompressed::~ReadCompressed() {} void ReadCompressed::Reset(int fd) { + raw_amount_ = 0; internal_.reset(); - internal_.reset(ReadFactory(fd, raw_amount_)); + internal_.reset(ReadFactory(fd, raw_amount_, NULL, 0, false)); } void ReadCompressed::Reset(std::istream &in) { @@ -472,4 +434,15 @@ std::size_t ReadCompressed::Read(void *to, std::size_t amount) { return internal_->Read(to, amount, *this); } +std::size_t ReadCompressed::ReadOrEOF(void *const to_in, std::size_t amount) { + uint8_t *to = reinterpret_cast(to_in); + while (amount) { + std::size_t got = Read(to, amount); + if (!got) break; + to += got; + amount -= got; + } + return to - reinterpret_cast(to_in); +} + } // namespace util -- cgit v1.2.3