summaryrefslogtreecommitdiff
path: root/klm/util/stream/sort.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/util/stream/sort.hh')
-rw-r--r--klm/util/stream/sort.hh542
1 files changed, 542 insertions, 0 deletions
diff --git a/klm/util/stream/sort.hh b/klm/util/stream/sort.hh
new file mode 100644
index 00000000..be6c11ea
--- /dev/null
+++ b/klm/util/stream/sort.hh
@@ -0,0 +1,542 @@
+/* Usage:
+ * Sort<Compare> sorter(temp, compare);
+ * Chain(config) >> Read(file) >> sorter.Unsorted();
+ * Stream stream;
+ * Chain chain(config) >> sorter.Sorted(internal_config, lazy_config) >> stream;
+ *
+ * Note that sorter must outlive any threads that use Unsorted or Sorted.
+ *
+ * Combiners take the form:
+ * bool operator()(void *into, const void *option, const Compare &compare) const
+ * which returns true iff a combination happened. The sorting algorithm
+ * guarantees compare(into, option). But it does not guarantee
+ * compare(option, into).
+ * Currently, combining is only done in merge steps, not during on-the-fly
+ * sort. Use a hash table for that.
+ */
+
+#ifndef UTIL_STREAM_SORT__
+#define UTIL_STREAM_SORT__
+
+#include "util/stream/chain.hh"
+#include "util/stream/config.hh"
+#include "util/stream/io.hh"
+#include "util/stream/stream.hh"
+#include "util/stream/timer.hh"
+
+#include "util/file.hh"
+#include "util/scoped.hh"
+#include "util/sized_iterator.hh"
+
+#include <algorithm>
+#include <iostream>
+#include <queue>
+#include <string>
+
+namespace util {
+namespace stream {
+
+struct NeverCombine {
+ template <class Compare> bool operator()(const void *, const void *, const Compare &) const {
+ return false;
+ }
+};
+
+// Manage the offsets of sorted blocks in a file.
+class Offsets {
+ public:
+ explicit Offsets(int fd) : log_(fd) {
+ Reset();
+ }
+
+ int File() const { return log_; }
+
+ void Append(uint64_t length) {
+ if (!length) return;
+ ++block_count_;
+ if (length == cur_.length) {
+ ++cur_.run;
+ return;
+ }
+ WriteOrThrow(log_, &cur_, sizeof(Entry));
+ cur_.length = length;
+ cur_.run = 1;
+ }
+
+ void FinishedAppending() {
+ WriteOrThrow(log_, &cur_, sizeof(Entry));
+ SeekOrThrow(log_, sizeof(Entry)); // Skip 0,0 at beginning.
+ cur_.run = 0;
+ if (block_count_) {
+ ReadOrThrow(log_, &cur_, sizeof(Entry));
+ assert(cur_.length);
+ assert(cur_.run);
+ }
+ }
+
+ uint64_t RemainingBlocks() const { return block_count_; }
+
+ uint64_t TotalOffset() const { return output_sum_; }
+
+ uint64_t PeekSize() const {
+ return cur_.length;
+ }
+
+ uint64_t NextSize() {
+ assert(block_count_);
+ uint64_t ret = cur_.length;
+ output_sum_ += ret;
+
+ --cur_.run;
+ --block_count_;
+ if (!cur_.run && block_count_) {
+ ReadOrThrow(log_, &cur_, sizeof(Entry));
+ assert(cur_.length);
+ assert(cur_.run);
+ }
+ return ret;
+ }
+
+ void Reset() {
+ SeekOrThrow(log_, 0);
+ ResizeOrThrow(log_, 0);
+ cur_.length = 0;
+ cur_.run = 0;
+ block_count_ = 0;
+ output_sum_ = 0;
+ }
+
+ private:
+ int log_;
+
+ struct Entry {
+ uint64_t length;
+ uint64_t run;
+ };
+ Entry cur_;
+
+ uint64_t block_count_;
+
+ uint64_t output_sum_;
+};
+
+// A priority queue of entries backed by file buffers
+template <class Compare> class MergeQueue {
+ public:
+ MergeQueue(int fd, std::size_t buffer_size, std::size_t entry_size, const Compare &compare)
+ : queue_(Greater(compare)), in_(fd), buffer_size_(buffer_size), entry_size_(entry_size) {}
+
+ void Push(void *base, uint64_t offset, uint64_t amount) {
+ queue_.push(Entry(base, in_, offset, amount, buffer_size_));
+ }
+
+ const void *Top() const {
+ return queue_.top().Current();
+ }
+
+ void Pop() {
+ Entry top(queue_.top());
+ queue_.pop();
+ if (top.Increment(in_, buffer_size_, entry_size_))
+ queue_.push(top);
+ }
+
+ std::size_t Size() const {
+ return queue_.size();
+ }
+
+ bool Empty() const {
+ return queue_.empty();
+ }
+
+ private:
+ // Priority queue contains these entries.
+ class Entry {
+ public:
+ Entry() {}
+
+ Entry(void *base, int fd, uint64_t offset, uint64_t amount, std::size_t buf_size) {
+ offset_ = offset;
+ remaining_ = amount;
+ buffer_end_ = static_cast<uint8_t*>(base) + buf_size;
+ Read(fd, buf_size);
+ }
+
+ bool Increment(int fd, std::size_t buf_size, std::size_t entry_size) {
+ current_ += entry_size;
+ if (current_ != buffer_end_) return true;
+ return Read(fd, buf_size);
+ }
+
+ const void *Current() const { return current_; }
+
+ private:
+ bool Read(int fd, std::size_t buf_size) {
+ current_ = buffer_end_ - buf_size;
+ std::size_t amount;
+ if (static_cast<uint64_t>(buf_size) < remaining_) {
+ amount = buf_size;
+ } else if (!remaining_) {
+ return false;
+ } else {
+ amount = remaining_;
+ buffer_end_ = current_ + remaining_;
+ }
+ PReadOrThrow(fd, current_, amount, offset_);
+ offset_ += amount;
+ assert(current_ <= buffer_end_);
+ remaining_ -= amount;
+ return true;
+ }
+
+ // Buffer
+ uint8_t *current_, *buffer_end_;
+ // File
+ uint64_t remaining_, offset_;
+ };
+
+ // Wrapper comparison function for queue entries.
+ class Greater : public std::binary_function<const Entry &, const Entry &, bool> {
+ public:
+ explicit Greater(const Compare &compare) : compare_(compare) {}
+
+ bool operator()(const Entry &first, const Entry &second) const {
+ return compare_(second.Current(), first.Current());
+ }
+
+ private:
+ const Compare compare_;
+ };
+
+ typedef std::priority_queue<Entry, std::vector<Entry>, Greater> Queue;
+ Queue queue_;
+
+ const int in_;
+ const std::size_t buffer_size_;
+ const std::size_t entry_size_;
+};
+
+/* A worker object that merges. If the number of pieces to merge exceeds the
+ * arity, it outputs multiple sorted blocks, recording to out_offsets.
+ * However, users will only every see a single sorted block out output because
+ * Sort::Sorted insures the arity is higher than the number of pieces before
+ * returning this.
+ */
+template <class Compare, class Combine> class MergingReader {
+ public:
+ MergingReader(int in, Offsets *in_offsets, Offsets *out_offsets, std::size_t buffer_size, std::size_t total_memory, const Compare &compare, const Combine &combine) :
+ compare_(compare), combine_(combine),
+ in_(in),
+ in_offsets_(in_offsets), out_offsets_(out_offsets),
+ buffer_size_(buffer_size), total_memory_(total_memory) {}
+
+ void Run(const ChainPosition &position) {
+ Run(position, false);
+ }
+
+ void Run(const ChainPosition &position, bool assert_one) {
+ // Special case: nothing to read.
+ if (!in_offsets_->RemainingBlocks()) {
+ Link l(position);
+ l.Poison();
+ return;
+ }
+ // If there's just one entry, just read.
+ if (in_offsets_->RemainingBlocks() == 1) {
+ // Sequencing is important.
+ uint64_t offset = in_offsets_->TotalOffset();
+ uint64_t amount = in_offsets_->NextSize();
+ ReadSingle(offset, amount, position);
+ if (out_offsets_) out_offsets_->Append(amount);
+ return;
+ }
+
+ Stream str(position);
+ scoped_malloc buffer(MallocOrThrow(total_memory_));
+ uint8_t *const buffer_end = static_cast<uint8_t*>(buffer.get()) + total_memory_;
+
+ const std::size_t entry_size = position.GetChain().EntrySize();
+
+ while (in_offsets_->RemainingBlocks()) {
+ // Use bigger buffers if there's less remaining.
+ uint64_t per_buffer = std::max(buffer_size_, total_memory_ / in_offsets_->RemainingBlocks());
+ per_buffer -= per_buffer % entry_size;
+ assert(per_buffer);
+
+ // Populate queue.
+ MergeQueue<Compare> queue(in_, per_buffer, entry_size, compare_);
+ for (uint8_t *buf = static_cast<uint8_t*>(buffer.get());
+ in_offsets_->RemainingBlocks() && (buf + std::min(per_buffer, in_offsets_->PeekSize()) <= buffer_end);) {
+ uint64_t offset = in_offsets_->TotalOffset();
+ uint64_t size = in_offsets_->NextSize();
+ queue.Push(buf, offset, size);
+ buf += static_cast<std::size_t>(std::min<uint64_t>(size, per_buffer));
+ }
+ // This shouldn't happen but it's probably better to die than loop indefinitely.
+ if (queue.Size() < 2 && in_offsets_->RemainingBlocks()) {
+ std::cerr << "Bug in sort implementation: not merging at least two stripes." << std::endl;
+ abort();
+ }
+ if (assert_one && in_offsets_->RemainingBlocks()) {
+ std::cerr << "Bug in sort implementation: should only be one merge group for lazy sort" << std::endl;
+ abort();
+ }
+
+ uint64_t written = 0;
+ // Merge including combiner support.
+ memcpy(str.Get(), queue.Top(), entry_size);
+ for (queue.Pop(); !queue.Empty(); queue.Pop()) {
+ if (!combine_(str.Get(), queue.Top(), compare_)) {
+ ++written; ++str;
+ memcpy(str.Get(), queue.Top(), entry_size);
+ }
+ }
+ ++written; ++str;
+ if (out_offsets_)
+ out_offsets_->Append(written * entry_size);
+ }
+ str.Poison();
+ }
+
+ private:
+ void ReadSingle(uint64_t offset, const uint64_t size, const ChainPosition &position) {
+ // Special case: only one to read.
+ const uint64_t end = offset + size;
+ const uint64_t block_size = position.GetChain().BlockSize();
+ Link l(position);
+ for (; offset + block_size < end; ++l, offset += block_size) {
+ PReadOrThrow(in_, l->Get(), block_size, offset);
+ l->SetValidSize(block_size);
+ }
+ PReadOrThrow(in_, l->Get(), end - offset, offset);
+ l->SetValidSize(end - offset);
+ (++l).Poison();
+ return;
+ }
+
+ Compare compare_;
+ Combine combine_;
+
+ int in_;
+
+ protected:
+ Offsets *in_offsets_;
+
+ private:
+ Offsets *out_offsets_;
+
+ std::size_t buffer_size_;
+ std::size_t total_memory_;
+};
+
+// The lazy step owns the remaining files. This keeps track of them.
+template <class Compare, class Combine> class OwningMergingReader : public MergingReader<Compare, Combine> {
+ private:
+ typedef MergingReader<Compare, Combine> P;
+ public:
+ OwningMergingReader(int data, const Offsets &offsets, std::size_t buffer, std::size_t lazy, const Compare &compare, const Combine &combine)
+ : P(data, NULL, NULL, buffer, lazy, compare, combine),
+ data_(data),
+ offsets_(offsets) {}
+
+ void Run(const ChainPosition &position) {
+ P::in_offsets_ = &offsets_;
+ scoped_fd data(data_);
+ scoped_fd offsets_file(offsets_.File());
+ P::Run(position, true);
+ }
+
+ private:
+ int data_;
+ Offsets offsets_;
+};
+
+// Don't use this directly. Worker that sorts blocks.
+template <class Compare> class BlockSorter {
+ public:
+ BlockSorter(Offsets &offsets, const Compare &compare) :
+ offsets_(&offsets), compare_(compare) {}
+
+ void Run(const ChainPosition &position) {
+ const std::size_t entry_size = position.GetChain().EntrySize();
+ for (Link link(position); link; ++link) {
+ // Record the size of each block in a separate file.
+ offsets_->Append(link->ValidSize());
+ void *end = static_cast<uint8_t*>(link->Get()) + link->ValidSize();
+ std::sort(
+ SizedIt(link->Get(), entry_size),
+ SizedIt(end, entry_size),
+ compare_);
+ }
+ offsets_->FinishedAppending();
+ }
+
+ private:
+ Offsets *offsets_;
+ SizedCompare<Compare> compare_;
+};
+
+class BadSortConfig : public Exception {
+ public:
+ BadSortConfig() throw() {}
+ ~BadSortConfig() throw() {}
+};
+
+template <class Compare, class Combine = NeverCombine> class Sort {
+ public:
+ Sort(Chain &in, const SortConfig &config, const Compare &compare = Compare(), const Combine &combine = Combine())
+ : config_(config),
+ data_(MakeTemp(config.temp_prefix)),
+ offsets_file_(MakeTemp(config.temp_prefix)), offsets_(offsets_file_.get()),
+ compare_(compare), combine_(combine),
+ entry_size_(in.EntrySize()) {
+ UTIL_THROW_IF(!entry_size_, BadSortConfig, "Sorting entries of size 0");
+ // Make buffer_size a multiple of the entry_size.
+ config_.buffer_size -= config_.buffer_size % entry_size_;
+ UTIL_THROW_IF(!config_.buffer_size, BadSortConfig, "Sort buffer too small");
+ UTIL_THROW_IF(config_.total_memory < config_.buffer_size * 4, BadSortConfig, "Sorting memory " << config_.total_memory << " is too small for four buffers (two read and two write).");
+ in >> BlockSorter<Compare>(offsets_, compare_) >> WriteAndRecycle(data_.get());
+ }
+
+ uint64_t Size() const {
+ return SizeOrThrow(data_.get());
+ }
+
+ // Do merge sort, terminating when lazy merge could be done with the
+ // specified memory. Return the minimum memory necessary to do lazy merge.
+ std::size_t Merge(std::size_t lazy_memory) {
+ if (offsets_.RemainingBlocks() <= 1) return 0;
+ const uint64_t lazy_arity = std::max<uint64_t>(1, lazy_memory / config_.buffer_size);
+ uint64_t size = Size();
+ /* No overflow because
+ * offsets_.RemainingBlocks() * config_.buffer_size <= lazy_memory ||
+ * size < lazy_memory
+ */
+ if (offsets_.RemainingBlocks() <= lazy_arity || size <= static_cast<uint64_t>(lazy_memory))
+ return std::min<std::size_t>(size, offsets_.RemainingBlocks() * config_.buffer_size);
+
+ scoped_fd data2(MakeTemp(config_.temp_prefix));
+ int fd_in = data_.get(), fd_out = data2.get();
+ scoped_fd offsets2_file(MakeTemp(config_.temp_prefix));
+ Offsets offsets2(offsets2_file.get());
+ Offsets *offsets_in = &offsets_, *offsets_out = &offsets2;
+
+ // Double buffered writing.
+ ChainConfig chain_config;
+ chain_config.entry_size = entry_size_;
+ chain_config.block_count = 2;
+ chain_config.total_memory = config_.buffer_size * 2;
+ Chain chain(chain_config);
+
+ while (offsets_in->RemainingBlocks() > lazy_arity) {
+ if (size <= static_cast<uint64_t>(lazy_memory)) break;
+ std::size_t reading_memory = config_.total_memory - 2 * config_.buffer_size;
+ if (size < static_cast<uint64_t>(reading_memory)) {
+ reading_memory = static_cast<std::size_t>(size);
+ }
+ SeekOrThrow(fd_in, 0);
+ chain >>
+ MergingReader<Compare, Combine>(
+ fd_in,
+ offsets_in, offsets_out,
+ config_.buffer_size,
+ reading_memory,
+ compare_, combine_) >>
+ WriteAndRecycle(fd_out);
+ chain.Wait();
+ offsets_out->FinishedAppending();
+ ResizeOrThrow(fd_in, 0);
+ offsets_in->Reset();
+ std::swap(fd_in, fd_out);
+ std::swap(offsets_in, offsets_out);
+ size = SizeOrThrow(fd_in);
+ }
+
+ SeekOrThrow(fd_in, 0);
+ if (fd_in == data2.get()) {
+ data_.reset(data2.release());
+ offsets_file_.reset(offsets2_file.release());
+ offsets_ = offsets2;
+ }
+ if (offsets_.RemainingBlocks() <= 1) return 0;
+ // No overflow because the while loop exited.
+ return std::min(size, offsets_.RemainingBlocks() * static_cast<uint64_t>(config_.buffer_size));
+ }
+
+ // Output to chain, using this amount of memory, maximum, for lazy merge
+ // sort.
+ void Output(Chain &out, std::size_t lazy_memory) {
+ Merge(lazy_memory);
+ out.SetProgressTarget(Size());
+ out >> OwningMergingReader<Compare, Combine>(data_.get(), offsets_, config_.buffer_size, lazy_memory, compare_, combine_);
+ data_.release();
+ offsets_file_.release();
+ }
+
+ /* If a pipeline step is reading sorted input and writing to a different
+ * sort order, then there's a trade-off between using RAM to read lazily
+ * (avoiding copying the file) and using RAM to increase block size and,
+ * therefore, decrease the number of merge sort passes in the next
+ * iteration.
+ *
+ * Merge sort takes log_{arity}(pieces) passes. Thus, each time the chain
+ * block size is multiplied by arity, the number of output passes decreases
+ * by one. Up to a constant, then, log_{arity}(chain) is the number of
+ * passes saved. Chain simply divides the memory evenly over all blocks.
+ *
+ * Lazy sort saves this many passes (up to a constant)
+ * log_{arity}((memory-lazy)/block_count) + 1
+ * Non-lazy sort saves this many passes (up to the same constant):
+ * log_{arity}(memory/block_count)
+ * Add log_{arity}(block_count) to both:
+ * log_{arity}(memory-lazy) + 1 versus log_{arity}(memory)
+ * Take arity to the power of both sizes (arity > 1)
+ * (memory - lazy)*arity versus memory
+ * Solve for lazy
+ * lazy = memory * (arity - 1) / arity
+ */
+ std::size_t DefaultLazy() {
+ float arity = static_cast<float>(config_.total_memory / config_.buffer_size);
+ return static_cast<std::size_t>(static_cast<float>(config_.total_memory) * (arity - 1.0) / arity);
+ }
+
+ // Same as Output with default lazy memory setting.
+ void Output(Chain &out) {
+ Output(out, DefaultLazy());
+ }
+
+ // Completely merge sort and transfer ownership to the caller.
+ int StealCompleted() {
+ // Merge all the way.
+ Merge(0);
+ SeekOrThrow(data_.get(), 0);
+ offsets_file_.reset();
+ return data_.release();
+ }
+
+ private:
+ SortConfig config_;
+
+ scoped_fd data_;
+
+ scoped_fd offsets_file_;
+ Offsets offsets_;
+
+ const Compare compare_;
+ const Combine combine_;
+ const std::size_t entry_size_;
+};
+
+// returns bytes to be read on demand.
+template <class Compare, class Combine> uint64_t BlockingSort(Chain &chain, const SortConfig &config, const Compare &compare = Compare(), const Combine &combine = NeverCombine()) {
+ Sort<Compare, Combine> sorter(chain, config, compare, combine);
+ chain.Wait(true);
+ uint64_t size = sorter.Size();
+ sorter.Output(chain);
+ return size;
+}
+
+} // namespace stream
+} // namespace util
+
+#endif // UTIL_STREAM_SORT__