summaryrefslogtreecommitdiff
path: root/klm/util/read_compressed.cc
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-10-13 00:42:37 -0400
committerChris Dyer <redpony@gmail.com>2014-10-13 00:42:37 -0400
commitb1ed81ef3216b212295afa76c5d20a56fb647204 (patch)
tree9633cdc1b8a341dfa58b0b7fec0e2cae44d28835 /klm/util/read_compressed.cc
parent1b17f61d359be6e1c3cea29f8c100db3bcdd73a0 (diff)
new kenlm
Diffstat (limited to 'klm/util/read_compressed.cc')
-rw-r--r--klm/util/read_compressed.cc417
1 files changed, 195 insertions, 222 deletions
diff --git a/klm/util/read_compressed.cc b/klm/util/read_compressed.cc
index b62a6e83..cee98040 100644
--- a/klm/util/read_compressed.cc
+++ b/klm/util/read_compressed.cc
@@ -49,6 +49,8 @@ class ReadBase {
thunk.internal_.reset(with);
}
+ ReadBase *Current(ReadCompressed &thunk) { return thunk.internal_.get(); }
+
static uint64_t &ReadCount(ReadCompressed &thunk) {
return thunk.raw_amount_;
}
@@ -56,6 +58,8 @@ class ReadBase {
namespace {
+ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, std::size_t already_size, bool require_compressed);
+
// Completed file that other classes can thunk to.
class Complete : public ReadBase {
public:
@@ -80,7 +84,7 @@ class Uncompressed : public ReadBase {
class UncompressedWithHeader : public ReadBase {
public:
- UncompressedWithHeader(int fd, void *already_data, std::size_t already_size) : fd_(fd) {
+ UncompressedWithHeader(int fd, const void *already_data, std::size_t already_size) : fd_(fd) {
assert(already_size);
buf_.reset(malloc(already_size));
if (!buf_.get()) throw std::bad_alloc();
@@ -91,6 +95,7 @@ class UncompressedWithHeader : public ReadBase {
std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
assert(buf_.get());
+ assert(remain_ != end_);
std::size_t sending = std::min<std::size_t>(amount, end_ - remain_);
memcpy(to, remain_, sending);
remain_ += sending;
@@ -108,23 +113,51 @@ class UncompressedWithHeader : public ReadBase {
scoped_fd fd_;
};
-#ifdef HAVE_ZLIB
-class GZip : public ReadBase {
+static const std::size_t kInputBuffer = 16384;
+
+template <class Compression> class StreamCompressed : public ReadBase {
+ public:
+ StreamCompressed(int fd, const void *already_data, std::size_t already_size)
+ : file_(fd),
+ in_buffer_(MallocOrThrow(kInputBuffer)),
+ back_(memcpy(in_buffer_.get(), already_data, already_size), already_size) {}
+
+ std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
+ if (amount == 0) return 0;
+ back_.SetOutput(to, amount);
+ do {
+ if (!back_.Stream().avail_in) ReadInput(thunk);
+ if (!back_.Process()) {
+ // reached end, at least for the compressed portion.
+ std::size_t ret = static_cast<const uint8_t *>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to);
+ ReplaceThis(ReadFactory(file_.release(), ReadCount(thunk), back_.Stream().next_in, back_.Stream().avail_in, true), thunk);
+ if (ret) return ret;
+ // We did not read anything this round, so clients might think EOF. Transfer responsibility to the next reader.
+ return Current(thunk)->Read(to, amount, thunk);
+ }
+ } while (back_.Stream().next_out == to);
+ return static_cast<const uint8_t*>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to);
+ }
+
private:
- static const std::size_t kInputBuffer = 16384;
+ void ReadInput(ReadCompressed &thunk) {
+ assert(!back_.Stream().avail_in);
+ std::size_t got = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
+ back_.SetInput(in_buffer_.get(), got);
+ ReadCount(thunk) += got;
+ }
+
+ scoped_fd file_;
+ scoped_malloc in_buffer_;
+
+ Compression back_;
+};
+
+#ifdef HAVE_ZLIB
+class GZip {
public:
- GZip(int fd, void *already_data, std::size_t already_size)
- : file_(fd), in_buffer_(malloc(kInputBuffer)) {
- if (!in_buffer_.get()) throw std::bad_alloc();
- assert(already_size < kInputBuffer);
- if (already_size) {
- memcpy(in_buffer_.get(), already_data, already_size);
- stream_.next_in = static_cast<Bytef *>(in_buffer_.get());
- stream_.avail_in = already_size;
- stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size);
- } else {
- stream_.avail_in = 0;
- }
+ GZip(const void *base, std::size_t amount) {
+ SetInput(base, amount);
stream_.zalloc = Z_NULL;
stream_.zfree = Z_NULL;
stream_.opaque = Z_NULL;
@@ -141,227 +174,154 @@ class GZip : public ReadBase {
}
}
- std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
- if (amount == 0) return 0;
+ void SetOutput(void *to, std::size_t amount) {
stream_.next_out = static_cast<Bytef*>(to);
stream_.avail_out = std::min<std::size_t>(std::numeric_limits<uInt>::max(), amount);
- do {
- if (!stream_.avail_in) ReadInput(thunk);
- int result = inflate(&stream_, 0);
- switch (result) {
- case Z_OK:
- break;
- case Z_STREAM_END:
- {
- std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
- ReplaceThis(new Complete(), thunk);
- return ret;
- }
- case Z_ERRNO:
- UTIL_THROW(ErrnoException, "zlib error");
- default:
- UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result);
- }
- } while (stream_.next_out == to);
- return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
}
- private:
- void ReadInput(ReadCompressed &thunk) {
- assert(!stream_.avail_in);
- stream_.next_in = static_cast<Bytef *>(in_buffer_.get());
- stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
- ReadCount(thunk) += stream_.avail_in;
+ void SetInput(const void *base, std::size_t amount) {
+ assert(amount < static_cast<std::size_t>(std::numeric_limits<uInt>::max()));
+ stream_.next_in = const_cast<Bytef*>(static_cast<const Bytef*>(base));
+ stream_.avail_in = amount;
}
- scoped_fd file_;
- scoped_malloc in_buffer_;
+ const z_stream &Stream() const { return stream_; }
+
+ bool Process() {
+ int result = inflate(&stream_, 0);
+ switch (result) {
+ case Z_OK:
+ return true;
+ case Z_STREAM_END:
+ return false;
+ case Z_ERRNO:
+ UTIL_THROW(ErrnoException, "zlib error");
+ default:
+ UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result);
+ }
+ }
+
+ private:
z_stream stream_;
};
#endif // HAVE_ZLIB
-const uint8_t kBZMagic[3] = {'B', 'Z', 'h'};
-
#ifdef HAVE_BZLIB
-class BZip : public ReadBase {
+class BZip {
public:
- BZip(int fd, void *already_data, std::size_t already_size) {
- scoped_fd hold(fd);
- closer_.reset(FDOpenReadOrThrow(hold));
- file_ = NULL;
- Open(already_data, already_size);
+ BZip(const void *base, std::size_t amount) {
+ memset(&stream_, 0, sizeof(stream_));
+ SetInput(base, amount);
+ HandleError(BZ2_bzDecompressInit(&stream_, 0, 0));
}
- BZip(FILE *file, void *already_data, std::size_t already_size) {
- closer_.reset(file);
- file_ = NULL;
- Open(already_data, already_size);
+ ~BZip() {
+ try {
+ HandleError(BZ2_bzDecompressEnd(&stream_));
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ abort();
+ }
}
- ~BZip() {
- Close(file_);
+ bool Process() {
+ int ret = BZ2_bzDecompress(&stream_);
+ if (ret == BZ_STREAM_END) return false;
+ HandleError(ret);
+ return true;
}
- std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
- assert(file_);
- int bzerror = BZ_OK;
- int ret = BZ2_bzRead(&bzerror, file_, to, std::min<std::size_t>(static_cast<std::size_t>(INT_MAX), amount));
- long pos = ftell(closer_.get());
- if (pos != -1) ReadCount(thunk) = pos;
- switch (bzerror) {
- case BZ_STREAM_END:
- /* bzip2 files can be concatenated by e.g. pbzip2. Annoyingly, the
- * library doesn't handle this internally. This gets the trailing
- * data, grows it up to magic as needed, validates the magic, and
- * reopens.
- */
- {
- bzerror = BZ_OK;
- void *trailing_data;
- int trailing_size;
- BZ2_bzReadGetUnused(&bzerror, file_, &trailing_data, &trailing_size);
- UTIL_THROW_IF(bzerror != BZ_OK, BZException, "bzip2 error in BZ2_bzReadGetUnused " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror);
- std::string trailing(static_cast<const char*>(trailing_data), trailing_size);
- Close(file_);
-
- if (trailing_size < (int)sizeof(kBZMagic)) {
- trailing.resize(sizeof(kBZMagic));
- if (1 != fread(&trailing[trailing_size], sizeof(kBZMagic) - trailing_size, 1, closer_.get())) {
- UTIL_THROW_IF(trailing_size, BZException, "File has trailing cruft");
- // Legitimate end of file.
- ReplaceThis(new Complete(), thunk);
- return ret;
- }
- }
- UTIL_THROW_IF(memcmp(trailing.data(), kBZMagic, sizeof(kBZMagic)), BZException, "Trailing cruft is not another bzip2 stream");
- Open(&trailing[0], trailing.size());
- }
- return ret;
- case BZ_OK:
- return ret;
- default:
- UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror);
- }
+ void SetOutput(void *base, std::size_t amount) {
+ stream_.next_out = static_cast<char*>(base);
+ stream_.avail_out = std::min<std::size_t>(std::numeric_limits<unsigned int>::max(), amount);
}
+ void SetInput(const void *base, std::size_t amount) {
+ stream_.next_in = const_cast<char*>(static_cast<const char*>(base));
+ stream_.avail_in = amount;
+ }
+
+ const bz_stream &Stream() const { return stream_; }
+
private:
- void Open(void *already_data, std::size_t already_size) {
- assert(!file_);
- int bzerror = BZ_OK;
- file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size);
- switch (bzerror) {
+ void HandleError(int value) {
+ switch(value) {
case BZ_OK:
return;
case BZ_CONFIG_ERROR:
- UTIL_THROW(BZException, "Looks like bzip2 was miscompiled.");
+ UTIL_THROW(BZException, "bzip2 seems to be miscompiled.");
case BZ_PARAM_ERROR:
- UTIL_THROW(BZException, "Parameter error");
- case BZ_IO_ERROR:
- UTIL_THROW(BZException, "IO error reading file");
+ UTIL_THROW(BZException, "bzip2 Parameter error");
+ case BZ_DATA_ERROR:
+ UTIL_THROW(BZException, "bzip2 detected a corrupt file");
+ case BZ_DATA_ERROR_MAGIC:
+ UTIL_THROW(BZException, "bzip2 detected bad magic bytes. Perhaps this was not a bzip2 file after all?");
case BZ_MEM_ERROR:
throw std::bad_alloc();
default:
- UTIL_THROW(BZException, "Unknown bzip2 error code " << bzerror);
+ UTIL_THROW(BZException, "Unknown bzip2 error code " << value);
}
- assert(file_);
}
- static void Close(BZFILE *&file) {
- if (file == NULL) return;
- int bzerror = BZ_OK;
- BZ2_bzReadClose(&bzerror, file);
- if (bzerror != BZ_OK) {
- std::cerr << "bz2 readclose error number " << bzerror << std::endl;
- abort();
- }
- file = NULL;
- }
-
- scoped_FILE closer_;
- BZFILE *file_;
+ bz_stream stream_;
};
#endif // HAVE_BZLIB
#ifdef HAVE_XZLIB
-class XZip : public ReadBase {
- private:
- static const std::size_t kInputBuffer = 16384;
+class XZip {
public:
- XZip(int fd, void *already_data, std::size_t already_size)
- : file_(fd), in_buffer_(malloc(kInputBuffer)), stream_(), action_(LZMA_RUN) {
- if (!in_buffer_.get()) throw std::bad_alloc();
- assert(already_size < kInputBuffer);
- if (already_size) {
- memcpy(in_buffer_.get(), already_data, already_size);
- stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get());
- stream_.avail_in = already_size;
- stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size);
- } else {
- stream_.avail_in = 0;
- }
- stream_.allocator = NULL;
- lzma_ret ret = lzma_stream_decoder(&stream_, UINT64_MAX, LZMA_CONCATENATED);
- switch (ret) {
- case LZMA_OK:
- break;
- case LZMA_MEM_ERROR:
- UTIL_THROW(ErrnoException, "xz open error");
- default:
- UTIL_THROW(XZException, "xz error code " << ret);
- }
+ XZip(const void *base, std::size_t amount)
+ : stream_(), action_(LZMA_RUN) {
+ memset(&stream_, 0, sizeof(stream_));
+ SetInput(base, amount);
+ HandleError(lzma_stream_decoder(&stream_, UINT64_MAX, 0));
}
~XZip() {
lzma_end(&stream_);
}
- std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
- if (amount == 0) return 0;
- stream_.next_out = static_cast<uint8_t*>(to);
+ void SetOutput(void *base, std::size_t amount) {
+ stream_.next_out = static_cast<uint8_t*>(base);
stream_.avail_out = amount;
- do {
- if (!stream_.avail_in) ReadInput(thunk);
- lzma_ret status = lzma_code(&stream_, action_);
- switch (status) {
- case LZMA_OK:
- break;
- case LZMA_STREAM_END:
- UTIL_THROW_IF(action_ != LZMA_FINISH, XZException, "Input not finished yet.");
- {
- std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
- ReplaceThis(new Complete(), thunk);
- return ret;
- }
- case LZMA_MEM_ERROR:
- throw std::bad_alloc();
- case LZMA_FORMAT_ERROR:
- UTIL_THROW(XZException, "xzlib says file format not recognized");
- case LZMA_OPTIONS_ERROR:
- UTIL_THROW(XZException, "xzlib says unsupported compression options");
- case LZMA_DATA_ERROR:
- UTIL_THROW(XZException, "xzlib says this file is corrupt");
- case LZMA_BUF_ERROR:
- UTIL_THROW(XZException, "xzlib says unexpected end of input");
- default:
- UTIL_THROW(XZException, "unrecognized xzlib error " << status);
- }
- } while (stream_.next_out == to);
- return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
+ }
+
+ void SetInput(const void *base, std::size_t amount) {
+ stream_.next_in = static_cast<const uint8_t*>(base);
+ stream_.avail_in = amount;
+ if (!amount) action_ = LZMA_FINISH;
+ }
+
+ const lzma_stream &Stream() const { return stream_; }
+
+ bool Process() {
+ lzma_ret status = lzma_code(&stream_, action_);
+ if (status == LZMA_STREAM_END) return false;
+ HandleError(status);
+ return true;
}
private:
- void ReadInput(ReadCompressed &thunk) {
- assert(!stream_.avail_in);
- stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get());
- stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
- if (!stream_.avail_in) action_ = LZMA_FINISH;
- ReadCount(thunk) += stream_.avail_in;
+ void HandleError(lzma_ret value) {
+ switch (value) {
+ case LZMA_OK:
+ return;
+ case LZMA_MEM_ERROR:
+ throw std::bad_alloc();
+ case LZMA_FORMAT_ERROR:
+ UTIL_THROW(XZException, "xzlib says file format not recognized");
+ case LZMA_OPTIONS_ERROR:
+ UTIL_THROW(XZException, "xzlib says unsupported compression options");
+ case LZMA_DATA_ERROR:
+ UTIL_THROW(XZException, "xzlib says this file is corrupt");
+ case LZMA_BUF_ERROR:
+ UTIL_THROW(XZException, "xzlib says unexpected end of input");
+ default:
+ UTIL_THROW(XZException, "unrecognized xzlib error " << value);
+ }
}
- scoped_fd file_;
- scoped_malloc in_buffer_;
lzma_stream stream_;
-
lzma_action action_;
};
#endif // HAVE_XZLIB
@@ -384,66 +344,67 @@ class IStreamReader : public ReadBase {
};
enum MagicResult {
- UNKNOWN, GZIP, BZIP, XZIP
+ UTIL_UNKNOWN, UTIL_GZIP, UTIL_BZIP, UTIL_XZIP
};
-MagicResult DetectMagic(const void *from_void) {
+MagicResult DetectMagic(const void *from_void, std::size_t length) {
const uint8_t *header = static_cast<const uint8_t*>(from_void);
- if (header[0] == 0x1f && header[1] == 0x8b) {
- return GZIP;
+ if (length >= 2 && header[0] == 0x1f && header[1] == 0x8b) {
+ return UTIL_GZIP;
}
- if (!memcmp(header, kBZMagic, sizeof(kBZMagic))) {
- return BZIP;
+ const uint8_t kBZMagic[3] = {'B', 'Z', 'h'};
+ if (length >= sizeof(kBZMagic) && !memcmp(header, kBZMagic, sizeof(kBZMagic))) {
+ return UTIL_BZIP;
}
const uint8_t kXZMagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 };
- if (!memcmp(header, kXZMagic, sizeof(kXZMagic))) {
- return XZIP;
+ if (length >= sizeof(kXZMagic) && !memcmp(header, kXZMagic, sizeof(kXZMagic))) {
+ return UTIL_XZIP;
}
- return UNKNOWN;
+ return UTIL_UNKNOWN;
}
-ReadBase *ReadFactory(int fd, uint64_t &raw_amount) {
+ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, const std::size_t already_size, bool require_compressed) {
scoped_fd hold(fd);
- unsigned char header[ReadCompressed::kMagicSize];
- raw_amount = ReadOrEOF(fd, header, ReadCompressed::kMagicSize);
- if (!raw_amount)
- return new Uncompressed(hold.release());
- if (raw_amount != ReadCompressed::kMagicSize)
- return new UncompressedWithHeader(hold.release(), header, raw_amount);
- switch (DetectMagic(header)) {
- case GZIP:
+ std::string header(reinterpret_cast<const char*>(already_data), already_size);
+ if (header.size() < ReadCompressed::kMagicSize) {
+ std::size_t original = header.size();
+ header.resize(ReadCompressed::kMagicSize);
+ std::size_t got = ReadOrEOF(fd, &header[original], ReadCompressed::kMagicSize - original);
+ raw_amount += got;
+ header.resize(original + got);
+ }
+ if (header.empty()) {
+ return new Complete();
+ }
+ switch (DetectMagic(&header[0], header.size())) {
+ case UTIL_GZIP:
#ifdef HAVE_ZLIB
- return new GZip(hold.release(), header, ReadCompressed::kMagicSize);
+ return new StreamCompressed<GZip>(hold.release(), header.data(), header.size());
#else
UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in.");
#endif
- case BZIP:
+ case UTIL_BZIP:
#ifdef HAVE_BZLIB
- return new BZip(hold.release(), header, ReadCompressed::kMagicSize);
+ return new StreamCompressed<BZip>(hold.release(), &header[0], header.size());
#else
- UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZ), but bzip support was not compiled in.");
+ UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZh), but bzip support was not compiled in.");
#endif
- case XZIP:
+ case UTIL_XZIP:
#ifdef HAVE_XZLIB
- return new XZip(hold.release(), header, ReadCompressed::kMagicSize);
+ return new StreamCompressed<XZip>(hold.release(), header.data(), header.size());
#else
UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in.");
#endif
- case UNKNOWN:
- break;
- }
- try {
- SeekOrThrow(fd, 0);
- } catch (const util::ErrnoException &e) {
- return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize);
+ default:
+ UTIL_THROW_IF(require_compressed, CompressedException, "Uncompressed data detected after a compresssed file. This could be supported but usually indicates an error.");
+ return new UncompressedWithHeader(hold.release(), header.data(), header.size());
}
- return new Uncompressed(hold.release());
}
} // namespace
bool ReadCompressed::DetectCompressedMagic(const void *from_void) {
- return DetectMagic(from_void) != UNKNOWN;
+ return DetectMagic(from_void, kMagicSize) != UTIL_UNKNOWN;
}
ReadCompressed::ReadCompressed(int fd) {
@@ -459,8 +420,9 @@ ReadCompressed::ReadCompressed() {}
ReadCompressed::~ReadCompressed() {}
void ReadCompressed::Reset(int fd) {
+ raw_amount_ = 0;
internal_.reset();
- internal_.reset(ReadFactory(fd, raw_amount_));
+ internal_.reset(ReadFactory(fd, raw_amount_, NULL, 0, false));
}
void ReadCompressed::Reset(std::istream &in) {
@@ -472,4 +434,15 @@ std::size_t ReadCompressed::Read(void *to, std::size_t amount) {
return internal_->Read(to, amount, *this);
}
+std::size_t ReadCompressed::ReadOrEOF(void *const to_in, std::size_t amount) {
+ uint8_t *to = reinterpret_cast<uint8_t*>(to_in);
+ while (amount) {
+ std::size_t got = Read(to, amount);
+ if (!got) break;
+ to += got;
+ amount -= got;
+ }
+ return to - reinterpret_cast<uint8_t*>(to_in);
+}
+
} // namespace util