From e0ef743090038ee02d656cee11debd2246624ba0 Mon Sep 17 00:00:00 2001 From: redpony Date: Mon, 18 Oct 2010 23:24:01 +0000 Subject: kenneth's LM preliminary integration git-svn-id: https://ws10smt.googlecode.com/svn/trunk@681 ec762483-ff6d-05da-a07a-a48fb63a330f --- klm/util/Makefile.am | 18 +++ klm/util/ersatz_progress.cc | 47 +++++++ klm/util/ersatz_progress.hh | 50 +++++++ klm/util/exception.cc | 35 +++++ klm/util/exception.hh | 72 ++++++++++ klm/util/file_piece.cc | 224 +++++++++++++++++++++++++++++++ klm/util/file_piece.hh | 105 +++++++++++++++ klm/util/file_piece_test.cc | 41 ++++++ klm/util/joint_sort.hh | 145 ++++++++++++++++++++ klm/util/joint_sort_test.cc | 50 +++++++ klm/util/key_value_packing.hh | 122 +++++++++++++++++ klm/util/key_value_packing_test.cc | 75 +++++++++++ klm/util/mmap.cc | 95 +++++++++++++ klm/util/mmap.hh | 101 ++++++++++++++ klm/util/murmur_hash.cc | 129 ++++++++++++++++++ klm/util/murmur_hash.hh | 14 ++ klm/util/probing_hash_table.hh | 97 ++++++++++++++ klm/util/probing_hash_table_test.cc | 30 +++++ klm/util/proxy_iterator.hh | 94 +++++++++++++ klm/util/scoped.cc | 12 ++ klm/util/scoped.hh | 66 +++++++++ klm/util/sorted_uniform.hh | 139 +++++++++++++++++++ klm/util/sorted_uniform_test.cc | 116 ++++++++++++++++ klm/util/string_piece.cc | 57 ++++++++ klm/util/string_piece.hh | 260 ++++++++++++++++++++++++++++++++++++ 25 files changed, 2194 insertions(+) create mode 100644 klm/util/Makefile.am create mode 100644 klm/util/ersatz_progress.cc create mode 100644 klm/util/ersatz_progress.hh create mode 100644 klm/util/exception.cc create mode 100644 klm/util/exception.hh create mode 100644 klm/util/file_piece.cc create mode 100644 klm/util/file_piece.hh create mode 100644 klm/util/file_piece_test.cc create mode 100644 klm/util/joint_sort.hh create mode 100644 klm/util/joint_sort_test.cc create mode 100644 klm/util/key_value_packing.hh create mode 100644 klm/util/key_value_packing_test.cc create mode 100644 klm/util/mmap.cc create mode 100644 klm/util/mmap.hh create mode 100644 klm/util/murmur_hash.cc create mode 100644 klm/util/murmur_hash.hh create mode 100644 klm/util/probing_hash_table.hh create mode 100644 klm/util/probing_hash_table_test.cc create mode 100644 klm/util/proxy_iterator.hh create mode 100644 klm/util/scoped.cc create mode 100644 klm/util/scoped.hh create mode 100644 klm/util/sorted_uniform.hh create mode 100644 klm/util/sorted_uniform_test.cc create mode 100644 klm/util/string_piece.cc create mode 100644 klm/util/string_piece.hh (limited to 'klm/util') diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am new file mode 100644 index 00000000..d3aea6b7 --- /dev/null +++ b/klm/util/Makefile.am @@ -0,0 +1,18 @@ +if HAVE_GTEST +noinst_PROGRAMS = \ + scorer_test +TESTS = scorer_test +endif + +noinst_LIBRARIES = libklm_util.a + +libklm_util_a_SOURCES = \ + ersatz_progress.cc \ + exception.cc \ + file_piece.cc \ + mmap.cc \ + murmur_hash.cc \ + scoped.cc \ + string_piece.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc new file mode 100644 index 00000000..09e3a106 --- /dev/null +++ b/klm/util/ersatz_progress.cc @@ -0,0 +1,47 @@ +#include "util/ersatz_progress.hh" + +#include +#include +#include +#include + +namespace util { + +namespace { const unsigned char kWidth = 100; } + +ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits::max()), complete_(next_), out_(NULL) {} + +ErsatzProgress::~ErsatzProgress() { + if (!out_) return; + for (; stones_written_ < kWidth; ++stones_written_) { + (*out_) << '*'; + } + *out_ << '\n'; +} + +ErsatzProgress::ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete) + : current_(0), next_(complete / kWidth), complete_(complete), stones_written_(0), out_(to) { + if (!out_) { + next_ = std::numeric_limits::max(); + return; + } + *out_ << message << "\n----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"; +} + +void ErsatzProgress::Milestone() { + if (!out_) { current_ = 0; return; } + if (!complete_) return; + unsigned char stone = std::min(static_cast(kWidth), (current_ * kWidth) / complete_); + + for (; stones_written_ < stone; ++stones_written_) { + (*out_) << '*'; + } + + if (current_ >= complete_) { + next_ = std::numeric_limits::max(); + } else { + next_ = std::max(next_, (stone * complete_) / kWidth); + } +} + +} // namespace util diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh new file mode 100644 index 00000000..ea6c3bb9 --- /dev/null +++ b/klm/util/ersatz_progress.hh @@ -0,0 +1,50 @@ +#ifndef UTIL_ERSATZ_PROGRESS__ +#define UTIL_ERSATZ_PROGRESS__ + +#include +#include + +// Ersatz version of boost::progress so core language model doesn't depend on +// boost. Also adds option to print nothing. + +namespace util { +class ErsatzProgress { + public: + // No output. + ErsatzProgress(); + + // Null means no output. The null value is useful for passing along the ostream pointer from another caller. + ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete); + + ~ErsatzProgress(); + + ErsatzProgress &operator++() { + if (++current_ == next_) Milestone(); + return *this; + } + + ErsatzProgress &operator+=(std::size_t amount) { + if ((current_ += amount) >= next_) Milestone(); + return *this; + } + + void Set(std::size_t to) { + if ((current_ = to) >= next_) Milestone(); + Milestone(); + } + + private: + void Milestone(); + + std::size_t current_, next_, complete_; + unsigned char stones_written_; + std::ostream *out_; + + // noncopyable + ErsatzProgress(const ErsatzProgress &other); + ErsatzProgress &operator=(const ErsatzProgress &other); +}; + +} // namespace util + +#endif // UTIL_ERSATZ_PROGRESS__ diff --git a/klm/util/exception.cc b/klm/util/exception.cc new file mode 100644 index 00000000..dd337a76 --- /dev/null +++ b/klm/util/exception.cc @@ -0,0 +1,35 @@ +#include "util/exception.hh" + +#include +#include + +namespace util { + +Exception::Exception() throw() {} +Exception::~Exception() throw() {} + +namespace { +// The XOPEN version. +const char *HandleStrerror(int ret, const char *buf) { + if (!ret) return buf; + return NULL; +} + +// The GNU version. +const char *HandleStrerror(const char *ret, const char *buf) { + return ret; +} +} // namespace + +ErrnoException::ErrnoException() throw() : errno_(errno) { + char buf[200]; + buf[0] = 0; + const char *add = HandleStrerror(strerror_r(errno, buf, 200), buf); + if (add) { + *this << add << ' '; + } +} + +ErrnoException::~ErrnoException() throw() {} + +} // namespace util diff --git a/klm/util/exception.hh b/klm/util/exception.hh new file mode 100644 index 00000000..124689cf --- /dev/null +++ b/klm/util/exception.hh @@ -0,0 +1,72 @@ +#ifndef UTIL_EXCEPTION__ +#define UTIL_EXCEPTION__ + +#include "util/string_piece.hh" + +#include +#include +#include + +namespace util { + +class Exception : public std::exception { + public: + Exception() throw(); + virtual ~Exception() throw(); + + const char *what() const throw() { return what_.c_str(); } + + // This helps restrict operator<< defined below. + template struct ExceptionTag { + typedef T Identity; + }; + + std::string &Str() { + return what_; + } + + protected: + std::string what_; +}; + +/* This implements the normal operator<< for Exception and all its children. + * SNIFAE means it only applies to Exception. Think of this as an ersatz + * boost::enable_if. + */ +template typename Except::template ExceptionTag::Identity operator<<(Except &e, const Data &data) { + // Argh I had a stringstream in the exception, but the only way to get the string is by calling str(). But that's a temporary string, so virtual const char *what() const can't actually return it. + std::stringstream stream; + stream << data; + e.Str() += stream.str(); + return e; +} +template typename Except::template ExceptionTag::Identity operator<<(Except &e, const char *data) { + e.Str() += data; + return e; +} +template typename Except::template ExceptionTag::Identity operator<<(Except &e, const std::string &data) { + e.Str() += data; + return e; +} +template typename Except::template ExceptionTag::Identity operator<<(Except &e, const StringPiece &str) { + e.Str().append(str.data(), str.length()); + return e; +} + +#define UTIL_THROW(Exception, Modify) { Exception UTIL_e; {UTIL_e << Modify;} throw UTIL_e; } + +class ErrnoException : public Exception { + public: + ErrnoException() throw(); + + virtual ~ErrnoException() throw(); + + int Error() { return errno_; } + + private: + int errno_; +}; + +} // namespace util + +#endif // UTIL_EXCEPTION__ diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc new file mode 100644 index 00000000..2b439499 --- /dev/null +++ b/klm/util/file_piece.cc @@ -0,0 +1,224 @@ +#include "util/file_piece.hh" + +#include "util/exception.hh" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace util { + +EndOfFileException::EndOfFileException() throw() { + *this << "End of file"; +} +EndOfFileException::~EndOfFileException() throw() {} + +ParseNumberException::ParseNumberException(StringPiece value) throw() { + *this << "Could not parse \"" << value << "\" into a float"; +} + +int OpenReadOrThrow(const char *name) { + int ret = open(name, O_RDONLY); + if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading"); + return ret; +} + +off_t SizeFile(int fd) { + struct stat sb; + if (fstat(fd, &sb) == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; + return sb.st_size; +} + +FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) : + file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)), + progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { + Initialize(name, show_progress, min_buffer); +} + +FilePiece::FilePiece(const char *name, int fd, std::ostream *show_progress, off_t min_buffer) : + file_(fd), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)), + progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { + Initialize(name, show_progress, min_buffer); +} + +void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) { + if (total_size_ == kBadSize) { + fallback_to_read_ = true; + if (show_progress) + *show_progress << "File " << name << " isn't normal. Using slower read() instead of mmap(). No progress bar." << std::endl; + } else { + fallback_to_read_ = false; + } + default_map_size_ = page_ * std::max((min_buffer / page_ + 1), 2); + position_ = NULL; + position_end_ = NULL; + mapped_offset_ = 0; + at_end_ = false; + Shift(); +} + +float FilePiece::ReadFloat() throw(EndOfFileException, ParseNumberException) { + SkipSpaces(); + while (last_space_ < position_) { + if (at_end_) { + // Hallucinate a null off the end of the file. + std::string buffer(position_, position_end_); + char *end; + float ret = std::strtof(buffer.c_str(), &end); + if (buffer.c_str() == end) throw ParseNumberException(buffer); + position_ += end - buffer.c_str(); + return ret; + } + Shift(); + } + char *end; + float ret = std::strtof(position_, &end); + if (end == position_) throw ParseNumberException(ReadDelimited()); + position_ = end; + return ret; +} + +void FilePiece::SkipSpaces() throw (EndOfFileException) { + for (; ; ++position_) { + if (position_ == position_end_) Shift(); + if (!isspace(*position_)) return; + } +} + +const char *FilePiece::FindDelimiterOrEOF() throw (EndOfFileException) { + for (const char *i = position_; i <= last_space_; ++i) { + if (isspace(*i)) return i; + } + while (!at_end_) { + size_t skip = position_end_ - position_; + Shift(); + for (const char *i = position_ + skip; i <= last_space_; ++i) { + if (isspace(*i)) return i; + } + } + return position_end_; +} + +StringPiece FilePiece::ReadLine(char delim) throw (EndOfFileException) { + const char *start = position_; + do { + for (const char *i = start; i < position_end_; ++i) { + if (*i == delim) { + StringPiece ret(position_, i - position_); + position_ = i + 1; + return ret; + } + } + size_t skip = position_end_ - position_; + Shift(); + start = position_ + skip; + } while (!at_end_); + StringPiece ret(position_, position_end_ - position_); + position_ = position_end_; + return position_; +} + +void FilePiece::Shift() throw(EndOfFileException) { + if (at_end_) throw EndOfFileException(); + off_t desired_begin = position_ - data_.begin() + mapped_offset_; + progress_.Set(desired_begin); + + if (!fallback_to_read_) MMapShift(desired_begin); + // Notice an mmap failure might set the fallback. + if (fallback_to_read_) ReadShift(desired_begin); + + for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) { + if (isspace(*last_space_)) break; + } +} + +void FilePiece::MMapShift(off_t desired_begin) throw() { + // Use mmap. + off_t ignore = desired_begin % page_; + // Duplicate request for Shift means give more data. + if (position_ == data_.begin() + ignore) { + default_map_size_ *= 2; + } + // Local version so that in case of failure it doesn't overwrite the class variable. + off_t mapped_offset = desired_begin - ignore; + + off_t mapped_size; + if (default_map_size_ >= static_cast(total_size_ - mapped_offset)) { + at_end_ = true; + mapped_size = total_size_ - mapped_offset; + } else { + mapped_size = default_map_size_; + } + + // Forcibly clear the existing mmap first. + data_.reset(); + data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED); + if (data_.get() == MAP_FAILED) { + fallback_to_read_ = true; + if (desired_begin) { + if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either."); + } + return; + } + mapped_offset_ = mapped_offset; + position_ = data_.begin() + ignore; + position_end_ = data_.begin() + mapped_size; +} + +void FilePiece::ReadShift(off_t desired_begin) throw() { + assert(fallback_to_read_); + if (data_.source() != scoped_memory::MALLOC_ALLOCATED) { + // First call. + data_.reset(); + data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); + if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_); + position_ = data_.begin(); + position_end_ = position_; + } + + // Bytes [data_.begin(), position_) have been consumed. + // Bytes [position_, position_end_) have been read into the buffer. + + // Start at the beginning of the buffer if there's nothing useful in it. + if (position_ == position_end_) { + mapped_offset_ += (position_end_ - data_.begin()); + position_ = data_.begin(); + position_end_ = position_; + } + + std::size_t already_read = position_end_ - data_.begin(); + + if (already_read == default_map_size_) { + if (position_ == data_.begin()) { + // Buffer too small. + std::size_t valid_length = position_end_ - position_; + default_map_size_ *= 2; + data_.call_realloc(default_map_size_); + if (!data_.get()) UTIL_THROW(ErrnoException, "realloc failed for " << default_map_size_); + position_ = data_.begin(); + position_end_ = position_ + valid_length; + } else { + size_t moving = position_end_ - position_; + memmove(data_.get(), position_, moving); + position_ = data_.begin(); + position_end_ = position_ + moving; + already_read = moving; + } + } + + ssize_t read_return = read(file_.get(), static_cast(data_.get()) + already_read, default_map_size_ - already_read); + if (read_return == -1) UTIL_THROW(ErrnoException, "read failed"); + if (read_return == 0) at_end_ = true; + position_end_ += read_return; +} + +} // namespace util diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh new file mode 100644 index 00000000..704f0ac6 --- /dev/null +++ b/klm/util/file_piece.hh @@ -0,0 +1,105 @@ +#ifndef UTIL_FILE_PIECE__ +#define UTIL_FILE_PIECE__ + +#include "util/ersatz_progress.hh" +#include "util/exception.hh" +#include "util/mmap.hh" +#include "util/scoped.hh" +#include "util/string_piece.hh" + +#include + +#include + +namespace util { + +class EndOfFileException : public Exception { + public: + EndOfFileException() throw(); + ~EndOfFileException() throw(); +}; + +class ParseNumberException : public Exception { + public: + explicit ParseNumberException(StringPiece value) throw(); + ~ParseNumberException() throw() {} +}; + +int OpenReadOrThrow(const char *name); + +// Return value for SizeFile when it can't size properly. +const off_t kBadSize = -1; +off_t SizeFile(int fd); + +class FilePiece { + public: + // 32 MB default. + explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); + // Takes ownership of fd. name is used for messages. + explicit FilePiece(const char *name, int fd, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); + + char get() throw(EndOfFileException) { + if (position_ == position_end_) Shift(); + return *(position_++); + } + + // Memory backing the returned StringPiece may vanish on the next call. + // Leaves the delimiter, if any, to be returned by get(). + StringPiece ReadDelimited() throw(EndOfFileException) { + SkipSpaces(); + return Consume(FindDelimiterOrEOF()); + } + // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. + // It is similar to getline in that way. + StringPiece ReadLine(char delim = '\n') throw(EndOfFileException); + + float ReadFloat() throw(EndOfFileException, ParseNumberException); + + void SkipSpaces() throw (EndOfFileException); + + off_t Offset() const { + return position_ - data_.begin() + mapped_offset_; + } + + // Only for testing. + void ForceFallbackToRead() { + fallback_to_read_ = true; + } + + private: + void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer); + + StringPiece Consume(const char *to) { + StringPiece ret(position_, to - position_); + position_ = to; + return ret; + } + + const char *FindDelimiterOrEOF() throw(EndOfFileException); + + void Shift() throw (EndOfFileException); + // Backends to Shift(). + void MMapShift(off_t desired_begin) throw (); + void ReadShift(off_t desired_begin) throw (); + + const char *position_, *last_space_, *position_end_; + + scoped_fd file_; + const off_t total_size_; + const off_t page_; + + size_t default_map_size_; + off_t mapped_offset_; + + // Order matters: file_ should always be destroyed after this. + scoped_memory data_; + + bool at_end_; + bool fallback_to_read_; + + ErsatzProgress progress_; +}; + +} // namespace util + +#endif // UTIL_FILE_PIECE__ diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc new file mode 100644 index 00000000..befb7866 --- /dev/null +++ b/klm/util/file_piece_test.cc @@ -0,0 +1,41 @@ +#include "util/file_piece.hh" + +#define BOOST_TEST_MODULE FilePieceTest +#include +#include +#include + +namespace util { +namespace { + +/* mmap implementation */ +BOOST_AUTO_TEST_CASE(MMapLine) { + std::fstream ref("file_piece.cc", std::ios::in); + FilePiece test("file_piece.cc", NULL, 1); + std::string ref_line; + while (getline(ref, ref_line)) { + StringPiece test_line(test.ReadLine()); + // I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924 + if (!test_line.empty() || !ref_line.empty()) { + BOOST_CHECK_EQUAL(ref_line, test_line); + } + } +} + +/* read() implementation */ +BOOST_AUTO_TEST_CASE(ReadLine) { + std::fstream ref("file_piece.cc", std::ios::in); + FilePiece test("file_piece.cc", NULL, 1); + test.ForceFallbackToRead(); + std::string ref_line; + while (getline(ref, ref_line)) { + StringPiece test_line(test.ReadLine()); + // I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924 + if (!test_line.empty() || !ref_line.empty()) { + BOOST_CHECK_EQUAL(ref_line, test_line); + } + } +} + +} // namespace +} // namespace util diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh new file mode 100644 index 00000000..a2f1c01d --- /dev/null +++ b/klm/util/joint_sort.hh @@ -0,0 +1,145 @@ +#ifndef UTIL_JOINT_SORT__ +#define UTIL_JOINT_SORT__ + +/* A terrifying amount of C++ to coax std::sort into soring one range while + * also permuting another range the same way. + */ + +#include "util/proxy_iterator.hh" + +#include +#include +#include + +namespace util { + +namespace detail { + +template class JointProxy; + +template class JointIter { + public: + JointIter() {} + + JointIter(const KeyIter &key_iter, const ValueIter &value_iter) : key_(key_iter), value_(value_iter) {} + + bool operator==(const JointIter &other) const { return key_ == other.key_; } + + bool operator<(const JointIter &other) const { return (key_ < other.key_); } + + std::ptrdiff_t operator-(const JointIter &other) const { return key_ - other.key_; } + + JointIter &operator+=(std::ptrdiff_t amount) { + key_ += amount; + value_ += amount; + return *this; + } + + void swap(const JointIter &other) { + std::swap(key_, other.key_); + std::swap(value_, other.value_); + } + + private: + friend class JointProxy; + KeyIter key_; + ValueIter value_; +}; + +template class JointProxy { + private: + typedef JointIter InnerIterator; + + public: + typedef struct { + typename std::iterator_traits::value_type key; + typename std::iterator_traits::value_type value; + const typename std::iterator_traits::value_type &GetKey() const { return key; } + } value_type; + + JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {} + JointProxy(const JointProxy &other) : inner_(other.inner_) {} + + operator const value_type() const { + value_type ret; + ret.key = *inner_.key_; + ret.value = *inner_.value_; + return ret; + } + + JointProxy &operator=(const JointProxy &other) { + *inner_.key_ = *other.inner_.key_; + *inner_.value_ = *other.inner_.value_; + return *this; + } + + JointProxy &operator=(const value_type &other) { + *inner_.key_ = other.key; + *inner_.value_ = other.value; + return *this; + } + + typename std::iterator_traits::reference GetKey() const { + return *(inner_.key_); + } + + void swap(JointProxy &other) { + std::swap(*inner_.key_, *other.inner_.key_); + std::swap(*inner_.value_, *other.inner_.value_); + } + + private: + friend class ProxyIterator >; + + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +template class LessWrapper : public std::binary_function { + public: + explicit LessWrapper(const Less &less) : less_(less) {} + + bool operator()(const Proxy &left, const Proxy &right) const { + return less_(left.GetKey(), right.GetKey()); + } + bool operator()(const Proxy &left, const typename Proxy::value_type &right) const { + return less_(left.GetKey(), right.GetKey()); + } + bool operator()(const typename Proxy::value_type &left, const Proxy &right) const { + return less_(left.GetKey(), right.GetKey()); + } + bool operator()(const typename Proxy::value_type &left, const typename Proxy::value_type &right) const { + return less_(left.GetKey(), right.GetKey()); + } + + private: + const Less less_; +}; + +} // namespace detail + +template void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin, const Less &less) { + ProxyIterator > full_begin(detail::JointProxy(key_begin, value_begin)); + detail::LessWrapper, Less> less_wrap(less); + std::sort(full_begin, full_begin + (key_end - key_begin), less_wrap); +} + + +template void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin) { + JointSort(key_begin, key_end, value_begin, std::less::value_type>()); +} + +} // namespace util + +namespace std { +template void swap(util::detail::JointIter &left, util::detail::JointIter &right) { + left.swap(right); +} + +template void swap(util::detail::JointProxy &left, util::detail::JointProxy &right) { + left.swap(right); +} +} // namespace std + +#endif // UTIL_JOINT_SORT__ diff --git a/klm/util/joint_sort_test.cc b/klm/util/joint_sort_test.cc new file mode 100644 index 00000000..4dc85916 --- /dev/null +++ b/klm/util/joint_sort_test.cc @@ -0,0 +1,50 @@ +#include "util/joint_sort.hh" + +#define BOOST_TEST_MODULE JointSortTest +#include + +namespace util { namespace { + +BOOST_AUTO_TEST_CASE(just_flip) { + char keys[2]; + int values[2]; + keys[0] = 1; values[0] = 327; + keys[1] = 0; values[1] = 87897; + JointSort(keys + 0, keys + 2, values + 0); + BOOST_CHECK_EQUAL(0, keys[0]); + BOOST_CHECK_EQUAL(87897, values[0]); + BOOST_CHECK_EQUAL(1, keys[1]); + BOOST_CHECK_EQUAL(327, values[1]); +} + +BOOST_AUTO_TEST_CASE(three) { + char keys[3]; + int values[3]; + keys[0] = 1; values[0] = 327; + keys[1] = 2; values[1] = 87897; + keys[2] = 0; values[2] = 10; + JointSort(keys + 0, keys + 3, values + 0); + BOOST_CHECK_EQUAL(0, keys[0]); + BOOST_CHECK_EQUAL(1, keys[1]); + BOOST_CHECK_EQUAL(2, keys[2]); +} + +BOOST_AUTO_TEST_CASE(char_int) { + char keys[4]; + int values[4]; + keys[0] = 3; values[0] = 327; + keys[1] = 1; values[1] = 87897; + keys[2] = 2; values[2] = 10; + keys[3] = 0; values[3] = 24347; + JointSort(keys + 0, keys + 4, values + 0); + BOOST_CHECK_EQUAL(0, keys[0]); + BOOST_CHECK_EQUAL(24347, values[0]); + BOOST_CHECK_EQUAL(1, keys[1]); + BOOST_CHECK_EQUAL(87897, values[1]); + BOOST_CHECK_EQUAL(2, keys[2]); + BOOST_CHECK_EQUAL(10, values[2]); + BOOST_CHECK_EQUAL(3, keys[3]); + BOOST_CHECK_EQUAL(327, values[3]); +} + +}} // namespace anonymous util diff --git a/klm/util/key_value_packing.hh b/klm/util/key_value_packing.hh new file mode 100644 index 00000000..450512ac --- /dev/null +++ b/klm/util/key_value_packing.hh @@ -0,0 +1,122 @@ +#ifndef UTIL_KEY_VALUE_PACKING__ +#define UTIL_KEY_VALUE_PACKING__ + +/* Why such a general interface? I'm planning on doing bit-level packing. */ + +#include +#include +#include + +#include + +namespace util { + +template struct Entry { + Key key; + Value value; + + const Key &GetKey() const { return key; } + const Value &GetValue() const { return value; } + + void Set(const Key &key_in, const Value &value_in) { + SetKey(key_in); + SetValue(value_in); + } + void SetKey(const Key &key_in) { key = key_in; } + void SetValue(const Value &value_in) { value = value_in; } + + bool operator<(const Entry &other) const { return GetKey() < other.GetKey(); } +}; + +// And now for a brief interlude to specialize std::swap. +} // namespace util +namespace std { +template void swap(util::Entry &first, util::Entry &second) { + swap(first.key, second.key); + swap(first.value, second.value); +} +}// namespace std +namespace util { + +template class AlignedPacking { + public: + typedef KeyT Key; + typedef ValueT Value; + + public: + static const std::size_t kBytes = sizeof(Entry); + static const std::size_t kBits = kBytes * 8; + + typedef Entry * MutableIterator; + typedef const Entry * ConstIterator; + typedef const Entry & ConstReference; + + static MutableIterator FromVoid(void *start) { + return reinterpret_cast(start); + } + + static Entry Make(const Key &key, const Value &value) { + Entry ret; + ret.Set(key, value); + return ret; + } +}; + +template class ByteAlignedPacking { + public: + typedef KeyT Key; + typedef ValueT Value; + + private: +#pragma pack(push) +#pragma pack(1) + struct RawEntry { + Key key; + Value value; + + const Key &GetKey() const { return key; } + const Value &GetValue() const { return value; } + + void Set(const Key &key_in, const Value &value_in) { + SetKey(key_in); + SetValue(value_in); + } + void SetKey(const Key &key_in) { key = key_in; } + void SetValue(const Value &value_in) { value = value_in; } + + bool operator<(const RawEntry &other) const { return GetKey() < other.GetKey(); } + }; +#pragma pack(pop) + + friend void std::swap<>(RawEntry&, RawEntry&); + + public: + typedef RawEntry *MutableIterator; + typedef const RawEntry *ConstIterator; + typedef RawEntry &ConstReference; + + static const std::size_t kBytes = sizeof(RawEntry); + static const std::size_t kBits = kBytes * 8; + + static MutableIterator FromVoid(void *start) { + return MutableIterator(reinterpret_cast(start)); + } + + static RawEntry Make(const Key &key, const Value &value) { + RawEntry ret; + ret.Set(key, value); + return ret; + } +}; + +} // namespace util +namespace std { +template void swap( + typename util::ByteAlignedPacking::RawEntry &first, + typename util::ByteAlignedPacking::RawEntry &second) { + swap(first.key, second.key); + swap(first.value, second.value); +} +}// namespace std + +#endif // UTIL_KEY_VALUE_PACKING__ diff --git a/klm/util/key_value_packing_test.cc b/klm/util/key_value_packing_test.cc new file mode 100644 index 00000000..a0d33fd7 --- /dev/null +++ b/klm/util/key_value_packing_test.cc @@ -0,0 +1,75 @@ +#include "util/key_value_packing.hh" + +#include +#include +#include +#include +#define BOOST_TEST_MODULE KeyValueStoreTest +#include + +#include +#include + +namespace util { +namespace { + +BOOST_AUTO_TEST_CASE(basic_in_out) { + typedef ByteAlignedPacking Packing; + void *backing = malloc(Packing::kBytes * 2); + Packing::MutableIterator i(Packing::FromVoid(backing)); + i->SetKey(10); + BOOST_CHECK_EQUAL(10, i->GetKey()); + i->SetValue(3); + BOOST_CHECK_EQUAL(3, i->GetValue()); + ++i; + i->SetKey(5); + BOOST_CHECK_EQUAL(5, i->GetKey()); + i->SetValue(42); + BOOST_CHECK_EQUAL(42, i->GetValue()); + + Packing::ConstIterator c(i); + BOOST_CHECK_EQUAL(5, c->GetKey()); + --c; + BOOST_CHECK_EQUAL(10, c->GetKey()); + BOOST_CHECK_EQUAL(42, i->GetValue()); + + BOOST_CHECK_EQUAL(5, i->GetKey()); + free(backing); +} + +BOOST_AUTO_TEST_CASE(simple_sort) { + typedef ByteAlignedPacking Packing; + char foo[Packing::kBytes * 4]; + Packing::MutableIterator begin(Packing::FromVoid(foo)); + Packing::MutableIterator i = begin; + i->SetKey(0); ++i; + i->SetKey(2); ++i; + i->SetKey(3); ++i; + i->SetKey(1); ++i; + std::sort(begin, i); + BOOST_CHECK_EQUAL(0, begin[0].GetKey()); + BOOST_CHECK_EQUAL(1, begin[1].GetKey()); + BOOST_CHECK_EQUAL(2, begin[2].GetKey()); + BOOST_CHECK_EQUAL(3, begin[3].GetKey()); +} + +BOOST_AUTO_TEST_CASE(big_sort) { + typedef ByteAlignedPacking Packing; + boost::scoped_array memory(new char[Packing::kBytes * 1000]); + Packing::MutableIterator begin(Packing::FromVoid(memory.get())); + + boost::mt19937 rng; + boost::uniform_int range(0, std::numeric_limits::max()); + boost::variate_generator > gen(rng, range); + + for (size_t i = 0; i < 1000; ++i) { + (begin + i)->SetKey(gen()); + } + std::sort(begin, begin + 1000); + for (size_t i = 0; i < 999; ++i) { + BOOST_CHECK(begin[i] < begin[i+1]); + } +} + +} // namespace +} // namespace util diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc new file mode 100644 index 00000000..648b5d0a --- /dev/null +++ b/klm/util/mmap.cc @@ -0,0 +1,95 @@ +#include "util/exception.hh" +#include "util/mmap.hh" +#include "util/scoped.hh" + +#include +#include +#include +#include +#include +#include +#include + +namespace util { + +scoped_mmap::~scoped_mmap() { + if (data_ != (void*)-1) { + if (munmap(data_, size_)) + err(1, "munmap failed "); + } +} + +void scoped_memory::reset(void *data, std::size_t size, Alloc source) { + switch(source_) { + case MMAP_ALLOCATED: + scoped_mmap(data_, size_); + break; + case ARRAY_ALLOCATED: + delete [] reinterpret_cast(data_); + break; + case MALLOC_ALLOCATED: + free(data_); + break; + case NONE_ALLOCATED: + break; + } + data_ = data; + size_ = size; + source_ = source; +} + +void scoped_memory::call_realloc(std::size_t size) { + assert(source_ == MALLOC_ALLOCATED || source_ == NONE_ALLOCATED); + void *new_data = realloc(data_, size); + if (!new_data) { + reset(); + } else { + reset(new_data, size, MALLOC_ALLOCATED); + } +} + +void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, off_t offset) { +#ifdef MAP_POPULATE // Linux specific + if (prefault) { + flags |= MAP_POPULATE; + } + int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ; +#else + int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ; +#endif + void *ret = mmap(NULL, size, protect, flags, fd, offset); + if (ret == MAP_FAILED) { + UTIL_THROW(ErrnoException, "mmap failed for size " << size << " at offset " << offset); + } + return ret; +} + +void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset) { + return MapOrThrow(size, false, MAP_FILE | MAP_PRIVATE, prefault, fd, offset); +} + +void *MapAnonymous(std::size_t size) { + return MapOrThrow(size, true, +#ifdef MAP_ANONYMOUS + MAP_ANONYMOUS // Linux +#else + MAP_ANON // BSD +#endif + | MAP_PRIVATE, false, -1, 0); +} + +void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem) { + file.reset(open(name, O_CREAT | O_RDWR | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)); + if (-1 == file.get()) + UTIL_THROW(ErrnoException, "Failed to open " << name << " for writing"); + if (-1 == ftruncate(file.get(), size)) + UTIL_THROW(ErrnoException, "ftruncate on " << name << " to " << size << " failed"); + try { + mem.reset(MapOrThrow(size, true, MAP_FILE | MAP_SHARED, false, file.get(), 0), size); + } catch (ErrnoException &e) { + e << " in file " << name; + throw; + } +} + +} // namespace util diff --git a/klm/util/mmap.hh b/klm/util/mmap.hh new file mode 100644 index 00000000..c9068ec9 --- /dev/null +++ b/klm/util/mmap.hh @@ -0,0 +1,101 @@ +#ifndef UTIL_MMAP__ +#define UTIL_MMAP__ +// Utilities for mmaped files. + +#include "util/scoped.hh" + +#include + +#include + +namespace util { + +// (void*)-1 is MAP_FAILED; this is done to avoid including the mmap header here. +class scoped_mmap { + public: + scoped_mmap() : data_((void*)-1), size_(0) {} + scoped_mmap(void *data, std::size_t size) : data_(data), size_(size) {} + ~scoped_mmap(); + + void *get() const { return data_; } + + const char *begin() const { return reinterpret_cast(data_); } + const char *end() const { return reinterpret_cast(data_) + size_; } + std::size_t size() const { return size_; } + + void reset(void *data, std::size_t size) { + scoped_mmap other(data_, size_); + data_ = data; + size_ = size; + } + + void reset() { + reset((void*)-1, 0); + } + + private: + void *data_; + std::size_t size_; + + scoped_mmap(const scoped_mmap &); + scoped_mmap &operator=(const scoped_mmap &); +}; + +/* For when the memory might come from mmap, new char[], or malloc. Uses NULL + * and 0 for blanks even though mmap signals errors with (void*)-1). The reset + * function checks that blank for mmap. + */ +class scoped_memory { + public: + typedef enum {MMAP_ALLOCATED, ARRAY_ALLOCATED, MALLOC_ALLOCATED, NONE_ALLOCATED} Alloc; + + scoped_memory() : data_(NULL), size_(0), source_(NONE_ALLOCATED) {} + + ~scoped_memory() { reset(); } + + void *get() const { return data_; } + const char *begin() const { return reinterpret_cast(data_); } + const char *end() const { return reinterpret_cast(data_) + size_; } + std::size_t size() const { return size_; } + + Alloc source() const { return source_; } + + void reset() { reset(NULL, 0, NONE_ALLOCATED); } + + void reset(void *data, std::size_t size, Alloc from); + + // realloc allows the current data to escape hence the need for this call + // If realloc fails, destroys the original too and get() returns NULL. + void call_realloc(std::size_t to); + + private: + + void *data_; + std::size_t size_; + + Alloc source_; + + scoped_memory(const scoped_memory &); + scoped_memory &operator=(const scoped_memory &); +}; + +struct scoped_mapped_file { + scoped_fd fd; + scoped_mmap mem; +}; + +// Wrapper around mmap to check it worked and hide some platform macros. +void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, off_t offset = 0); +void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset = 0); + +void *MapAnonymous(std::size_t size); + +// Open file name with mmap of size bytes, all of which are initially zero. +void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem); +inline void MapZeroedWrite(const char *name, std::size_t size, scoped_mapped_file &out) { + MapZeroedWrite(name, size, out.fd, out.mem); +} + +} // namespace util + +#endif // UTIL_SCOPED__ diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc new file mode 100644 index 00000000..d58a0727 --- /dev/null +++ b/klm/util/murmur_hash.cc @@ -0,0 +1,129 @@ +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. + * length changed to unsigned int. + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ + const uint64_t m = 0xc6a4a7935bd1e995ULL; + const int r = 47; + + uint64_t h = seed ^ (len * m); + + const uint64_t * data = (const uint64_t *)key; + const uint64_t * end = data + (len/8); + + while(data != end) + { + uint64_t k = *data++; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + const unsigned char * data2 = (const unsigned char*)data; + + switch(len & 7) + { + case 7: h ^= uint64_t(data2[6]) << 48; + case 6: h ^= uint64_t(data2[5]) << 40; + case 5: h ^= uint64_t(data2[4]) << 32; + case 4: h ^= uint64_t(data2[3]) << 24; + case 3: h ^= uint64_t(data2[2]) << 16; + case 2: h ^= uint64_t(data2[1]) << 8; + case 1: h ^= uint64_t(data2[0]); + h *= m; + }; + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ + const unsigned int m = 0x5bd1e995; + const int r = 24; + + unsigned int h1 = seed ^ len; + unsigned int h2 = 0; + + const unsigned int * data = (const unsigned int *)key; + + while(len >= 8) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + + unsigned int k2 = *data++; + k2 *= m; k2 ^= k2 >> r; k2 *= m; + h2 *= m; h2 ^= k2; + len -= 4; + } + + if(len >= 4) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + } + + switch(len) + { + case 3: h2 ^= ((unsigned char*)data)[2] << 16; + case 2: h2 ^= ((unsigned char*)data)[1] << 8; + case 1: h2 ^= ((unsigned char*)data)[0]; + h2 *= m; + }; + + h1 ^= h2 >> 18; h1 *= m; + h2 ^= h1 >> 22; h2 *= m; + h1 ^= h2 >> 17; h1 *= m; + h2 ^= h1 >> 19; h2 *= m; + + uint64_t h = h1; + + h = (h << 32) | h2; + + return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { + if (sizeof(int) == 4) { + return MurmurHash64B(key, len, seed); + } else { + return MurmurHash64A(key, len, seed); + } +} + +} // namespace util diff --git a/klm/util/murmur_hash.hh b/klm/util/murmur_hash.hh new file mode 100644 index 00000000..638aaeb2 --- /dev/null +++ b/klm/util/murmur_hash.hh @@ -0,0 +1,14 @@ +#ifndef UTIL_MURMUR_HASH__ +#define UTIL_MURMUR_HASH__ +#include +#include + +namespace util { + +uint64_t MurmurHash64A(const void * key, std::size_t len, unsigned int seed = 0); +uint64_t MurmurHash64B(const void * key, std::size_t len, unsigned int seed = 0); +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed = 0); + +} // namespace util + +#endif // UTIL_MURMUR_HASH__ diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh new file mode 100644 index 00000000..c3529a7e --- /dev/null +++ b/klm/util/probing_hash_table.hh @@ -0,0 +1,97 @@ +#ifndef UTIL_PROBING_HASH_TABLE__ +#define UTIL_PROBING_HASH_TABLE__ + +#include +#include +#include + +#include + +namespace util { + +/* Non-standard hash table + * Buckets must be set at the beginning and must be greater than maximum number + * of elements, else an infinite loop happens. + * Memory management and initialization is externalized to make it easier to + * serialize these to disk and load them quickly. + * Uses linear probing to find value. + * Only insert and lookup operations. + */ + +template > class ProbingHashTable { + public: + typedef PackingT Packing; + typedef typename Packing::Key Key; + typedef typename Packing::MutableIterator MutableIterator; + typedef typename Packing::ConstIterator ConstIterator; + + typedef HashT Hash; + typedef EqualT Equal; + + static std::size_t Size(std::size_t entries, float multiplier) { + return std::max(entries + 1, static_cast(multiplier * static_cast(entries))) * Packing::kBytes; + } + + // Must be assigned to later. + ProbingHashTable() +#ifdef DEBUG + : initialized_(false), entries_(0) +#endif + {} + + ProbingHashTable(void *start, std::size_t allocated, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) + : begin_(Packing::FromVoid(start)), + buckets_(allocated / Packing::kBytes), + end_(begin_ + (allocated / Packing::kBytes)), + invalid_(invalid), + hash_(hash_func), + equal_(equal_func) +#ifdef DEBUG + , initialized_(true), + entries_(0) +#endif + {} + + template void Insert(const T &t) { +#ifdef DEBUG + assert(initialized_); + assert(++entries_ < buckets_); +#endif + for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) { + if (equal_(i->GetKey(), invalid_)) { *i = t; return; } + if (++i == end_) { i = begin_; } + } + } + + void FinishedInserting() {} + + void LoadedBinary() {} + + template bool Find(const Key key, ConstIterator &out) const { +#ifdef DEBUG + assert(initialized_); +#endif + for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) { + Key got(i->GetKey()); + if (equal_(got, key)) { out = i; return true; } + if (equal_(got, invalid_)) { return false; } + if (++i == end_) { i = begin_; } + } + } + + private: + MutableIterator begin_; + std::size_t buckets_; + MutableIterator end_; + Key invalid_; + Hash hash_; + Equal equal_; +#ifdef DEBUG + bool initialized_; + std::size_t entries_; +#endif +}; + +} // namespace util + +#endif // UTIL_PROBING_HASH_TABLE__ diff --git a/klm/util/probing_hash_table_test.cc b/klm/util/probing_hash_table_test.cc new file mode 100644 index 00000000..ff2f5af3 --- /dev/null +++ b/klm/util/probing_hash_table_test.cc @@ -0,0 +1,30 @@ +#include "util/probing_hash_table.hh" + +#include "util/key_value_packing.hh" + +#define BOOST_TEST_MODULE ProbingHashTableTest +#include +#include + +namespace util { +namespace { + +typedef AlignedPacking Packing; +typedef ProbingHashTable > Table; + +BOOST_AUTO_TEST_CASE(simple) { + char mem[Table::Size(10, 1.2)]; + memset(mem, 0, sizeof(mem)); + + Table table(mem, sizeof(mem)); + Packing::ConstIterator i = Packing::ConstIterator(); + BOOST_CHECK(!table.Find(2, i)); + table.Insert(Packing::Make(3, 328920)); + BOOST_REQUIRE(table.Find(3, i)); + BOOST_CHECK_EQUAL(3, i->GetKey()); + BOOST_CHECK_EQUAL(static_cast(328920), i->GetValue()); + BOOST_CHECK(!table.Find(2, i)); +} + +} // namespace +} // namespace util diff --git a/klm/util/proxy_iterator.hh b/klm/util/proxy_iterator.hh new file mode 100644 index 00000000..1c5b7089 --- /dev/null +++ b/klm/util/proxy_iterator.hh @@ -0,0 +1,94 @@ +#ifndef UTIL_PROXY_ITERATOR__ +#define UTIL_PROXY_ITERATOR__ + +#include +#include + +/* This is a RandomAccessIterator that uses a proxy to access the underlying + * data. Useful for packing data at bit offsets but still using STL + * algorithms. + * + * Normally I would use boost::iterator_facade but some people are too lazy to + * install boost and still want to use my language model. It's amazing how + * many operators an iterator has. + * + * The Proxy needs to provide: + * class InnerIterator; + * InnerIterator &Inner(); + * const InnerIterator &Inner() const; + * + * InnerIterator has to implement: + * operator==(InnerIterator) + * operator<(InnerIterator) + * operator+=(std::ptrdiff_t) + * operator-(InnerIterator) + * and of course whatever Proxy needs to dereference it. + * + * It's also a good idea to specialize std::swap for Proxy. + */ + +namespace util { +template class ProxyIterator { + private: + // Self. + typedef ProxyIterator S; + typedef typename Proxy::InnerIterator InnerIterator; + + public: + typedef std::random_access_iterator_tag iterator_category; + typedef typename Proxy::value_type value_type; + typedef std::ptrdiff_t difference_type; + typedef Proxy reference; + typedef Proxy * pointer; + + ProxyIterator() {} + + // For cast from non const to const. + template ProxyIterator(const ProxyIterator &in) : p_(*in) {} + explicit ProxyIterator(const Proxy &p) : p_(p) {} + + // p_'s operator= does value copying, but here we want iterator copying. + S &operator=(const S &other) { + I() = other.I(); + return *this; + } + + bool operator==(const S &other) const { return I() == other.I(); } + bool operator!=(const S &other) const { return !(*this == other); } + bool operator<(const S &other) const { return I() < other.I(); } + bool operator>(const S &other) const { return other < *this; } + bool operator<=(const S &other) const { return !(*this > other); } + bool operator>=(const S &other) const { return !(*this < other); } + + S &operator++() { return *this += 1; } + S operator++(int) { S ret(*this); ++*this; return ret; } + S &operator+=(std::ptrdiff_t amount) { I() += amount; return *this; } + S operator+(std::ptrdiff_t amount) const { S ret(*this); ret += amount; return ret; } + + S &operator--() { return *this -= 1; } + S operator--(int) { S ret(*this); --*this; return ret; } + S &operator-=(std::ptrdiff_t amount) { I() += (-amount); return *this; } + S operator-(std::ptrdiff_t amount) const { S ret(*this); ret -= amount; return ret; } + + std::ptrdiff_t operator-(const S &other) const { return I() - other.I(); } + + Proxy operator*() { return p_; } + const Proxy operator*() const { return p_; } + Proxy *operator->() { return &p_; } + const Proxy *operator->() const { return &p_; } + Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); } + + private: + InnerIterator &I() { return p_.Inner(); } + const InnerIterator &I() const { return p_.Inner(); } + + Proxy p_; +}; + +template ProxyIterator operator+(std::ptrdiff_t amount, const ProxyIterator &it) { + return it + amount; +} + +} // namespace util + +#endif // UTIL_PROXY_ITERATOR__ diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc new file mode 100644 index 00000000..61394ffc --- /dev/null +++ b/klm/util/scoped.cc @@ -0,0 +1,12 @@ +#include "util/scoped.hh" + +#include +#include + +namespace util { + +scoped_fd::~scoped_fd() { + if (fd_ != -1 && close(fd_)) err(1, "Could not close file %i", fd_); +} + +} // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh new file mode 100644 index 00000000..ef62a74f --- /dev/null +++ b/klm/util/scoped.hh @@ -0,0 +1,66 @@ +#ifndef UTIL_SCOPED__ +#define UTIL_SCOPED__ + +/* Other scoped objects in the style of scoped_ptr. */ + +#include + +namespace util { + +template class scoped_thing { + public: + explicit scoped_thing(T *c = static_cast(0)) : c_(c) {} + + ~scoped_thing() { if (c_) Free(c_); } + + void reset(T *c) { + if (c_) Free(c_); + c_ = c; + } + + T &operator*() { return *c_; } + T &operator->() { return *c_; } + + T *get() { return c_; } + const T *get() const { return c_; } + + private: + T *c_; + + scoped_thing(const scoped_thing &); + scoped_thing &operator=(const scoped_thing &); +}; + +class scoped_fd { + public: + scoped_fd() : fd_(-1) {} + + explicit scoped_fd(int fd) : fd_(fd) {} + + ~scoped_fd(); + + void reset(int to) { + scoped_fd other(fd_); + fd_ = to; + } + + int get() const { return fd_; } + + int operator*() const { return fd_; } + + int release() { + int ret = fd_; + fd_ = -1; + return ret; + } + + private: + int fd_; + + scoped_fd(const scoped_fd &); + scoped_fd &operator=(const scoped_fd &); +}; + +} // namespace util + +#endif // UTIL_SCOPED__ diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh new file mode 100644 index 00000000..96ec4866 --- /dev/null +++ b/klm/util/sorted_uniform.hh @@ -0,0 +1,139 @@ +#ifndef UTIL_SORTED_UNIFORM__ +#define UTIL_SORTED_UNIFORM__ + +#include +#include + +#include +#include + +namespace util { + +inline std::size_t Pivot(uint64_t off, uint64_t range, std::size_t width) { + std::size_t ret = static_cast(static_cast(off) / static_cast(range) * static_cast(width)); + // Cap for floating point rounding + return (ret < width) ? ret : width - 1; +} +/*inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) { + return static_cast(static_cast(off) * static_cast(width) / static_cast(range)); +} +inline std::size_t Pivot(uint16_t off, uint16_t range, std::size_t width) { + return static_cast(static_cast(off) * width / static_cast(range)); +} +inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t width) { + return static_cast(static_cast(off) * width / static_cast(range)); +}*/ + +template bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) { + if (begin == end) return false; + Key below(begin->GetKey()); + if (key <= below) { + if (key == below) { out = begin; return true; } + return false; + } + // Make the range [begin, end]. + --end; + Key above(end->GetKey()); + if (key >= above) { + if (key == above) { out = end; return true; } + return false; + } + + // Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above. + while (end - begin > 1) { + Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast(end - begin - 1)))); + Key mid(pivot->GetKey()); + if (mid < key) { + begin = pivot; + below = mid; + } else if (mid > key) { + end = pivot; + above = mid; + } else { + out = pivot; + return true; + } + } + return false; +} + +// To use this template, you need to define a Pivot function to match Key. +template class SortedUniformMap { + public: + typedef PackingT Packing; + typedef typename Packing::ConstIterator ConstIterator; + + public: + // Offer consistent API with probing hash. + static std::size_t Size(std::size_t entries, float ignore = 0.0) { + return sizeof(uint64_t) + entries * Packing::kBytes; + } + + SortedUniformMap() +#ifdef DEBUG + : initialized_(false), loaded_(false) +#endif + {} + + SortedUniformMap(void *start, std::size_t allocated) : + begin_(Packing::FromVoid(reinterpret_cast(start) + 1)), + end_(begin_), size_ptr_(reinterpret_cast(start)) +#ifdef DEBUG + , initialized_(true), loaded_(false) +#endif + {} + + void LoadedBinary() { +#ifdef DEBUG + assert(initialized_); + assert(!loaded_); + loaded_ = true; +#endif + // Restore the size. + end_ = begin_ + *size_ptr_; + } + + // Caller responsible for not exceeding specified size. Do not call after FinishedInserting. + template void Insert(const T &t) { +#ifdef DEBUG + assert(initialized_); + assert(!loaded_); +#endif + *end_ = t; + ++end_; + } + + void FinishedInserting() { +#ifdef DEBUG + assert(initialized_); + assert(!loaded_); + loaded_ = true; +#endif + std::sort(begin_, end_); + *size_ptr_ = (end_ - begin_); + } + + // Do not call before FinishedInserting. + template bool Find(const Key key, ConstIterator &out) const { +#ifdef DEBUG + assert(initialized_); + assert(loaded_); +#endif + return SortedUniformFind(ConstIterator(begin_), ConstIterator(end_), key, out); + } + + ConstIterator begin() const { return begin_; } + ConstIterator end() const { return end_; } + + private: + typename Packing::MutableIterator begin_, end_; + uint64_t *size_ptr_; +#ifdef DEBUG + bool initialized_; + bool loaded_; +#endif +}; + +} // namespace util + +#endif // UTIL_SORTED_UNIFORM__ diff --git a/klm/util/sorted_uniform_test.cc b/klm/util/sorted_uniform_test.cc new file mode 100644 index 00000000..4aa4c8aa --- /dev/null +++ b/klm/util/sorted_uniform_test.cc @@ -0,0 +1,116 @@ +#include "util/sorted_uniform.hh" + +#include "util/key_value_packing.hh" + +#include +#include +#include +#include +#include +#define BOOST_TEST_MODULE SortedUniformTest +#include + +#include +#include +#include + +namespace util { +namespace { + +template void Check(const Map &map, const boost::unordered_map &reference, const Key key) { + typename boost::unordered_map::const_iterator ref = reference.find(key); + typename Map::ConstIterator i = typename Map::ConstIterator(); + if (ref == reference.end()) { + BOOST_CHECK(!map.Find(key, i)); + } else { + // g++ can't tell that require will crash and burn. + BOOST_REQUIRE(map.Find(key, i)); + BOOST_CHECK_EQUAL(ref->second, i->GetValue()); + } +} + +typedef SortedUniformMap > TestMap; + +BOOST_AUTO_TEST_CASE(empty) { + char buf[TestMap::Size(0)]; + TestMap map(buf, TestMap::Size(0)); + map.FinishedInserting(); + TestMap::ConstIterator i; + BOOST_CHECK(!map.Find(42, i)); +} + +BOOST_AUTO_TEST_CASE(one) { + char buf[TestMap::Size(1)]; + TestMap map(buf, sizeof(buf)); + Entry e; + e.Set(42,2); + map.Insert(e); + map.FinishedInserting(); + TestMap::ConstIterator i = TestMap::ConstIterator(); + BOOST_REQUIRE(map.Find(42, i)); + BOOST_CHECK(i == map.begin()); + BOOST_CHECK(!map.Find(43, i)); + BOOST_CHECK(!map.Find(41, i)); +} + +template void RandomTest(Key upper, size_t entries, size_t queries) { + typedef unsigned char Value; + typedef SortedUniformMap > Map; + boost::scoped_array buffer(new char[Map::Size(entries)]); + Map map(buffer.get(), entries); + boost::mt19937 rng; + boost::uniform_int range_key(0, upper); + boost::uniform_int range_value(0, 255); + boost::variate_generator > gen_key(rng, range_key); + boost::variate_generator > gen_value(rng, range_value); + + boost::unordered_map reference; + Entry ent; + for (size_t i = 0; i < entries; ++i) { + Key key = gen_key(); + unsigned char value = gen_value(); + if (reference.insert(std::make_pair(key, value)).second) { + ent.Set(key, value); + map.Insert(Entry(ent)); + } + } + map.FinishedInserting(); + + // Random queries. + for (size_t i = 0; i < queries; ++i) { + const Key key = gen_key(); + Check(map, reference, key); + } + + typename boost::unordered_map::const_iterator it = reference.begin(); + for (size_t i = 0; (i < queries) && (it != reference.end()); ++i, ++it) { + Check(map, reference, it->second); + } +} + +BOOST_AUTO_TEST_CASE(basic) { + RandomTest(11, 10, 200); +} + +BOOST_AUTO_TEST_CASE(tiny_dense_random) { + RandomTest(11, 50, 200); +} + +BOOST_AUTO_TEST_CASE(small_dense_random) { + RandomTest(100, 100, 200); +} + +BOOST_AUTO_TEST_CASE(small_sparse_random) { + RandomTest(200, 15, 200); +} + +BOOST_AUTO_TEST_CASE(medium_sparse_random) { + RandomTest(32000, 1000, 2000); +} + +BOOST_AUTO_TEST_CASE(sparse_random) { + RandomTest(std::numeric_limits::max(), 100000, 2000); +} + +} // namespace +} // namespace util diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc new file mode 100644 index 00000000..6917a6bc --- /dev/null +++ b/klm/util/string_piece.cc @@ -0,0 +1,57 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// Copied from strings/stringpiece.cc with modifications + +#include "util/string_piece.hh" + +#ifdef USE_BOOST +#include +#endif + +#include +#include + +#ifdef USE_ICU +U_NAMESPACE_BEGIN +#endif + +std::ostream& operator<<(std::ostream& o, const StringPiece& piece) { + o.write(piece.data(), static_cast(piece.size())); + return o; +} + +#ifdef USE_BOOST +size_t hash_value(const StringPiece &str) { + return boost::hash_range(str.data(), str.data() + str.length()); +} +#endif + +#ifdef USE_ICU +U_NAMESPACE_END +#endif diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh new file mode 100644 index 00000000..58008d13 --- /dev/null +++ b/klm/util/string_piece.hh @@ -0,0 +1,260 @@ +/* If you use ICU in your program, then compile with -DUSE_ICU -licui18n. If + * you don't use ICU, then this will use the Google implementation from Chrome. + * This has been modified from the original version to let you choose. + */ + +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// Copied from strings/stringpiece.h with modifications +// +// A string-like object that points to a sized piece of memory. +// +// Functions or methods may use const StringPiece& parameters to accept either +// a "const char*" or a "string" value that will be implicitly converted to +// a StringPiece. The implicit conversion means that it is often appropriate +// to include this .h file in other files rather than forward-declaring +// StringPiece as would be appropriate for most other Google classes. +// +// Systematic usage of StringPiece is encouraged as it will reduce unnecessary +// conversions from "const char*" to "string" and back again. +// + +#ifndef BASE_STRING_PIECE_H__ +#define BASE_STRING_PIECE_H__ + +//Uncomment this line if you use ICU in your code. +//#define USE_ICU +//Uncomment this line if you want boost hashing for your StringPieces. +//#define USE_BOOST + +#include +#include + +#ifdef USE_ICU +#include +U_NAMESPACE_BEGIN +#else + +#include +#include +#include +#include + +class StringPiece { + public: + typedef size_t size_type; + + private: + const char* ptr_; + size_type length_; + + public: + // We provide non-explicit singleton constructors so users can pass + // in a "const char*" or a "string" wherever a "StringPiece" is + // expected. + StringPiece() : ptr_(NULL), length_(0) { } + StringPiece(const char* str) + : ptr_(str), length_((str == NULL) ? 0 : strlen(str)) { } + StringPiece(const std::string& str) + : ptr_(str.data()), length_(str.size()) { } + StringPiece(const char* offset, size_type len) + : ptr_(offset), length_(len) { } + + // data() may return a pointer to a buffer with embedded NULs, and the + // returned buffer may or may not be null terminated. Therefore it is + // typically a mistake to pass data() to a routine that expects a NUL + // terminated string. + const char* data() const { return ptr_; } + size_type size() const { return length_; } + size_type length() const { return length_; } + bool empty() const { return length_ == 0; } + + void clear() { ptr_ = NULL; length_ = 0; } + void set(const char* data, size_type len) { ptr_ = data; length_ = len; } + void set(const char* str) { + ptr_ = str; + length_ = str ? strlen(str) : 0; + } + void set(const void* data, size_type len) { + ptr_ = reinterpret_cast(data); + length_ = len; + } + + char operator[](size_type i) const { return ptr_[i]; } + + void remove_prefix(size_type n) { + ptr_ += n; + length_ -= n; + } + + void remove_suffix(size_type n) { + length_ -= n; + } + + int compare(const StringPiece& x) const { + int r = wordmemcmp(ptr_, x.ptr_, std::min(length_, x.length_)); + if (r == 0) { + if (length_ < x.length_) r = -1; + else if (length_ > x.length_) r = +1; + } + return r; + } + + std::string as_string() const { + // std::string doesn't like to take a NULL pointer even with a 0 size. + return std::string(!empty() ? data() : "", size()); + } + + void CopyToString(std::string* target) const; + void AppendToString(std::string* target) const; + + // Does "this" start with "x" + bool starts_with(const StringPiece& x) const { + return ((length_ >= x.length_) && + (wordmemcmp(ptr_, x.ptr_, x.length_) == 0)); + } + + // Does "this" end with "x" + bool ends_with(const StringPiece& x) const { + return ((length_ >= x.length_) && + (wordmemcmp(ptr_ + (length_-x.length_), x.ptr_, x.length_) == 0)); + } + + // standard STL container boilerplate + typedef char value_type; + typedef const char* pointer; + typedef const char& reference; + typedef const char& const_reference; + typedef ptrdiff_t difference_type; + static const size_type npos; + typedef const char* const_iterator; + typedef const char* iterator; + typedef std::reverse_iterator const_reverse_iterator; + typedef std::reverse_iterator reverse_iterator; + iterator begin() const { return ptr_; } + iterator end() const { return ptr_ + length_; } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(ptr_ + length_); + } + const_reverse_iterator rend() const { + return const_reverse_iterator(ptr_); + } + + size_type max_size() const { return length_; } + size_type capacity() const { return length_; } + + size_type copy(char* buf, size_type n, size_type pos = 0) const; + + size_type find(const StringPiece& s, size_type pos = 0) const; + size_type find(char c, size_type pos = 0) const; + size_type rfind(const StringPiece& s, size_type pos = npos) const; + size_type rfind(char c, size_type pos = npos) const; + + size_type find_first_of(const StringPiece& s, size_type pos = 0) const; + size_type find_first_of(char c, size_type pos = 0) const { + return find(c, pos); + } + size_type find_first_not_of(const StringPiece& s, size_type pos = 0) const; + size_type find_first_not_of(char c, size_type pos = 0) const; + size_type find_last_of(const StringPiece& s, size_type pos = npos) const; + size_type find_last_of(char c, size_type pos = npos) const { + return rfind(c, pos); + } + size_type find_last_not_of(const StringPiece& s, size_type pos = npos) const; + size_type find_last_not_of(char c, size_type pos = npos) const; + + StringPiece substr(size_type pos, size_type n = npos) const; + + static int wordmemcmp(const char* p, const char* p2, size_type N) { + return memcmp(p, p2, N); + } +}; + +inline bool operator==(const StringPiece& x, const StringPiece& y) { + if (x.size() != y.size()) + return false; + + return std::memcmp(x.data(), y.data(), x.size()) == 0; +} + +inline bool operator!=(const StringPiece& x, const StringPiece& y) { + return !(x == y); +} + +#endif + +inline bool operator<(const StringPiece& x, const StringPiece& y) { + const int r = std::memcmp(x.data(), y.data(), + std::min(x.size(), y.size())); + return ((r < 0) || ((r == 0) && (x.size() < y.size()))); +} + +inline bool operator>(const StringPiece& x, const StringPiece& y) { + return y < x; +} + +inline bool operator<=(const StringPiece& x, const StringPiece& y) { + return !(x > y); +} + +inline bool operator>=(const StringPiece& x, const StringPiece& y) { + return !(x < y); +} + +// allow StringPiece to be logged (needed for unit testing). +extern std::ostream& operator<<(std::ostream& o, const StringPiece& piece); + +#ifdef USE_BOOST +size_t hash_value(const StringPiece &str); + +/* Support for lookup of StringPiece in boost::unordered_map */ +struct StringPieceCompatibleHash : public std::unary_function { + size_t operator()(const StringPiece &str) const { + return hash_value(str); + } +}; + +struct StringPieceCompatibleEquals : public std::binary_function { + bool operator()(const StringPiece &first, const StringPiece &second) const { + return first == second; + } +}; +template typename T::const_iterator FindStringPiece(const T &t, const StringPiece &key) { + return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +} +template typename T::iterator FindStringPiece(T &t, const StringPiece &key) { + return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +} +#endif + +#ifdef USE_ICU +U_NAMESPACE_END +#endif + +#endif // BASE_STRING_PIECE_H__ -- cgit v1.2.3