summaryrefslogtreecommitdiff
path: root/klm/util/stream
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2013-01-21 12:29:43 +0100
committerPatrick Simianer <p@simianer.de>2013-01-21 12:29:43 +0100
commit50f22047eb1b7f2d60e85cdcf0fcd86342e50523 (patch)
tree730dabaf2fa57b1e4536d40f036b46795d37f289 /klm/util/stream
parent8b399cb09513cd79ed4182be9f75882c1e1b336a (diff)
parent608886384da40aedfabd629c882b8ea9b3f6348e (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'klm/util/stream')
-rw-r--r--klm/util/stream/Makefile.am20
-rw-r--r--klm/util/stream/block.hh43
-rw-r--r--klm/util/stream/chain.cc155
-rw-r--r--klm/util/stream/chain.hh198
-rw-r--r--klm/util/stream/config.hh32
-rw-r--r--klm/util/stream/io.cc64
-rw-r--r--klm/util/stream/io.hh76
-rw-r--r--klm/util/stream/io_test.cc38
-rw-r--r--klm/util/stream/line_input.cc52
-rw-r--r--klm/util/stream/line_input.hh22
-rw-r--r--klm/util/stream/multi_progress.cc86
-rw-r--r--klm/util/stream/multi_progress.hh90
-rw-r--r--klm/util/stream/sort.hh544
-rw-r--r--klm/util/stream/sort_test.cc62
-rw-r--r--klm/util/stream/stream.hh74
-rw-r--r--klm/util/stream/stream_test.cc35
-rw-r--r--klm/util/stream/timer.hh16
17 files changed, 1607 insertions, 0 deletions
diff --git a/klm/util/stream/Makefile.am b/klm/util/stream/Makefile.am
new file mode 100644
index 00000000..f18cbedb
--- /dev/null
+++ b/klm/util/stream/Makefile.am
@@ -0,0 +1,20 @@
+noinst_LIBRARIES = libklm_util_stream.a
+
+libklm_util_stream_a_SOURCES = \
+ block.hh \
+ chain.cc \
+ chain.hh \
+ config.hh \
+ io.cc \
+ io.hh \
+ line_input.cc \
+ line_input.hh \
+ multi_progress.cc \
+ multi_progress.hh \
+ sort.hh \
+ stream.hh \
+ timer.hh
+
+AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm
+
+#-I$(top_srcdir)/klm/util/double-conversion
diff --git a/klm/util/stream/block.hh b/klm/util/stream/block.hh
new file mode 100644
index 00000000..11aa991e
--- /dev/null
+++ b/klm/util/stream/block.hh
@@ -0,0 +1,43 @@
+#ifndef UTIL_STREAM_BLOCK__
+#define UTIL_STREAM_BLOCK__
+
+#include <cstddef>
+#include <stdint.h>
+
+namespace util {
+namespace stream {
+
+class Block {
+ public:
+ Block() : mem_(NULL), valid_size_(0) {}
+
+ Block(void *mem, std::size_t size) : mem_(mem), valid_size_(size) {}
+
+ void SetValidSize(std::size_t to) { valid_size_ = to; }
+ // Read might fill in less than Allocated at EOF.
+ std::size_t ValidSize() const { return valid_size_; }
+
+ void *Get() { return mem_; }
+ const void *Get() const { return mem_; }
+
+ const void *ValidEnd() const {
+ return reinterpret_cast<const uint8_t*>(mem_) + valid_size_;
+ }
+
+ operator bool() const { return mem_ != NULL; }
+ bool operator!() const { return mem_ == NULL; }
+
+ private:
+ friend class Link;
+ void SetToPoison() {
+ mem_ = NULL;
+ }
+
+ void *mem_;
+ std::size_t valid_size_;
+};
+
+} // namespace stream
+} // namespace util
+
+#endif // UTIL_STREAM_BLOCK__
diff --git a/klm/util/stream/chain.cc b/klm/util/stream/chain.cc
new file mode 100644
index 00000000..46708c60
--- /dev/null
+++ b/klm/util/stream/chain.cc
@@ -0,0 +1,155 @@
+#include "util/stream/chain.hh"
+
+#include "util/stream/io.hh"
+
+#include "util/exception.hh"
+#include "util/pcqueue.hh"
+
+#include <cstdlib>
+#include <new>
+#include <iostream>
+
+#include <stdint.h>
+#include <stdlib.h>
+
+namespace util {
+namespace stream {
+
+ChainConfigException::ChainConfigException() throw() { *this << "Chain configured with "; }
+ChainConfigException::~ChainConfigException() throw() {}
+
+Thread::~Thread() {
+ thread_.join();
+}
+
+void Thread::UnhandledException(const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ abort();
+}
+
+void Recycler::Run(const ChainPosition &position) {
+ for (Link l(position); l; ++l) {
+ l->SetValidSize(position.GetChain().BlockSize());
+ }
+}
+
+const Recycler kRecycle = Recycler();
+
+Chain::Chain(const ChainConfig &config) : config_(config), complete_called_(false) {
+ UTIL_THROW_IF(!config.entry_size, ChainConfigException, "zero-size entries.");
+ UTIL_THROW_IF(!config.block_count, ChainConfigException, "block count zero");
+ UTIL_THROW_IF(config.total_memory < config.entry_size * config.block_count, ChainConfigException, config.total_memory << " total memory, too small for " << config.block_count << " blocks of containing entries of size " << config.entry_size);
+ // Round down block size to a multiple of entry size.
+ block_size_ = config.total_memory / (config.block_count * config.entry_size) * config.entry_size;
+}
+
+Chain::~Chain() {
+ Wait();
+}
+
+ChainPosition Chain::Add() {
+ if (!Running()) Start();
+ PCQueue<Block> &in = queues_.back();
+ queues_.push_back(new PCQueue<Block>(config_.block_count));
+ return ChainPosition(in, queues_.back(), this, progress_);
+}
+
+Chain &Chain::operator>>(const WriteAndRecycle &writer) {
+ threads_.push_back(new Thread(Complete(), writer));
+ return *this;
+}
+
+void Chain::Wait(bool release_memory) {
+ if (queues_.empty()) {
+ assert(threads_.empty());
+ return; // Nothing to wait for.
+ }
+ if (!complete_called_) CompleteLoop();
+ threads_.clear();
+ for (std::size_t i = 0; queues_.front().Consume(); ++i) {
+ if (i == config_.block_count) {
+ std::cerr << "Chain ending without poison." << std::endl;
+ abort();
+ }
+ }
+ queues_.clear();
+ progress_.Finished();
+ complete_called_ = false;
+ if (release_memory) memory_.reset();
+}
+
+void Chain::Start() {
+ Wait(false);
+ if (!memory_.get()) {
+ // Allocate memory.
+ assert(threads_.empty());
+ assert(queues_.empty());
+ std::size_t malloc_size = block_size_ * config_.block_count;
+ memory_.reset(MallocOrThrow(malloc_size));
+ }
+ // This queue can accomodate all blocks.
+ queues_.push_back(new PCQueue<Block>(config_.block_count));
+ // Populate the lead queue with blocks.
+ uint8_t *base = static_cast<uint8_t*>(memory_.get());
+ for (std::size_t i = 0; i < config_.block_count; ++i) {
+ queues_.front().Produce(Block(base, block_size_));
+ base += block_size_;
+ }
+}
+
+ChainPosition Chain::Complete() {
+ assert(Running());
+ UTIL_THROW_IF(complete_called_, util::Exception, "CompleteLoop() called twice");
+ complete_called_ = true;
+ return ChainPosition(queues_.back(), queues_.front(), this, progress_);
+}
+
+Link::Link() : in_(NULL), out_(NULL), poisoned_(true) {}
+
+void Link::Init(const ChainPosition &position) {
+ UTIL_THROW_IF(in_, util::Exception, "Link::Init twice");
+ in_ = position.in_;
+ out_ = position.out_;
+ poisoned_ = false;
+ progress_ = position.progress_;
+ in_->Consume(current_);
+}
+
+Link::Link(const ChainPosition &position) : in_(NULL) {
+ Init(position);
+}
+
+Link::~Link() {
+ if (current_) {
+ // Probably an exception unwinding.
+ std::cerr << "Last input should have been poison." << std::endl;
+// abort();
+ } else {
+ if (!poisoned_) {
+ // Pass the poison!
+ out_->Produce(current_);
+ }
+ }
+}
+
+Link &Link::operator++() {
+ assert(current_);
+ progress_ += current_.ValidSize();
+ out_->Produce(current_);
+ in_->Consume(current_);
+ if (!current_) {
+ poisoned_ = true;
+ out_->Produce(current_);
+ }
+ return *this;
+}
+
+void Link::Poison() {
+ assert(!poisoned_);
+ current_.SetToPoison();
+ out_->Produce(current_);
+ poisoned_ = true;
+}
+
+} // namespace stream
+} // namespace util
diff --git a/klm/util/stream/chain.hh b/klm/util/stream/chain.hh
new file mode 100644
index 00000000..154b9b33
--- /dev/null
+++ b/klm/util/stream/chain.hh
@@ -0,0 +1,198 @@
+#ifndef UTIL_STREAM_CHAIN__
+#define UTIL_STREAM_CHAIN__
+
+#include "util/stream/block.hh"
+#include "util/stream/config.hh"
+#include "util/stream/multi_progress.hh"
+#include "util/scoped.hh"
+
+#include <boost/ptr_container/ptr_vector.hpp>
+#include <boost/thread/thread.hpp>
+
+#include <cstddef>
+
+#include <assert.h>
+
+namespace util {
+template <class T> class PCQueue;
+namespace stream {
+
+class ChainConfigException : public Exception {
+ public:
+ ChainConfigException() throw();
+ ~ChainConfigException() throw();
+};
+
+class Chain;
+// Specifies position in chain for Link constructor.
+class ChainPosition {
+ public:
+ const Chain &GetChain() const { return *chain_; }
+ private:
+ friend class Chain;
+ friend class Link;
+ ChainPosition(PCQueue<Block> &in, PCQueue<Block> &out, Chain *chain, MultiProgress &progress)
+ : in_(&in), out_(&out), chain_(chain), progress_(progress.Add()) {}
+
+ PCQueue<Block> *in_, *out_;
+
+ Chain *chain_;
+
+ WorkerProgress progress_;
+};
+
+// Position is usually ChainPosition but if there are multiple streams involved, this can be ChainPositions.
+class Thread {
+ public:
+ template <class Position, class Worker> Thread(const Position &position, const Worker &worker)
+ : thread_(boost::ref(*this), position, worker) {}
+
+ ~Thread();
+
+ template <class Position, class Worker> void operator()(const Position &position, Worker &worker) {
+ try {
+ worker.Run(position);
+ } catch (const std::exception &e) {
+ UnhandledException(e);
+ }
+ }
+
+ private:
+ void UnhandledException(const std::exception &e);
+
+ boost::thread thread_;
+};
+
+class Recycler {
+ public:
+ void Run(const ChainPosition &position);
+};
+
+extern const Recycler kRecycle;
+class WriteAndRecycle;
+
+class Chain {
+ private:
+ template <class T, void (T::*ptr)(const ChainPosition &) = &T::Run> struct CheckForRun {
+ typedef Chain type;
+ };
+
+ public:
+ explicit Chain(const ChainConfig &config);
+
+ ~Chain();
+
+ void ActivateProgress() {
+ assert(!Running());
+ progress_.Activate();
+ }
+
+ void SetProgressTarget(uint64_t target) {
+ progress_.SetTarget(target);
+ }
+
+ std::size_t EntrySize() const {
+ return config_.entry_size;
+ }
+ std::size_t BlockSize() const {
+ return block_size_;
+ }
+
+ // Two ways to add to the chain: Add() or operator>>.
+ ChainPosition Add();
+
+ // This is for adding threaded workers with a Run method.
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) {
+ assert(!complete_called_);
+ threads_.push_back(new Thread(Add(), worker));
+ return *this;
+ }
+
+ // Avoid copying the worker.
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) {
+ assert(!complete_called_);
+ threads_.push_back(new Thread(Add(), worker));
+ return *this;
+ }
+
+ // Note that Link and Stream also define operator>> outside this class.
+
+ // To complete the loop, call CompleteLoop(), >> kRecycle, or the destructor.
+ void CompleteLoop() {
+ threads_.push_back(new Thread(Complete(), kRecycle));
+ }
+
+ Chain &operator>>(const Recycler &recycle) {
+ CompleteLoop();
+ return *this;
+ }
+
+ Chain &operator>>(const WriteAndRecycle &writer);
+
+ // Chains are reusable. Call Wait to wait for everything to finish and free memory.
+ void Wait(bool release_memory = true);
+
+ // Waits for the current chain to complete (if any) then starts again.
+ void Start();
+
+ bool Running() const { return !queues_.empty(); }
+
+ private:
+ ChainPosition Complete();
+
+ ChainConfig config_;
+
+ std::size_t block_size_;
+
+ scoped_malloc memory_;
+
+ boost::ptr_vector<PCQueue<Block> > queues_;
+
+ bool complete_called_;
+
+ boost::ptr_vector<Thread> threads_;
+
+ MultiProgress progress_;
+};
+
+// Create the link in the worker thread using the position token.
+class Link {
+ public:
+ // Either default construct and Init or just construct all at once.
+ Link();
+ void Init(const ChainPosition &position);
+
+ explicit Link(const ChainPosition &position);
+
+ ~Link();
+
+ Block &operator*() { return current_; }
+ const Block &operator*() const { return current_; }
+
+ Block *operator->() { return &current_; }
+ const Block *operator->() const { return &current_; }
+
+ Link &operator++();
+
+ operator bool() const { return current_; }
+
+ void Poison();
+
+ private:
+ Block current_;
+ PCQueue<Block> *in_, *out_;
+
+ bool poisoned_;
+
+ WorkerProgress progress_;
+};
+
+inline Chain &operator>>(Chain &chain, Link &link) {
+ link.Init(chain.Add());
+ return chain;
+}
+
+} // namespace stream
+} // namespace util
+
+#endif // UTIL_STREAM_CHAIN__
diff --git a/klm/util/stream/config.hh b/klm/util/stream/config.hh
new file mode 100644
index 00000000..1eeb3a8a
--- /dev/null
+++ b/klm/util/stream/config.hh
@@ -0,0 +1,32 @@
+#ifndef UTIL_STREAM_CONFIG__
+#define UTIL_STREAM_CONFIG__
+
+#include <cstddef>
+#include <string>
+
+namespace util { namespace stream {
+
+struct ChainConfig {
+ ChainConfig() {}
+
+ ChainConfig(std::size_t in_entry_size, std::size_t in_block_count, std::size_t in_total_memory)
+ : entry_size(in_entry_size), block_count(in_block_count), total_memory(in_total_memory) {}
+
+ std::size_t entry_size;
+ std::size_t block_count;
+ // Chain's constructor will make this a multiple of entry_size.
+ std::size_t total_memory;
+};
+
+struct SortConfig {
+ std::string temp_prefix;
+
+ // Size of each input/output buffer.
+ std::size_t buffer_size;
+
+ // Total memory to use when running alone.
+ std::size_t total_memory;
+};
+
+}} // namespaces
+#endif // UTIL_STREAM_CONFIG__
diff --git a/klm/util/stream/io.cc b/klm/util/stream/io.cc
new file mode 100644
index 00000000..c7ad2980
--- /dev/null
+++ b/klm/util/stream/io.cc
@@ -0,0 +1,64 @@
+#include "util/stream/io.hh"
+
+#include "util/file.hh"
+#include "util/stream/chain.hh"
+
+#include <cstddef>
+
+namespace util {
+namespace stream {
+
+ReadSizeException::ReadSizeException() throw() {}
+ReadSizeException::~ReadSizeException() throw() {}
+
+void Read::Run(const ChainPosition &position) {
+ const std::size_t block_size = position.GetChain().BlockSize();
+ const std::size_t entry_size = position.GetChain().EntrySize();
+ for (Link link(position); link; ++link) {
+ std::size_t got = util::ReadOrEOF(file_, link->Get(), block_size);
+ UTIL_THROW_IF(got % entry_size, ReadSizeException, "File ended with " << got << " bytes, not a multiple of " << entry_size << ".");
+ if (got == 0) {
+ link.Poison();
+ return;
+ } else {
+ link->SetValidSize(got);
+ }
+ }
+}
+
+void PRead::Run(const ChainPosition &position) {
+ scoped_fd owner;
+ if (own_) owner.reset(file_);
+ uint64_t size = SizeOrThrow(file_);
+ UTIL_THROW_IF(size % static_cast<uint64_t>(position.GetChain().EntrySize()), ReadSizeException, "File size " << file_ << " size is " << size << " not a multiple of " << position.GetChain().EntrySize());
+ std::size_t block_size = position.GetChain().BlockSize();
+ Link link(position);
+ uint64_t offset = 0;
+ for (; offset + block_size < size; offset += block_size, ++link) {
+ PReadOrThrow(file_, link->Get(), block_size, offset);
+ link->SetValidSize(block_size);
+ }
+ if (size - offset) {
+ PReadOrThrow(file_, link->Get(), size - offset, offset);
+ link->SetValidSize(size - offset);
+ ++link;
+ }
+ link.Poison();
+}
+
+void Write::Run(const ChainPosition &position) {
+ for (Link link(position); link; ++link) {
+ WriteOrThrow(file_, link->Get(), link->ValidSize());
+ }
+}
+
+void WriteAndRecycle::Run(const ChainPosition &position) {
+ const std::size_t block_size = position.GetChain().BlockSize();
+ for (Link link(position); link; ++link) {
+ WriteOrThrow(file_, link->Get(), link->ValidSize());
+ link->SetValidSize(block_size);
+ }
+}
+
+} // namespace stream
+} // namespace util
diff --git a/klm/util/stream/io.hh b/klm/util/stream/io.hh
new file mode 100644
index 00000000..934b6b3f
--- /dev/null
+++ b/klm/util/stream/io.hh
@@ -0,0 +1,76 @@
+#ifndef UTIL_STREAM_IO__
+#define UTIL_STREAM_IO__
+
+#include "util/exception.hh"
+#include "util/file.hh"
+
+namespace util {
+namespace stream {
+
+class ChainPosition;
+
+class ReadSizeException : public util::Exception {
+ public:
+ ReadSizeException() throw();
+ ~ReadSizeException() throw();
+};
+
+class Read {
+ public:
+ explicit Read(int fd) : file_(fd) {}
+ void Run(const ChainPosition &position);
+ private:
+ int file_;
+};
+
+// Like read but uses pread so that the file can be accessed from multiple threads.
+class PRead {
+ public:
+ explicit PRead(int fd, bool take_own = false) : file_(fd), own_(take_own) {}
+ void Run(const ChainPosition &position);
+ private:
+ int file_;
+ bool own_;
+};
+
+class Write {
+ public:
+ explicit Write(int fd) : file_(fd) {}
+ void Run(const ChainPosition &position);
+ private:
+ int file_;
+};
+
+class WriteAndRecycle {
+ public:
+ explicit WriteAndRecycle(int fd) : file_(fd) {}
+ void Run(const ChainPosition &position);
+ private:
+ int file_;
+};
+
+// Reuse the same file over and over again to buffer output.
+class FileBuffer {
+ public:
+ explicit FileBuffer(int fd) : file_(fd) {}
+
+ WriteAndRecycle Sink() const {
+ util::SeekOrThrow(file_.get(), 0);
+ return WriteAndRecycle(file_.get());
+ }
+
+ PRead Source() const {
+ return PRead(file_.get());
+ }
+
+ uint64_t Size() const {
+ return SizeOrThrow(file_.get());
+ }
+
+ private:
+ scoped_fd file_;
+};
+
+} // namespace stream
+} // namespace util
+#endif // UTIL_STREAM_IO__
diff --git a/klm/util/stream/io_test.cc b/klm/util/stream/io_test.cc
new file mode 100644
index 00000000..82108335
--- /dev/null
+++ b/klm/util/stream/io_test.cc
@@ -0,0 +1,38 @@
+#include "util/stream/io.hh"
+
+#include "util/stream/chain.hh"
+#include "util/file.hh"
+
+#define BOOST_TEST_MODULE IOTest
+#include <boost/test/unit_test.hpp>
+
+#include <unistd.h>
+
+namespace util { namespace stream { namespace {
+
+BOOST_AUTO_TEST_CASE(CopyFile) {
+ std::string temps("io_test_temp");
+
+ scoped_fd in(MakeTemp(temps));
+ for (uint64_t i = 0; i < 100000; ++i) {
+ WriteOrThrow(in.get(), &i, sizeof(uint64_t));
+ }
+ SeekOrThrow(in.get(), 0);
+ scoped_fd out(MakeTemp(temps));
+
+ ChainConfig config;
+ config.entry_size = 8;
+ config.total_memory = 1024;
+ config.block_count = 10;
+
+ Chain(config) >> PRead(in.get()) >> Write(out.get());
+
+ SeekOrThrow(out.get(), 0);
+ for (uint64_t i = 0; i < 100000; ++i) {
+ uint64_t got;
+ ReadOrThrow(out.get(), &got, sizeof(uint64_t));
+ BOOST_CHECK_EQUAL(i, got);
+ }
+}
+
+}}} // namespaces
diff --git a/klm/util/stream/line_input.cc b/klm/util/stream/line_input.cc
new file mode 100644
index 00000000..dafa5020
--- /dev/null
+++ b/klm/util/stream/line_input.cc
@@ -0,0 +1,52 @@
+#include "util/stream/line_input.hh"
+
+#include "util/exception.hh"
+#include "util/file.hh"
+#include "util/read_compressed.hh"
+#include "util/stream/chain.hh"
+
+#include <algorithm>
+#include <vector>
+
+namespace util { namespace stream {
+
+void LineInput::Run(const ChainPosition &position) {
+ ReadCompressed reader(fd_);
+ // Holding area for beginning of line to be placed in next block.
+ std::vector<char> carry;
+
+ for (Link block(position); ; ++block) {
+ char *to = static_cast<char*>(block->Get());
+ char *begin = to;
+ char *end = to + position.GetChain().BlockSize();
+ std::copy(carry.begin(), carry.end(), to);
+ to += carry.size();
+ while (to != end) {
+ std::size_t got = reader.Read(to, end - to);
+ if (!got) {
+ // EOF
+ block->SetValidSize(to - begin);
+ ++block;
+ block.Poison();
+ return;
+ }
+ to += got;
+ }
+
+ // Find the last newline.
+ char *newline;
+ for (newline = to - 1; ; --newline) {
+ UTIL_THROW_IF(newline < begin, Exception, "Did not find a newline in " << position.GetChain().BlockSize() << " bytes of input of " << NameFromFD(fd_) << ". Is this a text file?");
+ if (*newline == '\n') break;
+ }
+
+ // Copy everything after the last newline to the carry.
+ carry.clear();
+ carry.resize(to - (newline + 1));
+ std::copy(newline + 1, to, &*carry.begin());
+
+ block->SetValidSize(newline + 1 - begin);
+ }
+}
+
+}} // namespaces
diff --git a/klm/util/stream/line_input.hh b/klm/util/stream/line_input.hh
new file mode 100644
index 00000000..86db1dd0
--- /dev/null
+++ b/klm/util/stream/line_input.hh
@@ -0,0 +1,22 @@
+#ifndef UTIL_STREAM_LINE_INPUT__
+#define UTIL_STREAM_LINE_INPUT__
+namespace util {namespace stream {
+
+class ChainPosition;
+
+/* Worker that reads input into blocks, ensuring that blocks contain whole
+ * lines. Assumes that the maximum size of a line is less than the block size
+ */
+class LineInput {
+ public:
+ // Takes ownership upon thread execution.
+ explicit LineInput(int fd);
+
+ void Run(const ChainPosition &position);
+
+ private:
+ int fd_;
+};
+
+}} // namespaces
+#endif // UTIL_STREAM_LINE_INPUT__
diff --git a/klm/util/stream/multi_progress.cc b/klm/util/stream/multi_progress.cc
new file mode 100644
index 00000000..8ba10386
--- /dev/null
+++ b/klm/util/stream/multi_progress.cc
@@ -0,0 +1,86 @@
+#include "util/stream/multi_progress.hh"
+
+// TODO: merge some functionality with the simple progress bar?
+#include "util/ersatz_progress.hh"
+
+#include <iostream>
+#include <limits>
+
+#include <string.h>
+
+#if !defined(_WIN32) && !defined(_WIN64)
+#include <unistd.h>
+#endif
+
+namespace util { namespace stream {
+
+namespace {
+const char kDisplayCharacters[] = "-+*#0123456789";
+
+uint64_t Next(unsigned char stone, uint64_t complete) {
+ return (static_cast<uint64_t>(stone + 1) * complete + MultiProgress::kWidth - 1) / MultiProgress::kWidth;
+}
+
+} // namespace
+
+MultiProgress::MultiProgress() : active_(false), complete_(std::numeric_limits<uint64_t>::max()), character_handout_(0) {}
+
+MultiProgress::~MultiProgress() {
+ if (active_ && complete_ != std::numeric_limits<uint64_t>::max())
+ std::cerr << '\n';
+}
+
+void MultiProgress::Activate() {
+ active_ =
+#if !defined(_WIN32) && !defined(_WIN64)
+ // Is stderr a terminal?
+ (isatty(2) == 1)
+#else
+ true
+#endif
+ ;
+}
+
+void MultiProgress::SetTarget(uint64_t complete) {
+ if (!active_) return;
+ complete_ = complete;
+ if (!complete) complete_ = 1;
+ memset(display_, 0, sizeof(display_));
+ character_handout_ = 0;
+ std::cerr << kProgressBanner;
+}
+
+WorkerProgress MultiProgress::Add() {
+ if (!active_)
+ return WorkerProgress(std::numeric_limits<uint64_t>::max(), *this, '\0');
+ std::size_t character_index;
+ {
+ boost::unique_lock<boost::mutex> lock(mutex_);
+ character_index = character_handout_++;
+ if (character_handout_ == sizeof(kDisplayCharacters) - 1)
+ character_handout_ = 0;
+ }
+ return WorkerProgress(Next(0, complete_), *this, kDisplayCharacters[character_index]);
+}
+
+void MultiProgress::Finished() {
+ if (!active_ || complete_ == std::numeric_limits<uint64_t>::max()) return;
+ std::cerr << '\n';
+ complete_ = std::numeric_limits<uint64_t>::max();
+}
+
+void MultiProgress::Milestone(WorkerProgress &worker) {
+ if (!active_ || complete_ == std::numeric_limits<uint64_t>::max()) return;
+ unsigned char stone = std::min(static_cast<uint64_t>(kWidth), worker.current_ * kWidth / complete_);
+ for (char *i = &display_[worker.stone_]; i < &display_[stone]; ++i) {
+ *i = worker.character_;
+ }
+ worker.next_ = Next(stone, complete_);
+ worker.stone_ = stone;
+ {
+ boost::unique_lock<boost::mutex> lock(mutex_);
+ std::cerr << '\r' << display_ << std::flush;
+ }
+}
+
+}} // namespaces
diff --git a/klm/util/stream/multi_progress.hh b/klm/util/stream/multi_progress.hh
new file mode 100644
index 00000000..c4dd45a9
--- /dev/null
+++ b/klm/util/stream/multi_progress.hh
@@ -0,0 +1,90 @@
+/* Progress bar suitable for chains of workers */
+#ifndef UTIL_MULTI_PROGRESS__
+#define UTIL_MULTI_PROGRESS__
+
+#include <boost/thread/mutex.hpp>
+
+#include <cstddef>
+
+#include <stdint.h>
+
+namespace util { namespace stream {
+
+class WorkerProgress;
+
+class MultiProgress {
+ public:
+ static const unsigned char kWidth = 100;
+
+ MultiProgress();
+
+ ~MultiProgress();
+
+ // Turns on showing (requires SetTarget too).
+ void Activate();
+
+ void SetTarget(uint64_t complete);
+
+ WorkerProgress Add();
+
+ void Finished();
+
+ private:
+ friend class WorkerProgress;
+ void Milestone(WorkerProgress &worker);
+
+ bool active_;
+
+ uint64_t complete_;
+
+ boost::mutex mutex_;
+
+ // \0 at the end.
+ char display_[kWidth + 1];
+
+ std::size_t character_handout_;
+
+ MultiProgress(const MultiProgress &);
+ MultiProgress &operator=(const MultiProgress &);
+};
+
+class WorkerProgress {
+ public:
+ // Default contrutor must be initialized with operator= later.
+ WorkerProgress() : parent_(NULL) {}
+
+ // Not threadsafe for the same worker by default.
+ WorkerProgress &operator++() {
+ if (++current_ >= next_) {
+ parent_->Milestone(*this);
+ }
+ return *this;
+ }
+
+ WorkerProgress &operator+=(uint64_t amount) {
+ current_ += amount;
+ if (current_ >= next_) {
+ parent_->Milestone(*this);
+ }
+ return *this;
+ }
+
+ private:
+ friend class MultiProgress;
+ WorkerProgress(uint64_t next, MultiProgress &parent, char character)
+ : current_(0), next_(next), parent_(&parent), stone_(0), character_(character) {}
+
+ uint64_t current_, next_;
+
+ MultiProgress *parent_;
+
+ // Previous milestone reached.
+ unsigned char stone_;
+
+ // Character to display in bar.
+ char character_;
+};
+
+}} // namespaces
+
+#endif // UTIL_MULTI_PROGRESS__
diff --git a/klm/util/stream/sort.hh b/klm/util/stream/sort.hh
new file mode 100644
index 00000000..a86f160f
--- /dev/null
+++ b/klm/util/stream/sort.hh
@@ -0,0 +1,544 @@
+/* Usage:
+ * Sort<Compare> sorter(temp, compare);
+ * Chain(config) >> Read(file) >> sorter.Unsorted();
+ * Stream stream;
+ * Chain chain(config) >> sorter.Sorted(internal_config, lazy_config) >> stream;
+ *
+ * Note that sorter must outlive any threads that use Unsorted or Sorted.
+ *
+ * Combiners take the form:
+ * bool operator()(void *into, const void *option, const Compare &compare) const
+ * which returns true iff a combination happened. The sorting algorithm
+ * guarantees compare(into, option). But it does not guarantee
+ * compare(option, into).
+ * Currently, combining is only done in merge steps, not during on-the-fly
+ * sort. Use a hash table for that.
+ */
+
+#ifndef UTIL_STREAM_SORT__
+#define UTIL_STREAM_SORT__
+
+#include "util/stream/chain.hh"
+#include "util/stream/config.hh"
+#include "util/stream/io.hh"
+#include "util/stream/stream.hh"
+#include "util/stream/timer.hh"
+
+#include "util/file.hh"
+#include "util/scoped.hh"
+#include "util/sized_iterator.hh"
+
+#include <algorithm>
+#include <iostream>
+#include <queue>
+#include <string>
+
+namespace util {
+namespace stream {
+
+struct NeverCombine {
+ template <class Compare> bool operator()(const void *, const void *, const Compare &) const {
+ return false;
+ }
+};
+
+// Manage the offsets of sorted blocks in a file.
+class Offsets {
+ public:
+ explicit Offsets(int fd) : log_(fd) {
+ Reset();
+ }
+
+ int File() const { return log_; }
+
+ void Append(uint64_t length) {
+ if (!length) return;
+ ++block_count_;
+ if (length == cur_.length) {
+ ++cur_.run;
+ return;
+ }
+ WriteOrThrow(log_, &cur_, sizeof(Entry));
+ cur_.length = length;
+ cur_.run = 1;
+ }
+
+ void FinishedAppending() {
+ WriteOrThrow(log_, &cur_, sizeof(Entry));
+ SeekOrThrow(log_, sizeof(Entry)); // Skip 0,0 at beginning.
+ cur_.run = 0;
+ if (block_count_) {
+ ReadOrThrow(log_, &cur_, sizeof(Entry));
+ assert(cur_.length);
+ assert(cur_.run);
+ }
+ }
+
+ uint64_t RemainingBlocks() const { return block_count_; }
+
+ uint64_t TotalOffset() const { return output_sum_; }
+
+ uint64_t PeekSize() const {
+ return cur_.length;
+ }
+
+ uint64_t NextSize() {
+ assert(block_count_);
+ uint64_t ret = cur_.length;
+ output_sum_ += ret;
+
+ --cur_.run;
+ --block_count_;
+ if (!cur_.run && block_count_) {
+ ReadOrThrow(log_, &cur_, sizeof(Entry));
+ assert(cur_.length);
+ assert(cur_.run);
+ }
+ return ret;
+ }
+
+ void Reset() {
+ SeekOrThrow(log_, 0);
+ ResizeOrThrow(log_, 0);
+ cur_.length = 0;
+ cur_.run = 0;
+ block_count_ = 0;
+ output_sum_ = 0;
+ }
+
+ private:
+ int log_;
+
+ struct Entry {
+ uint64_t length;
+ uint64_t run;
+ };
+ Entry cur_;
+
+ uint64_t block_count_;
+
+ uint64_t output_sum_;
+};
+
+// A priority queue of entries backed by file buffers
+template <class Compare> class MergeQueue {
+ public:
+ MergeQueue(int fd, std::size_t buffer_size, std::size_t entry_size, const Compare &compare)
+ : queue_(Greater(compare)), in_(fd), buffer_size_(buffer_size), entry_size_(entry_size) {}
+
+ void Push(void *base, uint64_t offset, uint64_t amount) {
+ queue_.push(Entry(base, in_, offset, amount, buffer_size_));
+ }
+
+ const void *Top() const {
+ return queue_.top().Current();
+ }
+
+ void Pop() {
+ Entry top(queue_.top());
+ queue_.pop();
+ if (top.Increment(in_, buffer_size_, entry_size_))
+ queue_.push(top);
+ }
+
+ std::size_t Size() const {
+ return queue_.size();
+ }
+
+ bool Empty() const {
+ return queue_.empty();
+ }
+
+ private:
+ // Priority queue contains these entries.
+ class Entry {
+ public:
+ Entry() {}
+
+ Entry(void *base, int fd, uint64_t offset, uint64_t amount, std::size_t buf_size) {
+ offset_ = offset;
+ remaining_ = amount;
+ buffer_end_ = static_cast<uint8_t*>(base) + buf_size;
+ Read(fd, buf_size);
+ }
+
+ bool Increment(int fd, std::size_t buf_size, std::size_t entry_size) {
+ current_ += entry_size;
+ if (current_ != buffer_end_) return true;
+ return Read(fd, buf_size);
+ }
+
+ const void *Current() const { return current_; }
+
+ private:
+ bool Read(int fd, std::size_t buf_size) {
+ current_ = buffer_end_ - buf_size;
+ std::size_t amount;
+ if (static_cast<uint64_t>(buf_size) < remaining_) {
+ amount = buf_size;
+ } else if (!remaining_) {
+ return false;
+ } else {
+ amount = remaining_;
+ buffer_end_ = current_ + remaining_;
+ }
+ PReadOrThrow(fd, current_, amount, offset_);
+ offset_ += amount;
+ assert(current_ <= buffer_end_);
+ remaining_ -= amount;
+ return true;
+ }
+
+ // Buffer
+ uint8_t *current_, *buffer_end_;
+ // File
+ uint64_t remaining_, offset_;
+ };
+
+ // Wrapper comparison function for queue entries.
+ class Greater : public std::binary_function<const Entry &, const Entry &, bool> {
+ public:
+ explicit Greater(const Compare &compare) : compare_(compare) {}
+
+ bool operator()(const Entry &first, const Entry &second) const {
+ return compare_(second.Current(), first.Current());
+ }
+
+ private:
+ const Compare compare_;
+ };
+
+ typedef std::priority_queue<Entry, std::vector<Entry>, Greater> Queue;
+ Queue queue_;
+
+ const int in_;
+ const std::size_t buffer_size_;
+ const std::size_t entry_size_;
+};
+
+/* A worker object that merges. If the number of pieces to merge exceeds the
+ * arity, it outputs multiple sorted blocks, recording to out_offsets.
+ * However, users will only every see a single sorted block out output because
+ * Sort::Sorted insures the arity is higher than the number of pieces before
+ * returning this.
+ */
+template <class Compare, class Combine> class MergingReader {
+ public:
+ MergingReader(int in, Offsets *in_offsets, Offsets *out_offsets, std::size_t buffer_size, std::size_t total_memory, const Compare &compare, const Combine &combine) :
+ compare_(compare), combine_(combine),
+ in_(in),
+ in_offsets_(in_offsets), out_offsets_(out_offsets),
+ buffer_size_(buffer_size), total_memory_(total_memory) {}
+
+ void Run(const ChainPosition &position) {
+ Run(position, false);
+ }
+
+ void Run(const ChainPosition &position, bool assert_one) {
+ // Special case: nothing to read.
+ if (!in_offsets_->RemainingBlocks()) {
+ Link l(position);
+ l.Poison();
+ return;
+ }
+ // If there's just one entry, just read.
+ if (in_offsets_->RemainingBlocks() == 1) {
+ // Sequencing is important.
+ uint64_t offset = in_offsets_->TotalOffset();
+ uint64_t amount = in_offsets_->NextSize();
+ ReadSingle(offset, amount, position);
+ if (out_offsets_) out_offsets_->Append(amount);
+ return;
+ }
+
+ Stream str(position);
+ scoped_malloc buffer(MallocOrThrow(total_memory_));
+ uint8_t *const buffer_end = static_cast<uint8_t*>(buffer.get()) + total_memory_;
+
+ const std::size_t entry_size = position.GetChain().EntrySize();
+
+ while (in_offsets_->RemainingBlocks()) {
+ // Use bigger buffers if there's less remaining.
+ uint64_t per_buffer = static_cast<uint64_t>(std::max<std::size_t>(
+ buffer_size_,
+ static_cast<std::size_t>((static_cast<uint64_t>(total_memory_) / in_offsets_->RemainingBlocks()))));
+ per_buffer -= per_buffer % entry_size;
+ assert(per_buffer);
+
+ // Populate queue.
+ MergeQueue<Compare> queue(in_, per_buffer, entry_size, compare_);
+ for (uint8_t *buf = static_cast<uint8_t*>(buffer.get());
+ in_offsets_->RemainingBlocks() && (buf + std::min(per_buffer, in_offsets_->PeekSize()) <= buffer_end);) {
+ uint64_t offset = in_offsets_->TotalOffset();
+ uint64_t size = in_offsets_->NextSize();
+ queue.Push(buf, offset, size);
+ buf += static_cast<std::size_t>(std::min<uint64_t>(size, per_buffer));
+ }
+ // This shouldn't happen but it's probably better to die than loop indefinitely.
+ if (queue.Size() < 2 && in_offsets_->RemainingBlocks()) {
+ std::cerr << "Bug in sort implementation: not merging at least two stripes." << std::endl;
+ abort();
+ }
+ if (assert_one && in_offsets_->RemainingBlocks()) {
+ std::cerr << "Bug in sort implementation: should only be one merge group for lazy sort" << std::endl;
+ abort();
+ }
+
+ uint64_t written = 0;
+ // Merge including combiner support.
+ memcpy(str.Get(), queue.Top(), entry_size);
+ for (queue.Pop(); !queue.Empty(); queue.Pop()) {
+ if (!combine_(str.Get(), queue.Top(), compare_)) {
+ ++written; ++str;
+ memcpy(str.Get(), queue.Top(), entry_size);
+ }
+ }
+ ++written; ++str;
+ if (out_offsets_)
+ out_offsets_->Append(written * entry_size);
+ }
+ str.Poison();
+ }
+
+ private:
+ void ReadSingle(uint64_t offset, const uint64_t size, const ChainPosition &position) {
+ // Special case: only one to read.
+ const uint64_t end = offset + size;
+ const uint64_t block_size = position.GetChain().BlockSize();
+ Link l(position);
+ for (; offset + block_size < end; ++l, offset += block_size) {
+ PReadOrThrow(in_, l->Get(), block_size, offset);
+ l->SetValidSize(block_size);
+ }
+ PReadOrThrow(in_, l->Get(), end - offset, offset);
+ l->SetValidSize(end - offset);
+ (++l).Poison();
+ return;
+ }
+
+ Compare compare_;
+ Combine combine_;
+
+ int in_;
+
+ protected:
+ Offsets *in_offsets_;
+
+ private:
+ Offsets *out_offsets_;
+
+ std::size_t buffer_size_;
+ std::size_t total_memory_;
+};
+
+// The lazy step owns the remaining files. This keeps track of them.
+template <class Compare, class Combine> class OwningMergingReader : public MergingReader<Compare, Combine> {
+ private:
+ typedef MergingReader<Compare, Combine> P;
+ public:
+ OwningMergingReader(int data, const Offsets &offsets, std::size_t buffer, std::size_t lazy, const Compare &compare, const Combine &combine)
+ : P(data, NULL, NULL, buffer, lazy, compare, combine),
+ data_(data),
+ offsets_(offsets) {}
+
+ void Run(const ChainPosition &position) {
+ P::in_offsets_ = &offsets_;
+ scoped_fd data(data_);
+ scoped_fd offsets_file(offsets_.File());
+ P::Run(position, true);
+ }
+
+ private:
+ int data_;
+ Offsets offsets_;
+};
+
+// Don't use this directly. Worker that sorts blocks.
+template <class Compare> class BlockSorter {
+ public:
+ BlockSorter(Offsets &offsets, const Compare &compare) :
+ offsets_(&offsets), compare_(compare) {}
+
+ void Run(const ChainPosition &position) {
+ const std::size_t entry_size = position.GetChain().EntrySize();
+ for (Link link(position); link; ++link) {
+ // Record the size of each block in a separate file.
+ offsets_->Append(link->ValidSize());
+ void *end = static_cast<uint8_t*>(link->Get()) + link->ValidSize();
+ std::sort(
+ SizedIt(link->Get(), entry_size),
+ SizedIt(end, entry_size),
+ compare_);
+ }
+ offsets_->FinishedAppending();
+ }
+
+ private:
+ Offsets *offsets_;
+ SizedCompare<Compare> compare_;
+};
+
+class BadSortConfig : public Exception {
+ public:
+ BadSortConfig() throw() {}
+ ~BadSortConfig() throw() {}
+};
+
+template <class Compare, class Combine = NeverCombine> class Sort {
+ public:
+ Sort(Chain &in, const SortConfig &config, const Compare &compare = Compare(), const Combine &combine = Combine())
+ : config_(config),
+ data_(MakeTemp(config.temp_prefix)),
+ offsets_file_(MakeTemp(config.temp_prefix)), offsets_(offsets_file_.get()),
+ compare_(compare), combine_(combine),
+ entry_size_(in.EntrySize()) {
+ UTIL_THROW_IF(!entry_size_, BadSortConfig, "Sorting entries of size 0");
+ // Make buffer_size a multiple of the entry_size.
+ config_.buffer_size -= config_.buffer_size % entry_size_;
+ UTIL_THROW_IF(!config_.buffer_size, BadSortConfig, "Sort buffer too small");
+ UTIL_THROW_IF(config_.total_memory < config_.buffer_size * 4, BadSortConfig, "Sorting memory " << config_.total_memory << " is too small for four buffers (two read and two write).");
+ in >> BlockSorter<Compare>(offsets_, compare_) >> WriteAndRecycle(data_.get());
+ }
+
+ uint64_t Size() const {
+ return SizeOrThrow(data_.get());
+ }
+
+ // Do merge sort, terminating when lazy merge could be done with the
+ // specified memory. Return the minimum memory necessary to do lazy merge.
+ std::size_t Merge(std::size_t lazy_memory) {
+ if (offsets_.RemainingBlocks() <= 1) return 0;
+ const uint64_t lazy_arity = std::max<uint64_t>(1, lazy_memory / config_.buffer_size);
+ uint64_t size = Size();
+ /* No overflow because
+ * offsets_.RemainingBlocks() * config_.buffer_size <= lazy_memory ||
+ * size < lazy_memory
+ */
+ if (offsets_.RemainingBlocks() <= lazy_arity || size <= static_cast<uint64_t>(lazy_memory))
+ return std::min<std::size_t>(size, offsets_.RemainingBlocks() * config_.buffer_size);
+
+ scoped_fd data2(MakeTemp(config_.temp_prefix));
+ int fd_in = data_.get(), fd_out = data2.get();
+ scoped_fd offsets2_file(MakeTemp(config_.temp_prefix));
+ Offsets offsets2(offsets2_file.get());
+ Offsets *offsets_in = &offsets_, *offsets_out = &offsets2;
+
+ // Double buffered writing.
+ ChainConfig chain_config;
+ chain_config.entry_size = entry_size_;
+ chain_config.block_count = 2;
+ chain_config.total_memory = config_.buffer_size * 2;
+ Chain chain(chain_config);
+
+ while (offsets_in->RemainingBlocks() > lazy_arity) {
+ if (size <= static_cast<uint64_t>(lazy_memory)) break;
+ std::size_t reading_memory = config_.total_memory - 2 * config_.buffer_size;
+ if (size < static_cast<uint64_t>(reading_memory)) {
+ reading_memory = static_cast<std::size_t>(size);
+ }
+ SeekOrThrow(fd_in, 0);
+ chain >>
+ MergingReader<Compare, Combine>(
+ fd_in,
+ offsets_in, offsets_out,
+ config_.buffer_size,
+ reading_memory,
+ compare_, combine_) >>
+ WriteAndRecycle(fd_out);
+ chain.Wait();
+ offsets_out->FinishedAppending();
+ ResizeOrThrow(fd_in, 0);
+ offsets_in->Reset();
+ std::swap(fd_in, fd_out);
+ std::swap(offsets_in, offsets_out);
+ size = SizeOrThrow(fd_in);
+ }
+
+ SeekOrThrow(fd_in, 0);
+ if (fd_in == data2.get()) {
+ data_.reset(data2.release());
+ offsets_file_.reset(offsets2_file.release());
+ offsets_ = offsets2;
+ }
+ if (offsets_.RemainingBlocks() <= 1) return 0;
+ // No overflow because the while loop exited.
+ return std::min(size, offsets_.RemainingBlocks() * static_cast<uint64_t>(config_.buffer_size));
+ }
+
+ // Output to chain, using this amount of memory, maximum, for lazy merge
+ // sort.
+ void Output(Chain &out, std::size_t lazy_memory) {
+ Merge(lazy_memory);
+ out.SetProgressTarget(Size());
+ out >> OwningMergingReader<Compare, Combine>(data_.get(), offsets_, config_.buffer_size, lazy_memory, compare_, combine_);
+ data_.release();
+ offsets_file_.release();
+ }
+
+ /* If a pipeline step is reading sorted input and writing to a different
+ * sort order, then there's a trade-off between using RAM to read lazily
+ * (avoiding copying the file) and using RAM to increase block size and,
+ * therefore, decrease the number of merge sort passes in the next
+ * iteration.
+ *
+ * Merge sort takes log_{arity}(pieces) passes. Thus, each time the chain
+ * block size is multiplied by arity, the number of output passes decreases
+ * by one. Up to a constant, then, log_{arity}(chain) is the number of
+ * passes saved. Chain simply divides the memory evenly over all blocks.
+ *
+ * Lazy sort saves this many passes (up to a constant)
+ * log_{arity}((memory-lazy)/block_count) + 1
+ * Non-lazy sort saves this many passes (up to the same constant):
+ * log_{arity}(memory/block_count)
+ * Add log_{arity}(block_count) to both:
+ * log_{arity}(memory-lazy) + 1 versus log_{arity}(memory)
+ * Take arity to the power of both sizes (arity > 1)
+ * (memory - lazy)*arity versus memory
+ * Solve for lazy
+ * lazy = memory * (arity - 1) / arity
+ */
+ std::size_t DefaultLazy() {
+ float arity = static_cast<float>(config_.total_memory / config_.buffer_size);
+ return static_cast<std::size_t>(static_cast<float>(config_.total_memory) * (arity - 1.0) / arity);
+ }
+
+ // Same as Output with default lazy memory setting.
+ void Output(Chain &out) {
+ Output(out, DefaultLazy());
+ }
+
+ // Completely merge sort and transfer ownership to the caller.
+ int StealCompleted() {
+ // Merge all the way.
+ Merge(0);
+ SeekOrThrow(data_.get(), 0);
+ offsets_file_.reset();
+ return data_.release();
+ }
+
+ private:
+ SortConfig config_;
+
+ scoped_fd data_;
+
+ scoped_fd offsets_file_;
+ Offsets offsets_;
+
+ const Compare compare_;
+ const Combine combine_;
+ const std::size_t entry_size_;
+};
+
+// returns bytes to be read on demand.
+template <class Compare, class Combine> uint64_t BlockingSort(Chain &chain, const SortConfig &config, const Compare &compare = Compare(), const Combine &combine = NeverCombine()) {
+ Sort<Compare, Combine> sorter(chain, config, compare, combine);
+ chain.Wait(true);
+ uint64_t size = sorter.Size();
+ sorter.Output(chain);
+ return size;
+}
+
+} // namespace stream
+} // namespace util
+
+#endif // UTIL_STREAM_SORT__
diff --git a/klm/util/stream/sort_test.cc b/klm/util/stream/sort_test.cc
new file mode 100644
index 00000000..fd7705cd
--- /dev/null
+++ b/klm/util/stream/sort_test.cc
@@ -0,0 +1,62 @@
+#include "util/stream/sort.hh"
+
+#define BOOST_TEST_MODULE SortTest
+#include <boost/test/unit_test.hpp>
+
+#include <algorithm>
+
+#include <unistd.h>
+
+namespace util { namespace stream { namespace {
+
+struct CompareUInt64 : public std::binary_function<const void *, const void *, bool> {
+ bool operator()(const void *first, const void *second) const {
+ return *static_cast<const uint64_t*>(first) < *reinterpret_cast<const uint64_t*>(second);
+ }
+};
+
+const uint64_t kSize = 100000;
+
+struct Putter {
+ Putter(std::vector<uint64_t> &shuffled) : shuffled_(shuffled) {}
+
+ void Run(const ChainPosition &position) {
+ Stream put_shuffled(position);
+ for (uint64_t i = 0; i < shuffled_.size(); ++i, ++put_shuffled) {
+ *static_cast<uint64_t*>(put_shuffled.Get()) = shuffled_[i];
+ }
+ put_shuffled.Poison();
+ }
+ std::vector<uint64_t> &shuffled_;
+};
+
+BOOST_AUTO_TEST_CASE(FromShuffled) {
+ std::vector<uint64_t> shuffled;
+ shuffled.reserve(kSize);
+ for (uint64_t i = 0; i < kSize; ++i) {
+ shuffled.push_back(i);
+ }
+ std::random_shuffle(shuffled.begin(), shuffled.end());
+
+ ChainConfig config;
+ config.entry_size = 8;
+ config.total_memory = 800;
+ config.block_count = 3;
+
+ SortConfig merge_config;
+ merge_config.temp_prefix = "sort_test_temp";
+ merge_config.buffer_size = 800;
+ merge_config.total_memory = 3300;
+
+ Chain chain(config);
+ chain >> Putter(shuffled);
+ BlockingSort(chain, merge_config, CompareUInt64(), NeverCombine());
+ Stream sorted;
+ chain >> sorted >> kRecycle;
+ for (uint64_t i = 0; i < kSize; ++i, ++sorted) {
+ BOOST_CHECK_EQUAL(i, *static_cast<const uint64_t*>(sorted.Get()));
+ }
+ BOOST_CHECK(!sorted);
+}
+
+}}} // namespaces
diff --git a/klm/util/stream/stream.hh b/klm/util/stream/stream.hh
new file mode 100644
index 00000000..6ff45b82
--- /dev/null
+++ b/klm/util/stream/stream.hh
@@ -0,0 +1,74 @@
+#ifndef UTIL_STREAM_STREAM__
+#define UTIL_STREAM_STREAM__
+
+#include "util/stream/chain.hh"
+
+#include <boost/noncopyable.hpp>
+
+#include <assert.h>
+#include <stdint.h>
+
+namespace util {
+namespace stream {
+
+class Stream : boost::noncopyable {
+ public:
+ Stream() : current_(NULL), end_(NULL) {}
+
+ void Init(const ChainPosition &position) {
+ entry_size_ = position.GetChain().EntrySize();
+ block_size_ = position.GetChain().BlockSize();
+ block_it_.Init(position);
+ StartBlock();
+ }
+
+ explicit Stream(const ChainPosition &position) {
+ Init(position);
+ }
+
+ operator bool() const { return current_ != NULL; }
+ bool operator!() const { return current_ == NULL; }
+
+ const void *Get() const { return current_; }
+ void *Get() { return current_; }
+
+ void Poison() {
+ block_it_->SetValidSize(current_ - static_cast<uint8_t*>(block_it_->Get()));
+ ++block_it_;
+ block_it_.Poison();
+ }
+
+ Stream &operator++() {
+ assert(*this);
+ assert(current_ < end_);
+ current_ += entry_size_;
+ if (current_ == end_) {
+ ++block_it_;
+ StartBlock();
+ }
+ return *this;
+ }
+
+ private:
+ void StartBlock() {
+ for (; block_it_ && !block_it_->ValidSize(); ++block_it_) {}
+ current_ = static_cast<uint8_t*>(block_it_->Get());
+ end_ = current_ + block_it_->ValidSize();
+ }
+
+ uint8_t *current_, *end_;
+
+ std::size_t entry_size_;
+ std::size_t block_size_;
+
+ Link block_it_;
+};
+
+inline Chain &operator>>(Chain &chain, Stream &stream) {
+ stream.Init(chain.Add());
+ return chain;
+}
+
+} // namespace stream
+} // namespace util
+#endif // UTIL_STREAM_STREAM__
diff --git a/klm/util/stream/stream_test.cc b/klm/util/stream/stream_test.cc
new file mode 100644
index 00000000..6575d50d
--- /dev/null
+++ b/klm/util/stream/stream_test.cc
@@ -0,0 +1,35 @@
+#include "util/stream/io.hh"
+
+#include "util/stream/stream.hh"
+#include "util/file.hh"
+
+#define BOOST_TEST_MODULE StreamTest
+#include <boost/test/unit_test.hpp>
+
+#include <unistd.h>
+
+namespace util { namespace stream { namespace {
+
+BOOST_AUTO_TEST_CASE(StreamTest) {
+ scoped_fd in(MakeTemp("io_test_temp"));
+ for (uint64_t i = 0; i < 100000; ++i) {
+ WriteOrThrow(in.get(), &i, sizeof(uint64_t));
+ }
+ SeekOrThrow(in.get(), 0);
+
+ ChainConfig config;
+ config.entry_size = 8;
+ config.total_memory = 100;
+ config.block_count = 12;
+
+ Stream s;
+ Chain chain(config);
+ chain >> Read(in.get()) >> s >> kRecycle;
+ uint64_t i = 0;
+ for (; s; ++s, ++i) {
+ BOOST_CHECK_EQUAL(i, *static_cast<const uint64_t*>(s.Get()));
+ }
+ BOOST_CHECK_EQUAL(100000ULL, i);
+}
+
+}}} // namespaces
diff --git a/klm/util/stream/timer.hh b/klm/util/stream/timer.hh
new file mode 100644
index 00000000..7e1a5885
--- /dev/null
+++ b/klm/util/stream/timer.hh
@@ -0,0 +1,16 @@
+#ifndef UTIL_STREAM_TIMER__
+#define UTIL_STREAM_TIMER__
+
+// Sorry Jon, this was adding library dependencies in Moses and people complained.
+
+/*#include <boost/version.hpp>
+
+#if BOOST_VERSION >= 104800
+#include <boost/timer/timer.hpp>
+#define UTIL_TIMER(str) boost::timer::auto_cpu_timer timer(std::cerr, 1, (str))
+#else
+//#warning Using Boost older than 1.48. Timing information will not be available.*/
+#define UTIL_TIMER(str)
+//#endif
+
+#endif // UTIL_STREAM_TIMER__