diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2010-12-13 16:18:34 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2010-12-13 16:18:34 -0500 |
commit | be98f29f51350c24136c191f01af3fbfe340ef78 (patch) | |
tree | 2e104152110ca76b527147458050a41934e031f2 /klm | |
parent | 063c0623aaf5dad8d02e5eae5793c123cd7fc3fe (diff) |
new version of kenlm
Diffstat (limited to 'klm')
-rw-r--r-- | klm/lm/binary_format.cc | 1 | ||||
-rw-r--r-- | klm/lm/binary_format.hh | 9 | ||||
-rw-r--r-- | klm/lm/build_binary.cc | 112 | ||||
-rw-r--r-- | klm/lm/enumerate_vocab.hh | 7 | ||||
-rw-r--r-- | klm/lm/lm_exception.cc | 8 | ||||
-rw-r--r-- | klm/lm/lm_exception.hh | 18 | ||||
-rw-r--r-- | klm/lm/model.cc | 1 | ||||
-rw-r--r-- | klm/lm/model.hh | 23 | ||||
-rw-r--r-- | klm/lm/model_test.cc | 4 | ||||
-rw-r--r-- | klm/lm/read_arpa.cc | 4 | ||||
-rw-r--r-- | klm/lm/read_arpa.hh | 2 | ||||
-rw-r--r-- | klm/lm/search_trie.cc | 200 | ||||
-rw-r--r-- | klm/lm/trie.cc | 42 | ||||
-rw-r--r-- | klm/util/bit_packing.cc | 13 | ||||
-rw-r--r-- | klm/util/bit_packing.hh | 48 | ||||
-rw-r--r-- | klm/util/bit_packing_test.cc | 46 | ||||
-rw-r--r-- | klm/util/exception.cc | 5 | ||||
-rw-r--r-- | klm/util/file_piece.cc | 99 | ||||
-rw-r--r-- | klm/util/file_piece.hh | 5 | ||||
-rw-r--r-- | klm/util/file_piece_test.cc | 29 | ||||
-rw-r--r-- | klm/util/mmap.cc | 24 | ||||
-rw-r--r-- | klm/util/murmur_hash.hh | 2 | ||||
-rw-r--r-- | klm/util/scoped.cc | 14 | ||||
-rw-r--r-- | klm/util/string_piece.hh | 2 |
24 files changed, 547 insertions, 171 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 2a075b6b..69a06355 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -141,7 +141,6 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t } uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) { - if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0"); if (config.write_mmap) { std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size; // Write out an mmap file. diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index f95f05f7..a43c883c 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -67,9 +67,12 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to) if (detail::IsBinaryFormat(backing.file.get())) { detail::ReadHeader(backing.file.get(), params); detail::MatchCheck(To::kModelType, params); - std::size_t memory_size = To::Size(params.counts, config); - uint8_t *start = detail::SetupBinary(config, params, memory_size, backing); - to.InitializeFromBinary(start, params, config, backing.file.get()); + // Replace the probing_multiplier. + Config new_config(config); + new_config.probing_multiplier = params.fixed.probing_multiplier; + std::size_t memory_size = To::Size(params.counts, new_config); + uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); + to.InitializeFromBinary(start, params, new_config, backing.file.get()); } else { detail::ComplainAboutARPA(config, To::kModelType); util::FilePiece f(backing.file.release(), file, config.messages); diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 4db631a2..ec034640 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -1,13 +1,113 @@ #include "lm/model.hh" +#include "util/file_piece.hh" #include <iostream> +#include <iomanip> + +#include <math.h> +#include <stdlib.h> +#include <unistd.h> + +namespace lm { +namespace ngram { +namespace { + +void Usage(const char *name) { + std::cerr << "Usage: " << name << " [-u unknown_probability] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" +"Where type is one of probing, trie, or sorted:\n\n" +"probing uses a probing hash table. It is the fastest but uses the most memory.\n" +"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" +"trie is a straightforward trie with bit-level packing. It uses the least\n" +"memory and is still faster than SRI or IRST. Building the trie format uses an\n" +"on-disk sort to save memory.\n" +"-t is the temporary directory prefix. Default is the output file name.\n" +"-m is the amount of memory to use, in MB. Default is 1024MB (1GB).\n\n" +"sorted is like probing but uses a sorted uniform map instead of a hash table.\n" +"It uses more memory than trie and is also slower, so there's no real reason to\n" +"use it.\n\n" +"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" +"Passing only an input file will print memory usage of each data structure.\n" +"If the ARPA file does not have <unk>, -u sets <unk>'s probability; default 0.0.\n"; + exit(1); +} + +// I could really use boost::lexical_cast right about now. +float ParseFloat(const char *from) { + char *end; + float ret = strtod(from, &end); + if (*end) throw util::ParseNumberException(from); + return ret; +} +unsigned long int ParseUInt(const char *from) { + char *end; + unsigned long int ret = strtoul(from, &end, 10); + if (*end) throw util::ParseNumberException(from); + return ret; +} + +void ShowSizes(const char *file, const lm::ngram::Config &config) { + std::vector<uint64_t> counts; + util::FilePiece f(file); + lm::ReadARPACounts(f, counts); + std::size_t probing_size = ProbingModel::Size(counts, config); + // probing is always largest so use it to determine number of columns. + long int length = std::max<long int>(5, lrint(ceil(log10(probing_size)))); + std::cout << "Memory usage:\ntype "; + // right align bytes. + for (long int i = 0; i < length - 5; ++i) std::cout << ' '; + std::cout << "bytes\n" + "probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n" + "trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n" + "sorted " << std::setw(length) << SortedModel::Size(counts, config) << "\n"; +} + +} // namespace ngram +} // namespace lm +} // namespace int main(int argc, char *argv[]) { - if (argc != 3) { - std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl; - return 1; - } + using namespace lm::ngram; + lm::ngram::Config config; - config.write_mmap = argv[2]; - lm::ngram::Model(argv[1], config); + int opt; + while ((opt = getopt(argc, argv, "u:p:t:m:")) != -1) { + switch(opt) { + case 'u': + config.unknown_missing_prob = ParseFloat(optarg); + break; + case 'p': + config.probing_multiplier = ParseFloat(optarg); + break; + case 't': + config.temporary_directory_prefix = optarg; + break; + case 'm': + config.building_memory = ParseUInt(optarg) * 1048576; + break; + default: + Usage(argv[0]); + } + } + if (optind + 1 == argc) { + ShowSizes(argv[optind], config); + } else if (optind + 2 == argc) { + config.write_mmap = argv[optind + 1]; + ProbingModel(argv[optind], config); + } else if (optind + 3 == argc) { + const char *model_type = argv[optind]; + const char *from_file = argv[optind + 1]; + config.write_mmap = argv[optind + 2]; + if (!strcmp(model_type, "probing")) { + ProbingModel(from_file, config); + } else if (!strcmp(model_type, "sorted")) { + SortedModel(from_file, config); + } else if (!strcmp(model_type, "trie")) { + TrieModel(from_file, config); + } else { + Usage(argv[0]); + } + } else { + Usage(argv[0]); + } + return 0; } diff --git a/klm/lm/enumerate_vocab.hh b/klm/lm/enumerate_vocab.hh index 7a2f7d12..e734316b 100644 --- a/klm/lm/enumerate_vocab.hh +++ b/klm/lm/enumerate_vocab.hh @@ -8,9 +8,10 @@ namespace lm { namespace ngram { /* If you need the actual strings in the vocabulary, inherit from this class - * and implement Add. Then put a pointer in Config.enumerate_vocab. - * Add is called once per n-gram. index starts at 0 and increases by 1 each - * time. + * and implement Add. Then put a pointer in Config.enumerate_vocab; it does + * not take ownership. Add is called once per vocab word. index starts at 0 + * and increases by 1 each time. This is only used by the Model constructor; + * the pointer is not retained by the class. */ class EnumerateVocab { public: diff --git a/klm/lm/lm_exception.cc b/klm/lm/lm_exception.cc index ab2ec52f..473849d1 100644 --- a/klm/lm/lm_exception.cc +++ b/klm/lm/lm_exception.cc @@ -5,14 +5,18 @@ namespace lm { +ConfigException::ConfigException() throw() {} +ConfigException::~ConfigException() throw() {} + LoadException::LoadException() throw() {} LoadException::~LoadException() throw() {} -VocabLoadException::VocabLoadException() throw() {} -VocabLoadException::~VocabLoadException() throw() {} FormatLoadException::FormatLoadException() throw() {} FormatLoadException::~FormatLoadException() throw() {} +VocabLoadException::VocabLoadException() throw() {} +VocabLoadException::~VocabLoadException() throw() {} + SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() { *this << "Missing special word " << which; } diff --git a/klm/lm/lm_exception.hh b/klm/lm/lm_exception.hh index 1216c4c7..3773c572 100644 --- a/klm/lm/lm_exception.hh +++ b/klm/lm/lm_exception.hh @@ -11,6 +11,12 @@ namespace lm { +class ConfigException : public util::Exception { + public: + ConfigException() throw(); + ~ConfigException() throw(); +}; + class LoadException : public util::Exception { public: virtual ~LoadException() throw(); @@ -19,18 +25,18 @@ class LoadException : public util::Exception { LoadException() throw(); }; -class VocabLoadException : public LoadException { - public: - virtual ~VocabLoadException() throw(); - VocabLoadException() throw(); -}; - class FormatLoadException : public LoadException { public: FormatLoadException() throw(); ~FormatLoadException() throw(); }; +class VocabLoadException : public LoadException { + public: + virtual ~VocabLoadException() throw(); + VocabLoadException() throw(); +}; + class SpecialWordMissingException : public VocabLoadException { public: explicit SpecialWordMissingException(StringPiece which) throw(); diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 6921d4d9..421e72fa 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -23,6 +23,7 @@ namespace detail { template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) { if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } diff --git a/klm/lm/model.hh b/klm/lm/model.hh index e0eeee17..53e5773d 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -12,6 +12,8 @@ #include <algorithm> #include <vector> +#include <string.h> + namespace util { class FilePiece; } namespace lm { @@ -21,9 +23,10 @@ namespace ngram { // Having this limit means that State can be // (kMaxOrder - 1) * sizeof(float) bytes instead of // sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead -const std::size_t kMaxOrder = 6; +const unsigned char kMaxOrder = 6; -// This is a POD. +// This is a POD but if you want memcmp to return the same as operator==, call +// ZeroRemaining first. class State { public: bool operator==(const State &other) const { @@ -37,6 +40,22 @@ class State { return true; } + // Three way comparison function. + int Compare(const State &other) const { + if (valid_length_ == other.valid_length_) { + return memcmp(history_, other.history_, valid_length_ * sizeof(WordIndex)); + } + return (valid_length_ < other.valid_length_) ? -1 : 1; + } + + // Call this before using raw memcmp. + void ZeroRemaining() { + for (unsigned char i = valid_length_; i < kMaxOrder - 1; ++i) { + history_[i] = 0; + backoff_[i] = 0.0; + } + } + // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. WordIndex history_[kMaxOrder - 1]; diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 159628d4..b5125a95 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -4,6 +4,7 @@ #define BOOST_TEST_MODULE ModelTest #include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp> namespace lm { namespace ngram { @@ -123,7 +124,7 @@ class ExpectEnumerateVocab : public EnumerateVocab { } void Check(const base::Vocabulary &vocab) { - BOOST_CHECK_EQUAL(34, seen.size()); + BOOST_CHECK_EQUAL(34ULL, seen.size()); BOOST_REQUIRE(!seen.empty()); BOOST_CHECK_EQUAL("<unk>", seen[0]); for (WordIndex i = 0; i < seen.size(); ++i) { @@ -144,6 +145,7 @@ template <class ModelT> void LoadingTest() { config.messages = NULL; ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; + config.probing_multiplier = 2.0; ModelT m("test.arpa", config); enumerate.Check(m.GetVocabulary()); Starters(m); diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 8e9a770d..262a9c6a 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -49,7 +49,7 @@ template <class F> void GenericReadNGramHeader(F &in, unsigned int length) { while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} std::stringstream expected; expected << '\\' << length << "-grams:"; - if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead. "); + if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead"); } template <class F> void GenericReadEnd(F &in) { @@ -110,7 +110,7 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) { { float got = in.ReadFloat(); if (got != 0.0) - UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff."); + UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff"); } break; case '\n': diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh index cabdb195..571fcbc5 100644 --- a/klm/lm/read_arpa.hh +++ b/klm/lm/read_arpa.hh @@ -54,7 +54,7 @@ template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const uns } ReadBackoff(f, weights); } catch(util::Exception &e) { - e << " in the " << n << "-gram at byte " << f.Offset(); + e << " in the " << static_cast<unsigned int>(n) << "-gram at byte " << f.Offset(); throw; } } diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 182e27f5..12294682 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,3 +1,4 @@ +/* This is where the trie is built. It's on-disk. */ #include "lm/search_trie.hh" #include "lm/lm_exception.hh" @@ -8,6 +9,7 @@ #include "lm/word_index.hh" #include "util/ersatz_progress.hh" #include "util/file_piece.hh" +#include "util/proxy_iterator.hh" #include "util/scoped.hh" #include <algorithm> @@ -30,43 +32,119 @@ namespace ngram { namespace trie { namespace { -template <unsigned char Order> class FullEntry { +/* An entry is a n-gram with probability. It consists of: + * WordIndex[order] + * float probability + * backoff probability (omitted for highest order n-gram) + * These are stored consecutively in memory. We want to sort them. + * + * The problem is the length depends on order (but all n-grams being compared + * have the same order). Allocating each entry on the heap (i.e. std::vector + * or std::string) then sorting pointers is the normal solution. But that's + * too memory inefficient. A lot of this code is just here to force std::sort + * to work with records where length is specified at runtime (and avoid using + * Boost for LM code). I could have used qsort, but the point is to also + * support __gnu_cxx:parallel_sort which doesn't have a qsort version. + */ + +class EntryIterator { public: - typedef ProbBackoff Weights; - static const unsigned char kOrder = Order; + EntryIterator() {} - // reverse order - WordIndex words[Order]; - Weights weights; + EntryIterator(void *ptr, std::size_t size) : ptr_(static_cast<uint8_t*>(ptr)), size_(size) {} - bool operator<(const FullEntry<Order> &other) const { - for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { - if (*i < *j) return true; - if (*i > *j) return false; - } - return false; + bool operator==(const EntryIterator &other) const { + return ptr_ == other.ptr_; + } + bool operator<(const EntryIterator &other) const { + return ptr_ < other.ptr_; + } + EntryIterator &operator+=(std::ptrdiff_t amount) { + ptr_ += amount * size_; + return *this; + } + std::ptrdiff_t operator-(const EntryIterator &other) const { + return (ptr_ - other.ptr_) / size_; } + + const void *Data() const { return ptr_; } + void *Data() { return ptr_; } + std::size_t EntrySize() const { return size_; } + + private: + uint8_t *ptr_; + std::size_t size_; }; -template <unsigned char Order> class ProbEntry { +class EntryProxy { public: - typedef Prob Weights; - static const unsigned char kOrder = Order; + EntryProxy() {} + + EntryProxy(void *ptr, std::size_t size) : inner_(ptr, size) {} + + operator std::string() const { + return std::string(reinterpret_cast<const char*>(inner_.Data()), inner_.EntrySize()); + } + + EntryProxy &operator=(const EntryProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize()); + return *this; + } + + EntryProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), inner_.EntrySize()); + return *this; + } + + const WordIndex *Indices() const { + return static_cast<const WordIndex*>(inner_.Data()); + } + + private: + friend class util::ProxyIterator<EntryProxy>; + + typedef std::string value_type; - // reverse order - WordIndex words[Order]; - Weights weights; + typedef EntryIterator InnerIterator; + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; - bool operator<(const ProbEntry<Order> &other) const { - for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { - if (*i < *j) return true; - if (*i > *j) return false; +typedef util::ProxyIterator<EntryProxy> NGramIter; + +class CompareRecords : public std::binary_function<const EntryProxy &, const EntryProxy &, bool> { + public: + explicit CompareRecords(unsigned char order) : order_(order) {} + + bool operator()(const EntryProxy &first, const EntryProxy &second) const { + return Compare(first.Indices(), second.Indices()); + } + bool operator()(const EntryProxy &first, const std::string &second) const { + return Compare(first.Indices(), reinterpret_cast<const WordIndex*>(second.data())); + } + bool operator()(const std::string &first, const EntryProxy &second) const { + return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices()); + } + bool operator()(const std::string &first, const std::string &second) const { + return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(first.data())); + } + + private: + bool Compare(const WordIndex *first, const WordIndex *second) const { + const WordIndex *end = first + order_; + for (; first != end; ++first, ++second) { + if (*first < *second) return true; + if (*first > *second) return false; } return false; } + + unsigned char order_; }; void WriteOrThrow(FILE *to, const void *data, size_t size) { + assert(size); if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); } @@ -84,21 +162,24 @@ void CopyOrThrow(FILE *from, FILE *to, size_t size) { } } -template <class Entry> std::string DiskFlush(const Entry *begin, const Entry *end, const std::string &file_prefix, std::size_t batch) { +std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) { + const std::size_t entry_size = sizeof(WordIndex) * order + weights_size; + const std::size_t prefix_size = sizeof(WordIndex) * (order - 1); std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << '_' << batch; + assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch; std::string ret(assembled.str()); util::scoped_FILE out(fopen(ret.c_str(), "w")); if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing"); - for (const Entry *group_begin = begin; group_begin != end;) { - const Entry *group_end = group_begin; - for (++group_end; (group_end != end) && !memcmp(group_begin->words, group_end->words, sizeof(WordIndex) * (Entry::kOrder - 1)); ++group_end) {} - WriteOrThrow(out.get(), group_begin->words, sizeof(WordIndex) * (Entry::kOrder - 1)); - WordIndex group_size = group_end - group_begin; + // Compress entries that being with the same (order-1) words. + for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) { + const uint8_t *group_end = group_begin; + for (group_end += entry_size; (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size); group_end += entry_size) {} + WriteOrThrow(out.get(), group_begin, prefix_size); + WordIndex group_size = (group_end - group_begin) / entry_size; WriteOrThrow(out.get(), &group_size, sizeof(group_size)); - for (const Entry *i = group_begin; i != group_end; ++i) { - WriteOrThrow(out.get(), &i->words[Entry::kOrder - 1], sizeof(WordIndex)); - WriteOrThrow(out.get(), &i->weights, sizeof(typename Entry::Weights)); + for (const uint8_t *i = group_begin; i != group_end; i += entry_size) { + WriteOrThrow(out.get(), i + prefix_size, sizeof(WordIndex)); + WriteOrThrow(out.get(), i + sizeof(WordIndex) * order, weights_size); } group_begin = group_end; } @@ -219,25 +300,37 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha } } -template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix) { - ConvertToSorted<FullEntry<Entry::kOrder - 1> >(f, vocab, counts, mem, file_prefix); - - ReadNGramHeader(f, Entry::kOrder); - const size_t count = counts[Entry::kOrder - 1]; - const size_t batch_size = std::min(count, mem.size() / sizeof(Entry)); - Entry *const begin = reinterpret_cast<Entry*>(mem.get()); +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { + if (order == 1) return; + ConvertToSorted(f, vocab, counts, mem, file_prefix, order - 1); + + ReadNGramHeader(f, order); + const size_t count = counts[order - 1]; + // Size of weights. Does it include backoff? + const size_t words_size = sizeof(WordIndex) * order; + const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); + const size_t entry_size = words_size + weights_size; + const size_t batch_size = std::min(count, mem.size() / entry_size); + uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get()); std::deque<std::string> files; for (std::size_t batch = 0, done = 0; done < count; ++batch) { - Entry *out = begin; - Entry *out_end = out + std::min(count - done, batch_size); - for (; out != out_end; ++out) { - ReadNGram(f, Entry::kOrder, vocab, out->words, out->weights); + uint8_t *out = begin; + uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; + if (order == counts.size()) { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size)); + } + } else { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size)); + } } - //__gnu_parallel::sort(begin, out_end); - std::sort(begin, out_end); + // TODO: __gnu_parallel::sort here. + EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); + std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order)); - files.push_back(DiskFlush(begin, out_end, file_prefix, batch)); - done += out_end - begin; + files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); + done += (out_end - begin) / entry_size; } // All individual files created. Merge them. @@ -245,9 +338,9 @@ template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVoca std::size_t merge_count = 0; while (files.size() > 1) { std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merge_" << (merge_count++); + assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++); files.push_back(assembled.str()); - MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), sizeof(typename Entry::Weights), Entry::kOrder); + MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), weights_size, order); if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); files.pop_front(); if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); @@ -255,14 +348,12 @@ template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVoca } if (!files.empty()) { std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merged"; + assembled << file_prefix << static_cast<unsigned int>(order) << "_merged"; std::string merged_name(assembled.str()); if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); } } -template <> void ConvertToSorted<FullEntry<1> >(util::FilePiece &/*f*/, const SortedVocabulary &/*vocab*/, const std::vector<uint64_t> &/*counts*/, util::scoped_memory &/*mem*/, const std::string &/*file_prefix*/) {} - void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { { std::string unigram_name = file_prefix + "unigrams"; @@ -275,7 +366,7 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, util::scoped_memory mem; mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED); if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); - ConvertToSorted<ProbEntry<5> >(f, vocab, counts, mem, file_prefix); + ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size()); ReadEnd(f); } @@ -390,7 +481,8 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const temporary_directory.resize(strlen(temporary_directory.c_str())); // Add directory delimiter. Assumes a real operating system. temporary_directory += '/'; - ARPAToSortedFiles(f, counts, config.building_memory, temporary_directory.c_str(), vocab); + // At least 1MB sorting memory. + ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); BuildTrie(temporary_directory.c_str(), counts, config.messages, *this); if (rmdir(temporary_directory.c_str())) { std::cerr << "Failed to delete " << temporary_directory << std::endl; diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 8ed7b2a2..04bd2079 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -15,21 +15,21 @@ namespace { // Assumes key is first. class JustKeyProxy { public: - JustKeyProxy() : inner_(), base_(), key_mask_(), total_bits_() {} + JustKeyProxy() : inner_(), base_(), key_mask_(), key_bits_(), total_bits_() {} operator uint64_t() const { return GetKey(); } uint64_t GetKey() const { uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_); - return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_mask_); + return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_bits_, key_mask_); } private: friend class util::ProxyIterator<JustKeyProxy>; - friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); + friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); - JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t total_bits) - : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), total_bits_(total_bits) {} + JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) + : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {} // This is a read-only iterator. JustKeyProxy &operator=(const JustKeyProxy &other); @@ -44,12 +44,12 @@ class JustKeyProxy { uint64_t inner_; const uint8_t *const base_; const uint64_t key_mask_; - const uint8_t total_bits_; + const uint8_t key_bits_, total_bits_; }; -bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { - util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, total_bits)); - util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, total_bits)); +bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { + util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits)); + util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits)); util::ProxyIterator<JustKeyProxy> out; if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false; at_index = out.Inner(); @@ -96,67 +96,67 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t assert(next <= next_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word); + util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word); at_pointer += word_bits_; util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); at_pointer += prob_bits_; util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff); at_pointer += backoff_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next); + util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next); ++insert_index_; } bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; at_pointer *= total_bits_; at_pointer += word_bits_; prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); at_pointer += prob_bits_; backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); return true; } bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; at_pointer *= total_bits_; at_pointer += word_bits_; at_pointer += prob_bits_; backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); return true; } void BitPackedMiddle::FinishedLoading(uint64_t next_end) { assert(next_end <= next_mask_); uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; - util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_end); + util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end); } void BitPackedLongest::Insert(WordIndex index, float prob) { assert(index <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, index); + util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, index); at_pointer += word_bits_; util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); ++insert_index_; } -bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &node) const { +bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, total_bits_, node.begin, node.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; at_pointer = at_pointer * total_bits_ + word_bits_; prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); return true; diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc index dd14ffe1..9d4fdf27 100644 --- a/klm/util/bit_packing.cc +++ b/klm/util/bit_packing.cc @@ -1,12 +1,15 @@ #include "util/bit_packing.hh" #include "util/exception.hh" +#include <string.h> + namespace util { namespace { template <bool> struct StaticCheck {}; template <> struct StaticCheck<true> { typedef bool StaticAssertionPassed; }; +// If your float isn't 4 bytes, we're hosed. typedef StaticCheck<sizeof(float) == 4>::StaticAssertionPassed FloatSize; } // namespace @@ -21,6 +24,16 @@ uint8_t RequiredBits(uint64_t max_value) { void BitPackingSanity() { const detail::FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 }; if ((neg1.i ^ pos1.i) != 0x80000000) UTIL_THROW(Exception, "Sign bit is not 0x80000000"); + char mem[57+8]; + memset(mem, 0, sizeof(mem)); + const uint64_t test57 = 0x123456789abcdefULL; + for (uint64_t b = 0; b < 57 * 8; b += 57) { + WriteInt57(mem + b / 8, b % 8, 57, test57); + } + for (uint64_t b = 0; b < 57 * 8; b += 57) { + if (test57 != ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)) + UTIL_THROW(Exception, "The bit packing routines are failing for your architecture. Please send a bug report with your architecture, operating system, and compiler."); + } // TODO: more checks. } diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 422ed873..0fd39d7f 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -6,56 +6,68 @@ #include <assert.h> #ifdef __APPLE__ #include <architecture/byte_order.h> -#else +#elif __linux__ #include <endian.h> -#endif +#else +#include <arpa/nameser_compat.h> +#endif #include <inttypes.h> -#if __BYTE_ORDER != __LITTLE_ENDIAN -#error The bit aligned storage functions assume little endian architecture -#endif - namespace util { /* WARNING WARNING WARNING: * The write functions assume that memory is zero initially. This makes them * faster and is the appropriate case for mmapped language model construction. * These routines assume that unaligned access to uint64_t is fast and that - * storage is little endian. This is the case on x86_64. It may not be the - * case on 32-bit x86 but my target audience is large language models for which - * 64-bit is necessary. + * storage is little endian. This is the case on x86_64. I'm not sure how + * fast unaligned 64-bit access is on x86 but my target audience is large + * language models for which 64-bit is necessary. + * + * Call the BitPackingSanity function to sanity check. Calling once suffices, + * but it may be called multiple times when that's inconvenient. */ +inline uint8_t BitPackShift(uint8_t bit, uint8_t length) { +// Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct. +#if BYTE_ORDER == LITTLE_ENDIAN + return bit; +#elif BYTE_ORDER == BIG_ENDIAN + return 64 - length - bit; +#else +#error "Bit packing code isn't written for your byte order." +#endif +} + /* Pack integers up to 57 bits using their least significant digits. * The length is specified using mask: * Assumes mask == (1 << length) - 1 where length <= 57. */ -inline uint64_t ReadInt57(const void *base, uint8_t bit, uint64_t mask) { - return (*reinterpret_cast<const uint64_t*>(base) >> bit) & mask; +inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) { + return (*reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, length)) & mask; } -/* Assumes value <= mask and mask == (1 << length) - 1 where length <= 57. +/* Assumes value < (1 << length) and length <= 57. * Assumes the memory is zero initially. */ -inline void WriteInt57(void *base, uint8_t bit, uint64_t value) { - *reinterpret_cast<uint64_t*>(base) |= (value << bit); +inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) { + *reinterpret_cast<uint64_t*>(base) |= (value << BitPackShift(bit, length)); } namespace detail { typedef union { float f; uint32_t i; } FloatEnc; } inline float ReadFloat32(const void *base, uint8_t bit) { detail::FloatEnc encoded; - encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit; + encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 32); return encoded.f; } inline void WriteFloat32(void *base, uint8_t bit, float value) { detail::FloatEnc encoded; encoded.f = value; - WriteInt57(base, bit, encoded.i); + WriteInt57(base, bit, 32, encoded.i); } inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) { detail::FloatEnc encoded; - encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit; + encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 31); // Sign bit set means negative. encoded.i |= 0x80000000; return encoded.f; @@ -65,7 +77,7 @@ inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) { detail::FloatEnc encoded; encoded.f = value; encoded.i &= ~0x80000000; - WriteInt57(base, bit, encoded.i); + WriteInt57(base, bit, 31, encoded.i); } void BitPackingSanity(); diff --git a/klm/util/bit_packing_test.cc b/klm/util/bit_packing_test.cc new file mode 100644 index 00000000..c578ddd1 --- /dev/null +++ b/klm/util/bit_packing_test.cc @@ -0,0 +1,46 @@ +#include "util/bit_packing.hh" + +#define BOOST_TEST_MODULE BitPackingTest +#include <boost/test/unit_test.hpp> + +#include <string.h> + +namespace util { +namespace { + +const uint64_t test57 = 0x123456789abcdefULL; + +BOOST_AUTO_TEST_CASE(ZeroBit) { + char mem[16]; + memset(mem, 0, sizeof(mem)); + WriteInt57(mem, 0, 57, test57); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, 0, 57, (1ULL << 57) - 1)); +} + +BOOST_AUTO_TEST_CASE(EachBit) { + char mem[16]; + for (uint8_t b = 0; b < 8; ++b) { + memset(mem, 0, sizeof(mem)); + WriteInt57(mem, b, 57, test57); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); + } +} + +BOOST_AUTO_TEST_CASE(Consecutive) { + char mem[57+8]; + memset(mem, 0, sizeof(mem)); + for (uint64_t b = 0; b < 57 * 8; b += 57) { + WriteInt57(mem + (b / 8), b % 8, 57, test57); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + } + for (uint64_t b = 0; b < 57 * 8; b += 57) { + BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + } +} + +BOOST_AUTO_TEST_CASE(Sanity) { + BitPackingSanity(); +} + +} // namespace +} // namespace util diff --git a/klm/util/exception.cc b/klm/util/exception.cc index dd337a76..de6dd43c 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -24,7 +24,12 @@ const char *HandleStrerror(const char *ret, const char *buf) { ErrnoException::ErrnoException() throw() : errno_(errno) { char buf[200]; buf[0] = 0; +#ifdef sun + const char *add = strerror(errno); +#else const char *add = HandleStrerror(strerror_r(errno, buf, 200), buf); +#endif + if (add) { *this << add << ' '; } diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index e7bd8659..5a667ebb 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -2,12 +2,12 @@ #include "util/exception.hh" +#include <iostream> #include <string> #include <limits> #include <assert.h> #include <ctype.h> -#include <err.h> #include <fcntl.h> #include <stdlib.h> #include <sys/mman.h> @@ -27,7 +27,7 @@ EndOfFileException::EndOfFileException() throw() { EndOfFileException::~EndOfFileException() throw() {} ParseNumberException::ParseNumberException(StringPiece value) throw() { - *this << "Could not parse \"" << value << "\" into a float"; + *this << "Could not parse \"" << value << "\" into a number"; } GZException::GZException(void *file) { @@ -68,12 +68,52 @@ FilePiece::~FilePiece() { file_.release(); int ret; if (Z_OK != (ret = gzclose(gz_file_))) { - errx(1, "could not close file %s using zlib", file_name_.c_str()); + std::cerr << "could not close file " << file_name_ << " using zlib" << std::endl; + abort(); } } #endif } +StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) { + const char *start = position_; + do { + for (const char *i = start; i < position_end_; ++i) { + if (*i == delim) { + StringPiece ret(position_, i - position_); + position_ = i + 1; + return ret; + } + } + size_t skip = position_end_ - position_; + Shift(); + start = position_ + skip; + } while (!at_end_); + StringPiece ret(position_, position_end_ - position_); + position_ = position_end_; + return ret; +} + +float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) { + return ReadNumber<float>(); +} +double FilePiece::ReadDouble() throw(GZException, EndOfFileException, ParseNumberException) { + return ReadNumber<double>(); +} +long int FilePiece::ReadLong() throw(GZException, EndOfFileException, ParseNumberException) { + return ReadNumber<long int>(); +} +unsigned long int FilePiece::ReadULong() throw(GZException, EndOfFileException, ParseNumberException) { + return ReadNumber<unsigned long int>(); +} + +void FilePiece::SkipSpaces() throw (GZException, EndOfFileException) { + for (; ; ++position_) { + if (position_ == position_end_) Shift(); + if (!isspace(*position_)) return; + } +} + void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) { #ifdef HAVE_ZLIB gz_file_ = NULL; @@ -108,14 +148,34 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t } } -float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) { +namespace { +void ParseNumber(const char *begin, char *&end, float &out) { +#ifdef sun + out = static_cast<float>(strtod(begin, &end)); +#else + out = strtof(begin, &end); +#endif +} +void ParseNumber(const char *begin, char *&end, double &out) { + out = strtod(begin, &end); +} +void ParseNumber(const char *begin, char *&end, long int &out) { + out = strtol(begin, &end, 10); +} +void ParseNumber(const char *begin, char *&end, unsigned long int &out) { + out = strtoul(begin, &end, 10); +} +} // namespace + +template <class T> T FilePiece::ReadNumber() throw(GZException, EndOfFileException, ParseNumberException) { SkipSpaces(); while (last_space_ < position_) { if (at_end_) { // Hallucinate a null off the end of the file. std::string buffer(position_, position_end_); char *end; - float ret = strtof(buffer.c_str(), &end); + T ret; + ParseNumber(buffer.c_str(), end, ret); if (buffer.c_str() == end) throw ParseNumberException(buffer); position_ += end - buffer.c_str(); return ret; @@ -123,19 +183,13 @@ float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberE Shift(); } char *end; - float ret = strtof(position_, &end); + T ret; + ParseNumber(position_, end, ret); if (end == position_) throw ParseNumberException(ReadDelimited()); position_ = end; return ret; } -void FilePiece::SkipSpaces() throw (GZException, EndOfFileException) { - for (; ; ++position_) { - if (position_ == position_end_) Shift(); - if (!isspace(*position_)) return; - } -} - const char *FilePiece::FindDelimiterOrEOF() throw (GZException, EndOfFileException) { for (const char *i = position_; i <= last_space_; ++i) { if (isspace(*i)) return i; @@ -150,25 +204,6 @@ const char *FilePiece::FindDelimiterOrEOF() throw (GZException, EndOfFileExcepti return position_end_; } -StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) { - const char *start = position_; - do { - for (const char *i = start; i < position_end_; ++i) { - if (*i == delim) { - StringPiece ret(position_, i - position_); - position_ = i + 1; - return ret; - } - } - size_t skip = position_end_ - position_; - Shift(); - start = position_ + skip; - } while (!at_end_); - StringPiece ret(position_, position_end_ - position_); - position_ = position_end_; - return ret; -} - void FilePiece::Shift() throw(GZException, EndOfFileException) { if (at_end_) { progress_.Finished(); diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index 11d4a751..b7697e71 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -68,6 +68,9 @@ class FilePiece { StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException); float ReadFloat() throw(GZException, EndOfFileException, ParseNumberException); + double ReadDouble() throw(GZException, EndOfFileException, ParseNumberException); + long int ReadLong() throw(GZException, EndOfFileException, ParseNumberException); + unsigned long int ReadULong() throw(GZException, EndOfFileException, ParseNumberException); void SkipSpaces() throw (GZException, EndOfFileException); @@ -80,6 +83,8 @@ class FilePiece { private: void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw(GZException); + template <class T> T ReadNumber() throw(GZException, EndOfFileException, ParseNumberException); + StringPiece Consume(const char *to) { StringPiece ret(position_, to - position_); position_ = to; diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc index 23e79fe0..dc9ec7e7 100644 --- a/klm/util/file_piece_test.cc +++ b/klm/util/file_piece_test.cc @@ -8,6 +8,8 @@ #include <iostream> #include <stdio.h> +#include <sys/types.h> +#include <sys/stat.h> namespace util { namespace { @@ -27,14 +29,18 @@ BOOST_AUTO_TEST_CASE(MMapReadLine) { BOOST_CHECK_THROW(test.get(), EndOfFileException); } +#ifndef __APPLE__ +/* Apple isn't happy with the popen, fileno, dup. And I don't want to + * reimplement popen. This is an issue with the test. + */ /* read() implementation */ BOOST_AUTO_TEST_CASE(StreamReadLine) { std::fstream ref("file_piece.cc", std::ios::in); - scoped_FILE catter(popen("cat file_piece.cc", "r")); - BOOST_REQUIRE(catter.get()); + FILE *catter = popen("cat file_piece.cc", "r"); + BOOST_REQUIRE(catter); - FilePiece test(dup(fileno(catter.get())), "file_piece.cc", NULL, 1); + FilePiece test(dup(fileno(catter)), "file_piece.cc", NULL, 1); std::string ref_line; while (getline(ref, ref_line)) { StringPiece test_line(test.ReadLine()); @@ -44,7 +50,9 @@ BOOST_AUTO_TEST_CASE(StreamReadLine) { } } BOOST_CHECK_THROW(test.get(), EndOfFileException); + BOOST_REQUIRE(!pclose(catter)); } +#endif // __APPLE__ #ifdef HAVE_ZLIB @@ -64,14 +72,17 @@ BOOST_AUTO_TEST_CASE(PlainZipReadLine) { } BOOST_CHECK_THROW(test.get(), EndOfFileException); } -// gzip stream + +// gzip stream. Apple doesn't like popen, fileno, dup. This is an issue with +// the test. +#ifndef __APPLE__ BOOST_AUTO_TEST_CASE(StreamZipReadLine) { std::fstream ref("file_piece.cc", std::ios::in); - scoped_FILE catter(popen("gzip <file_piece.cc", "r")); - BOOST_REQUIRE(catter.get()); + FILE * catter = popen("gzip <file_piece.cc", "r"); + BOOST_REQUIRE(catter); - FilePiece test(dup(fileno(catter.get())), "file_piece.cc", NULL, 1); + FilePiece test(dup(fileno(catter)), "file_piece.cc", NULL, 1); std::string ref_line; while (getline(ref, ref_line)) { StringPiece test_line(test.ReadLine()); @@ -81,9 +92,11 @@ BOOST_AUTO_TEST_CASE(StreamZipReadLine) { } } BOOST_CHECK_THROW(test.get(), EndOfFileException); + BOOST_REQUIRE(!pclose(catter)); } +#endif // __APPLE__ -#endif +#endif // HAVE_ZLIB } // namespace } // namespace util diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index 8685170f..5a810c64 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -2,8 +2,9 @@ #include "util/mmap.hh" #include "util/scoped.hh" +#include <iostream> + #include <assert.h> -#include <err.h> #include <fcntl.h> #include <sys/types.h> #include <sys/mman.h> @@ -14,8 +15,10 @@ namespace util { scoped_mmap::~scoped_mmap() { if (data_ != (void*)-1) { - if (munmap(data_, size_)) - err(1, "munmap failed "); + if (munmap(data_, size_)) { + std::cerr << "munmap failed for " << size_ << " bytes." << std::endl; + abort(); + } } } @@ -73,18 +76,27 @@ void ReadAll(int fd, void *to_void, std::size_t amount) { to += ret; } } + +const int kFileFlags = +#ifdef MAP_FILE + MAP_FILE | MAP_SHARED +#else + MAP_SHARED +#endif + ; + } // namespace void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out) { switch (method) { case LAZY: - out.reset(MapOrThrow(size, false, MAP_FILE | MAP_SHARED, false, fd, offset), size, scoped_memory::MMAP_ALLOCATED); + out.reset(MapOrThrow(size, false, kFileFlags, false, fd, offset), size, scoped_memory::MMAP_ALLOCATED); break; case POPULATE_OR_LAZY: #ifdef MAP_POPULATE case POPULATE_OR_READ: #endif - out.reset(MapOrThrow(size, false, MAP_FILE | MAP_SHARED, true, fd, offset), size, scoped_memory::MMAP_ALLOCATED); + out.reset(MapOrThrow(size, false, kFileFlags, true, fd, offset), size, scoped_memory::MMAP_ALLOCATED); break; #ifndef MAP_POPULATE case POPULATE_OR_READ: @@ -115,7 +127,7 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) { if (-1 == ftruncate(file.get(), size)) UTIL_THROW(ErrnoException, "ftruncate on " << name << " to " << size << " failed"); try { - return MapOrThrow(size, true, MAP_FILE | MAP_SHARED, false, file.get(), 0); + return MapOrThrow(size, true, kFileFlags, false, file.get(), 0); } catch (ErrnoException &e) { e << " in file " << name; throw; diff --git a/klm/util/murmur_hash.hh b/klm/util/murmur_hash.hh index 638aaeb2..78fe583f 100644 --- a/klm/util/murmur_hash.hh +++ b/klm/util/murmur_hash.hh @@ -1,7 +1,7 @@ #ifndef UTIL_MURMUR_HASH__ #define UTIL_MURMUR_HASH__ #include <cstddef> -#include <stdint.h> +#include <inttypes.h> namespace util { diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc index 2c6d5394..a4cc5016 100644 --- a/klm/util/scoped.cc +++ b/klm/util/scoped.cc @@ -1,16 +1,24 @@ #include "util/scoped.hh" -#include <err.h> +#include <iostream> + +#include <stdlib.h> #include <unistd.h> namespace util { scoped_fd::~scoped_fd() { - if (fd_ != -1 && close(fd_)) err(1, "Could not close file %i", fd_); + if (fd_ != -1 && close(fd_)) { + std::cerr << "Could not close file " << fd_ << std::endl; + abort(); + } } scoped_FILE::~scoped_FILE() { - if (file_ && fclose(file_)) err(1, "Could not close file"); + if (file_ && fclose(file_)) { + std::cerr << "Could not close file " << std::endl; + abort(); + } } } // namespace util diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh index 4557173b..3ac2f8a7 100644 --- a/klm/util/string_piece.hh +++ b/klm/util/string_piece.hh @@ -51,7 +51,7 @@ //Uncomment this line if you use ICU in your code. //#define HAVE_ICU //Uncomment this line if you want boost hashing for your StringPieces. -#define HAVE_BOOST +//#define HAVE_BOOST #include <cstring> #include <iosfwd> |