summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2010-12-13 16:18:34 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2010-12-13 16:18:34 -0500
commitbe98f29f51350c24136c191f01af3fbfe340ef78 (patch)
tree2e104152110ca76b527147458050a41934e031f2 /klm
parent063c0623aaf5dad8d02e5eae5793c123cd7fc3fe (diff)
new version of kenlm
Diffstat (limited to 'klm')
-rw-r--r--klm/lm/binary_format.cc1
-rw-r--r--klm/lm/binary_format.hh9
-rw-r--r--klm/lm/build_binary.cc112
-rw-r--r--klm/lm/enumerate_vocab.hh7
-rw-r--r--klm/lm/lm_exception.cc8
-rw-r--r--klm/lm/lm_exception.hh18
-rw-r--r--klm/lm/model.cc1
-rw-r--r--klm/lm/model.hh23
-rw-r--r--klm/lm/model_test.cc4
-rw-r--r--klm/lm/read_arpa.cc4
-rw-r--r--klm/lm/read_arpa.hh2
-rw-r--r--klm/lm/search_trie.cc200
-rw-r--r--klm/lm/trie.cc42
-rw-r--r--klm/util/bit_packing.cc13
-rw-r--r--klm/util/bit_packing.hh48
-rw-r--r--klm/util/bit_packing_test.cc46
-rw-r--r--klm/util/exception.cc5
-rw-r--r--klm/util/file_piece.cc99
-rw-r--r--klm/util/file_piece.hh5
-rw-r--r--klm/util/file_piece_test.cc29
-rw-r--r--klm/util/mmap.cc24
-rw-r--r--klm/util/murmur_hash.hh2
-rw-r--r--klm/util/scoped.cc14
-rw-r--r--klm/util/string_piece.hh2
24 files changed, 547 insertions, 171 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 2a075b6b..69a06355 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -141,7 +141,6 @@ uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t
}
uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) {
- if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0");
if (config.write_mmap) {
std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size;
// Write out an mmap file.
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh
index f95f05f7..a43c883c 100644
--- a/klm/lm/binary_format.hh
+++ b/klm/lm/binary_format.hh
@@ -67,9 +67,12 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)
if (detail::IsBinaryFormat(backing.file.get())) {
detail::ReadHeader(backing.file.get(), params);
detail::MatchCheck(To::kModelType, params);
- std::size_t memory_size = To::Size(params.counts, config);
- uint8_t *start = detail::SetupBinary(config, params, memory_size, backing);
- to.InitializeFromBinary(start, params, config, backing.file.get());
+ // Replace the probing_multiplier.
+ Config new_config(config);
+ new_config.probing_multiplier = params.fixed.probing_multiplier;
+ std::size_t memory_size = To::Size(params.counts, new_config);
+ uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing);
+ to.InitializeFromBinary(start, params, new_config, backing.file.get());
} else {
detail::ComplainAboutARPA(config, To::kModelType);
util::FilePiece f(backing.file.release(), file, config.messages);
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index 4db631a2..ec034640 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -1,13 +1,113 @@
#include "lm/model.hh"
+#include "util/file_piece.hh"
#include <iostream>
+#include <iomanip>
+
+#include <math.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+namespace lm {
+namespace ngram {
+namespace {
+
+void Usage(const char *name) {
+ std::cerr << "Usage: " << name << " [-u unknown_probability] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n"
+"Where type is one of probing, trie, or sorted:\n\n"
+"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
+"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n"
+"trie is a straightforward trie with bit-level packing. It uses the least\n"
+"memory and is still faster than SRI or IRST. Building the trie format uses an\n"
+"on-disk sort to save memory.\n"
+"-t is the temporary directory prefix. Default is the output file name.\n"
+"-m is the amount of memory to use, in MB. Default is 1024MB (1GB).\n\n"
+"sorted is like probing but uses a sorted uniform map instead of a hash table.\n"
+"It uses more memory than trie and is also slower, so there's no real reason to\n"
+"use it.\n\n"
+"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n"
+"Passing only an input file will print memory usage of each data structure.\n"
+"If the ARPA file does not have <unk>, -u sets <unk>'s probability; default 0.0.\n";
+ exit(1);
+}
+
+// I could really use boost::lexical_cast right about now.
+float ParseFloat(const char *from) {
+ char *end;
+ float ret = strtod(from, &end);
+ if (*end) throw util::ParseNumberException(from);
+ return ret;
+}
+unsigned long int ParseUInt(const char *from) {
+ char *end;
+ unsigned long int ret = strtoul(from, &end, 10);
+ if (*end) throw util::ParseNumberException(from);
+ return ret;
+}
+
+void ShowSizes(const char *file, const lm::ngram::Config &config) {
+ std::vector<uint64_t> counts;
+ util::FilePiece f(file);
+ lm::ReadARPACounts(f, counts);
+ std::size_t probing_size = ProbingModel::Size(counts, config);
+ // probing is always largest so use it to determine number of columns.
+ long int length = std::max<long int>(5, lrint(ceil(log10(probing_size))));
+ std::cout << "Memory usage:\ntype ";
+ // right align bytes.
+ for (long int i = 0; i < length - 5; ++i) std::cout << ' ';
+ std::cout << "bytes\n"
+ "probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n"
+ "trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n"
+ "sorted " << std::setw(length) << SortedModel::Size(counts, config) << "\n";
+}
+
+} // namespace ngram
+} // namespace lm
+} // namespace
int main(int argc, char *argv[]) {
- if (argc != 3) {
- std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl;
- return 1;
- }
+ using namespace lm::ngram;
+
lm::ngram::Config config;
- config.write_mmap = argv[2];
- lm::ngram::Model(argv[1], config);
+ int opt;
+ while ((opt = getopt(argc, argv, "u:p:t:m:")) != -1) {
+ switch(opt) {
+ case 'u':
+ config.unknown_missing_prob = ParseFloat(optarg);
+ break;
+ case 'p':
+ config.probing_multiplier = ParseFloat(optarg);
+ break;
+ case 't':
+ config.temporary_directory_prefix = optarg;
+ break;
+ case 'm':
+ config.building_memory = ParseUInt(optarg) * 1048576;
+ break;
+ default:
+ Usage(argv[0]);
+ }
+ }
+ if (optind + 1 == argc) {
+ ShowSizes(argv[optind], config);
+ } else if (optind + 2 == argc) {
+ config.write_mmap = argv[optind + 1];
+ ProbingModel(argv[optind], config);
+ } else if (optind + 3 == argc) {
+ const char *model_type = argv[optind];
+ const char *from_file = argv[optind + 1];
+ config.write_mmap = argv[optind + 2];
+ if (!strcmp(model_type, "probing")) {
+ ProbingModel(from_file, config);
+ } else if (!strcmp(model_type, "sorted")) {
+ SortedModel(from_file, config);
+ } else if (!strcmp(model_type, "trie")) {
+ TrieModel(from_file, config);
+ } else {
+ Usage(argv[0]);
+ }
+ } else {
+ Usage(argv[0]);
+ }
+ return 0;
}
diff --git a/klm/lm/enumerate_vocab.hh b/klm/lm/enumerate_vocab.hh
index 7a2f7d12..e734316b 100644
--- a/klm/lm/enumerate_vocab.hh
+++ b/klm/lm/enumerate_vocab.hh
@@ -8,9 +8,10 @@ namespace lm {
namespace ngram {
/* If you need the actual strings in the vocabulary, inherit from this class
- * and implement Add. Then put a pointer in Config.enumerate_vocab.
- * Add is called once per n-gram. index starts at 0 and increases by 1 each
- * time.
+ * and implement Add. Then put a pointer in Config.enumerate_vocab; it does
+ * not take ownership. Add is called once per vocab word. index starts at 0
+ * and increases by 1 each time. This is only used by the Model constructor;
+ * the pointer is not retained by the class.
*/
class EnumerateVocab {
public:
diff --git a/klm/lm/lm_exception.cc b/klm/lm/lm_exception.cc
index ab2ec52f..473849d1 100644
--- a/klm/lm/lm_exception.cc
+++ b/klm/lm/lm_exception.cc
@@ -5,14 +5,18 @@
namespace lm {
+ConfigException::ConfigException() throw() {}
+ConfigException::~ConfigException() throw() {}
+
LoadException::LoadException() throw() {}
LoadException::~LoadException() throw() {}
-VocabLoadException::VocabLoadException() throw() {}
-VocabLoadException::~VocabLoadException() throw() {}
FormatLoadException::FormatLoadException() throw() {}
FormatLoadException::~FormatLoadException() throw() {}
+VocabLoadException::VocabLoadException() throw() {}
+VocabLoadException::~VocabLoadException() throw() {}
+
SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() {
*this << "Missing special word " << which;
}
diff --git a/klm/lm/lm_exception.hh b/klm/lm/lm_exception.hh
index 1216c4c7..3773c572 100644
--- a/klm/lm/lm_exception.hh
+++ b/klm/lm/lm_exception.hh
@@ -11,6 +11,12 @@
namespace lm {
+class ConfigException : public util::Exception {
+ public:
+ ConfigException() throw();
+ ~ConfigException() throw();
+};
+
class LoadException : public util::Exception {
public:
virtual ~LoadException() throw();
@@ -19,18 +25,18 @@ class LoadException : public util::Exception {
LoadException() throw();
};
-class VocabLoadException : public LoadException {
- public:
- virtual ~VocabLoadException() throw();
- VocabLoadException() throw();
-};
-
class FormatLoadException : public LoadException {
public:
FormatLoadException() throw();
~FormatLoadException() throw();
};
+class VocabLoadException : public LoadException {
+ public:
+ virtual ~VocabLoadException() throw();
+ VocabLoadException() throw();
+};
+
class SpecialWordMissingException : public VocabLoadException {
public:
explicit SpecialWordMissingException(StringPiece which) throw();
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index 6921d4d9..421e72fa 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -23,6 +23,7 @@ namespace detail {
template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
+ if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index e0eeee17..53e5773d 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -12,6 +12,8 @@
#include <algorithm>
#include <vector>
+#include <string.h>
+
namespace util { class FilePiece; }
namespace lm {
@@ -21,9 +23,10 @@ namespace ngram {
// Having this limit means that State can be
// (kMaxOrder - 1) * sizeof(float) bytes instead of
// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
-const std::size_t kMaxOrder = 6;
+const unsigned char kMaxOrder = 6;
-// This is a POD.
+// This is a POD but if you want memcmp to return the same as operator==, call
+// ZeroRemaining first.
class State {
public:
bool operator==(const State &other) const {
@@ -37,6 +40,22 @@ class State {
return true;
}
+ // Three way comparison function.
+ int Compare(const State &other) const {
+ if (valid_length_ == other.valid_length_) {
+ return memcmp(history_, other.history_, valid_length_ * sizeof(WordIndex));
+ }
+ return (valid_length_ < other.valid_length_) ? -1 : 1;
+ }
+
+ // Call this before using raw memcmp.
+ void ZeroRemaining() {
+ for (unsigned char i = valid_length_; i < kMaxOrder - 1; ++i) {
+ history_[i] = 0;
+ backoff_[i] = 0.0;
+ }
+ }
+
// You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
// This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
WordIndex history_[kMaxOrder - 1];
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 159628d4..b5125a95 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -4,6 +4,7 @@
#define BOOST_TEST_MODULE ModelTest
#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
namespace lm {
namespace ngram {
@@ -123,7 +124,7 @@ class ExpectEnumerateVocab : public EnumerateVocab {
}
void Check(const base::Vocabulary &vocab) {
- BOOST_CHECK_EQUAL(34, seen.size());
+ BOOST_CHECK_EQUAL(34ULL, seen.size());
BOOST_REQUIRE(!seen.empty());
BOOST_CHECK_EQUAL("<unk>", seen[0]);
for (WordIndex i = 0; i < seen.size(); ++i) {
@@ -144,6 +145,7 @@ template <class ModelT> void LoadingTest() {
config.messages = NULL;
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
+ config.probing_multiplier = 2.0;
ModelT m("test.arpa", config);
enumerate.Check(m.GetVocabulary());
Starters(m);
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index 8e9a770d..262a9c6a 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -49,7 +49,7 @@ template <class F> void GenericReadNGramHeader(F &in, unsigned int length) {
while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
std::stringstream expected;
expected << '\\' << length << "-grams:";
- if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead. ");
+ if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead");
}
template <class F> void GenericReadEnd(F &in) {
@@ -110,7 +110,7 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
{
float got = in.ReadFloat();
if (got != 0.0)
- UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff.");
+ UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff");
}
break;
case '\n':
diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh
index cabdb195..571fcbc5 100644
--- a/klm/lm/read_arpa.hh
+++ b/klm/lm/read_arpa.hh
@@ -54,7 +54,7 @@ template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const uns
}
ReadBackoff(f, weights);
} catch(util::Exception &e) {
- e << " in the " << n << "-gram at byte " << f.Offset();
+ e << " in the " << static_cast<unsigned int>(n) << "-gram at byte " << f.Offset();
throw;
}
}
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 182e27f5..12294682 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -1,3 +1,4 @@
+/* This is where the trie is built. It's on-disk. */
#include "lm/search_trie.hh"
#include "lm/lm_exception.hh"
@@ -8,6 +9,7 @@
#include "lm/word_index.hh"
#include "util/ersatz_progress.hh"
#include "util/file_piece.hh"
+#include "util/proxy_iterator.hh"
#include "util/scoped.hh"
#include <algorithm>
@@ -30,43 +32,119 @@ namespace ngram {
namespace trie {
namespace {
-template <unsigned char Order> class FullEntry {
+/* An entry is a n-gram with probability. It consists of:
+ * WordIndex[order]
+ * float probability
+ * backoff probability (omitted for highest order n-gram)
+ * These are stored consecutively in memory. We want to sort them.
+ *
+ * The problem is the length depends on order (but all n-grams being compared
+ * have the same order). Allocating each entry on the heap (i.e. std::vector
+ * or std::string) then sorting pointers is the normal solution. But that's
+ * too memory inefficient. A lot of this code is just here to force std::sort
+ * to work with records where length is specified at runtime (and avoid using
+ * Boost for LM code). I could have used qsort, but the point is to also
+ * support __gnu_cxx:parallel_sort which doesn't have a qsort version.
+ */
+
+class EntryIterator {
public:
- typedef ProbBackoff Weights;
- static const unsigned char kOrder = Order;
+ EntryIterator() {}
- // reverse order
- WordIndex words[Order];
- Weights weights;
+ EntryIterator(void *ptr, std::size_t size) : ptr_(static_cast<uint8_t*>(ptr)), size_(size) {}
- bool operator<(const FullEntry<Order> &other) const {
- for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) {
- if (*i < *j) return true;
- if (*i > *j) return false;
- }
- return false;
+ bool operator==(const EntryIterator &other) const {
+ return ptr_ == other.ptr_;
+ }
+ bool operator<(const EntryIterator &other) const {
+ return ptr_ < other.ptr_;
+ }
+ EntryIterator &operator+=(std::ptrdiff_t amount) {
+ ptr_ += amount * size_;
+ return *this;
+ }
+ std::ptrdiff_t operator-(const EntryIterator &other) const {
+ return (ptr_ - other.ptr_) / size_;
}
+
+ const void *Data() const { return ptr_; }
+ void *Data() { return ptr_; }
+ std::size_t EntrySize() const { return size_; }
+
+ private:
+ uint8_t *ptr_;
+ std::size_t size_;
};
-template <unsigned char Order> class ProbEntry {
+class EntryProxy {
public:
- typedef Prob Weights;
- static const unsigned char kOrder = Order;
+ EntryProxy() {}
+
+ EntryProxy(void *ptr, std::size_t size) : inner_(ptr, size) {}
+
+ operator std::string() const {
+ return std::string(reinterpret_cast<const char*>(inner_.Data()), inner_.EntrySize());
+ }
+
+ EntryProxy &operator=(const EntryProxy &from) {
+ memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize());
+ return *this;
+ }
+
+ EntryProxy &operator=(const std::string &from) {
+ memcpy(inner_.Data(), from.data(), inner_.EntrySize());
+ return *this;
+ }
+
+ const WordIndex *Indices() const {
+ return static_cast<const WordIndex*>(inner_.Data());
+ }
+
+ private:
+ friend class util::ProxyIterator<EntryProxy>;
+
+ typedef std::string value_type;
- // reverse order
- WordIndex words[Order];
- Weights weights;
+ typedef EntryIterator InnerIterator;
+ InnerIterator &Inner() { return inner_; }
+ const InnerIterator &Inner() const { return inner_; }
+ InnerIterator inner_;
+};
- bool operator<(const ProbEntry<Order> &other) const {
- for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) {
- if (*i < *j) return true;
- if (*i > *j) return false;
+typedef util::ProxyIterator<EntryProxy> NGramIter;
+
+class CompareRecords : public std::binary_function<const EntryProxy &, const EntryProxy &, bool> {
+ public:
+ explicit CompareRecords(unsigned char order) : order_(order) {}
+
+ bool operator()(const EntryProxy &first, const EntryProxy &second) const {
+ return Compare(first.Indices(), second.Indices());
+ }
+ bool operator()(const EntryProxy &first, const std::string &second) const {
+ return Compare(first.Indices(), reinterpret_cast<const WordIndex*>(second.data()));
+ }
+ bool operator()(const std::string &first, const EntryProxy &second) const {
+ return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices());
+ }
+ bool operator()(const std::string &first, const std::string &second) const {
+ return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(first.data()));
+ }
+
+ private:
+ bool Compare(const WordIndex *first, const WordIndex *second) const {
+ const WordIndex *end = first + order_;
+ for (; first != end; ++first, ++second) {
+ if (*first < *second) return true;
+ if (*first > *second) return false;
}
return false;
}
+
+ unsigned char order_;
};
void WriteOrThrow(FILE *to, const void *data, size_t size) {
+ assert(size);
if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
}
@@ -84,21 +162,24 @@ void CopyOrThrow(FILE *from, FILE *to, size_t size) {
}
}
-template <class Entry> std::string DiskFlush(const Entry *begin, const Entry *end, const std::string &file_prefix, std::size_t batch) {
+std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) {
+ const std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
+ const std::size_t prefix_size = sizeof(WordIndex) * (order - 1);
std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << '_' << batch;
+ assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch;
std::string ret(assembled.str());
util::scoped_FILE out(fopen(ret.c_str(), "w"));
if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing");
- for (const Entry *group_begin = begin; group_begin != end;) {
- const Entry *group_end = group_begin;
- for (++group_end; (group_end != end) && !memcmp(group_begin->words, group_end->words, sizeof(WordIndex) * (Entry::kOrder - 1)); ++group_end) {}
- WriteOrThrow(out.get(), group_begin->words, sizeof(WordIndex) * (Entry::kOrder - 1));
- WordIndex group_size = group_end - group_begin;
+ // Compress entries that being with the same (order-1) words.
+ for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) {
+ const uint8_t *group_end = group_begin;
+ for (group_end += entry_size; (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size); group_end += entry_size) {}
+ WriteOrThrow(out.get(), group_begin, prefix_size);
+ WordIndex group_size = (group_end - group_begin) / entry_size;
WriteOrThrow(out.get(), &group_size, sizeof(group_size));
- for (const Entry *i = group_begin; i != group_end; ++i) {
- WriteOrThrow(out.get(), &i->words[Entry::kOrder - 1], sizeof(WordIndex));
- WriteOrThrow(out.get(), &i->weights, sizeof(typename Entry::Weights));
+ for (const uint8_t *i = group_begin; i != group_end; i += entry_size) {
+ WriteOrThrow(out.get(), i + prefix_size, sizeof(WordIndex));
+ WriteOrThrow(out.get(), i + sizeof(WordIndex) * order, weights_size);
}
group_begin = group_end;
}
@@ -219,25 +300,37 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha
}
}
-template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix) {
- ConvertToSorted<FullEntry<Entry::kOrder - 1> >(f, vocab, counts, mem, file_prefix);
-
- ReadNGramHeader(f, Entry::kOrder);
- const size_t count = counts[Entry::kOrder - 1];
- const size_t batch_size = std::min(count, mem.size() / sizeof(Entry));
- Entry *const begin = reinterpret_cast<Entry*>(mem.get());
+void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) {
+ if (order == 1) return;
+ ConvertToSorted(f, vocab, counts, mem, file_prefix, order - 1);
+
+ ReadNGramHeader(f, order);
+ const size_t count = counts[order - 1];
+ // Size of weights. Does it include backoff?
+ const size_t words_size = sizeof(WordIndex) * order;
+ const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float));
+ const size_t entry_size = words_size + weights_size;
+ const size_t batch_size = std::min(count, mem.size() / entry_size);
+ uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get());
std::deque<std::string> files;
for (std::size_t batch = 0, done = 0; done < count; ++batch) {
- Entry *out = begin;
- Entry *out_end = out + std::min(count - done, batch_size);
- for (; out != out_end; ++out) {
- ReadNGram(f, Entry::kOrder, vocab, out->words, out->weights);
+ uint8_t *out = begin;
+ uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
+ if (order == counts.size()) {
+ for (; out != out_end; out += entry_size) {
+ ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size));
+ }
+ } else {
+ for (; out != out_end; out += entry_size) {
+ ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size));
+ }
}
- //__gnu_parallel::sort(begin, out_end);
- std::sort(begin, out_end);
+ // TODO: __gnu_parallel::sort here.
+ EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
+ std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order));
- files.push_back(DiskFlush(begin, out_end, file_prefix, batch));
- done += out_end - begin;
+ files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size));
+ done += (out_end - begin) / entry_size;
}
// All individual files created. Merge them.
@@ -245,9 +338,9 @@ template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVoca
std::size_t merge_count = 0;
while (files.size() > 1) {
std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merge_" << (merge_count++);
+ assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++);
files.push_back(assembled.str());
- MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), sizeof(typename Entry::Weights), Entry::kOrder);
+ MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), weights_size, order);
if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
files.pop_front();
if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
@@ -255,14 +348,12 @@ template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVoca
}
if (!files.empty()) {
std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merged";
+ assembled << file_prefix << static_cast<unsigned int>(order) << "_merged";
std::string merged_name(assembled.str());
if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str());
}
}
-template <> void ConvertToSorted<FullEntry<1> >(util::FilePiece &/*f*/, const SortedVocabulary &/*vocab*/, const std::vector<uint64_t> &/*counts*/, util::scoped_memory &/*mem*/, const std::string &/*file_prefix*/) {}
-
void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
{
std::string unigram_name = file_prefix + "unigrams";
@@ -275,7 +366,7 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts,
util::scoped_memory mem;
mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED);
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
- ConvertToSorted<ProbEntry<5> >(f, vocab, counts, mem, file_prefix);
+ ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size());
ReadEnd(f);
}
@@ -390,7 +481,8 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const
temporary_directory.resize(strlen(temporary_directory.c_str()));
// Add directory delimiter. Assumes a real operating system.
temporary_directory += '/';
- ARPAToSortedFiles(f, counts, config.building_memory, temporary_directory.c_str(), vocab);
+ // At least 1MB sorting memory.
+ ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
BuildTrie(temporary_directory.c_str(), counts, config.messages, *this);
if (rmdir(temporary_directory.c_str())) {
std::cerr << "Failed to delete " << temporary_directory << std::endl;
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 8ed7b2a2..04bd2079 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -15,21 +15,21 @@ namespace {
// Assumes key is first.
class JustKeyProxy {
public:
- JustKeyProxy() : inner_(), base_(), key_mask_(), total_bits_() {}
+ JustKeyProxy() : inner_(), base_(), key_mask_(), key_bits_(), total_bits_() {}
operator uint64_t() const { return GetKey(); }
uint64_t GetKey() const {
uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_);
- return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_mask_);
+ return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_bits_, key_mask_);
}
private:
friend class util::ProxyIterator<JustKeyProxy>;
- friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index);
+ friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index);
- JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t total_bits)
- : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), total_bits_(total_bits) {}
+ JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits)
+ : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}
// This is a read-only iterator.
JustKeyProxy &operator=(const JustKeyProxy &other);
@@ -44,12 +44,12 @@ class JustKeyProxy {
uint64_t inner_;
const uint8_t *const base_;
const uint64_t key_mask_;
- const uint8_t total_bits_;
+ const uint8_t key_bits_, total_bits_;
};
-bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) {
- util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, total_bits));
- util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, total_bits));
+bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) {
+ util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits));
+ util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits));
util::ProxyIterator<JustKeyProxy> out;
if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false;
at_index = out.Inner();
@@ -96,67 +96,67 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t
assert(next <= next_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
- util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word);
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word);
at_pointer += word_bits_;
util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
at_pointer += prob_bits_;
util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff);
at_pointer += backoff_bits_;
- util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next);
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next);
++insert_index_;
}
bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
at_pointer *= total_bits_;
at_pointer += word_bits_;
prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
at_pointer += prob_bits_;
backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
at_pointer += backoff_bits_;
- range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
// Read the next entry's pointer.
at_pointer += total_bits_;
- range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
return true;
}
bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
at_pointer *= total_bits_;
at_pointer += word_bits_;
at_pointer += prob_bits_;
backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
at_pointer += backoff_bits_;
- range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
// Read the next entry's pointer.
at_pointer += total_bits_;
- range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
return true;
}
void BitPackedMiddle::FinishedLoading(uint64_t next_end) {
assert(next_end <= next_mask_);
uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_;
- util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_end);
+ util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end);
}
void BitPackedLongest::Insert(WordIndex index, float prob) {
assert(index <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
- util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, index);
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, index);
at_pointer += word_bits_;
util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
++insert_index_;
}
-bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &node) const {
+bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, total_bits_, node.begin, node.end, word, at_pointer)) return false;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
at_pointer = at_pointer * total_bits_ + word_bits_;
prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
return true;
diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc
index dd14ffe1..9d4fdf27 100644
--- a/klm/util/bit_packing.cc
+++ b/klm/util/bit_packing.cc
@@ -1,12 +1,15 @@
#include "util/bit_packing.hh"
#include "util/exception.hh"
+#include <string.h>
+
namespace util {
namespace {
template <bool> struct StaticCheck {};
template <> struct StaticCheck<true> { typedef bool StaticAssertionPassed; };
+// If your float isn't 4 bytes, we're hosed.
typedef StaticCheck<sizeof(float) == 4>::StaticAssertionPassed FloatSize;
} // namespace
@@ -21,6 +24,16 @@ uint8_t RequiredBits(uint64_t max_value) {
void BitPackingSanity() {
const detail::FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 };
if ((neg1.i ^ pos1.i) != 0x80000000) UTIL_THROW(Exception, "Sign bit is not 0x80000000");
+ char mem[57+8];
+ memset(mem, 0, sizeof(mem));
+ const uint64_t test57 = 0x123456789abcdefULL;
+ for (uint64_t b = 0; b < 57 * 8; b += 57) {
+ WriteInt57(mem + b / 8, b % 8, 57, test57);
+ }
+ for (uint64_t b = 0; b < 57 * 8; b += 57) {
+ if (test57 != ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1))
+ UTIL_THROW(Exception, "The bit packing routines are failing for your architecture. Please send a bug report with your architecture, operating system, and compiler.");
+ }
// TODO: more checks.
}
diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh
index 422ed873..0fd39d7f 100644
--- a/klm/util/bit_packing.hh
+++ b/klm/util/bit_packing.hh
@@ -6,56 +6,68 @@
#include <assert.h>
#ifdef __APPLE__
#include <architecture/byte_order.h>
-#else
+#elif __linux__
#include <endian.h>
-#endif
+#else
+#include <arpa/nameser_compat.h>
+#endif
#include <inttypes.h>
-#if __BYTE_ORDER != __LITTLE_ENDIAN
-#error The bit aligned storage functions assume little endian architecture
-#endif
-
namespace util {
/* WARNING WARNING WARNING:
* The write functions assume that memory is zero initially. This makes them
* faster and is the appropriate case for mmapped language model construction.
* These routines assume that unaligned access to uint64_t is fast and that
- * storage is little endian. This is the case on x86_64. It may not be the
- * case on 32-bit x86 but my target audience is large language models for which
- * 64-bit is necessary.
+ * storage is little endian. This is the case on x86_64. I'm not sure how
+ * fast unaligned 64-bit access is on x86 but my target audience is large
+ * language models for which 64-bit is necessary.
+ *
+ * Call the BitPackingSanity function to sanity check. Calling once suffices,
+ * but it may be called multiple times when that's inconvenient.
*/
+inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {
+// Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct.
+#if BYTE_ORDER == LITTLE_ENDIAN
+ return bit;
+#elif BYTE_ORDER == BIG_ENDIAN
+ return 64 - length - bit;
+#else
+#error "Bit packing code isn't written for your byte order."
+#endif
+}
+
/* Pack integers up to 57 bits using their least significant digits.
* The length is specified using mask:
* Assumes mask == (1 << length) - 1 where length <= 57.
*/
-inline uint64_t ReadInt57(const void *base, uint8_t bit, uint64_t mask) {
- return (*reinterpret_cast<const uint64_t*>(base) >> bit) & mask;
+inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) {
+ return (*reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, length)) & mask;
}
-/* Assumes value <= mask and mask == (1 << length) - 1 where length <= 57.
+/* Assumes value < (1 << length) and length <= 57.
* Assumes the memory is zero initially.
*/
-inline void WriteInt57(void *base, uint8_t bit, uint64_t value) {
- *reinterpret_cast<uint64_t*>(base) |= (value << bit);
+inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) {
+ *reinterpret_cast<uint64_t*>(base) |= (value << BitPackShift(bit, length));
}
namespace detail { typedef union { float f; uint32_t i; } FloatEnc; }
inline float ReadFloat32(const void *base, uint8_t bit) {
detail::FloatEnc encoded;
- encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit;
+ encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 32);
return encoded.f;
}
inline void WriteFloat32(void *base, uint8_t bit, float value) {
detail::FloatEnc encoded;
encoded.f = value;
- WriteInt57(base, bit, encoded.i);
+ WriteInt57(base, bit, 32, encoded.i);
}
inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) {
detail::FloatEnc encoded;
- encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit;
+ encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 31);
// Sign bit set means negative.
encoded.i |= 0x80000000;
return encoded.f;
@@ -65,7 +77,7 @@ inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) {
detail::FloatEnc encoded;
encoded.f = value;
encoded.i &= ~0x80000000;
- WriteInt57(base, bit, encoded.i);
+ WriteInt57(base, bit, 31, encoded.i);
}
void BitPackingSanity();
diff --git a/klm/util/bit_packing_test.cc b/klm/util/bit_packing_test.cc
new file mode 100644
index 00000000..c578ddd1
--- /dev/null
+++ b/klm/util/bit_packing_test.cc
@@ -0,0 +1,46 @@
+#include "util/bit_packing.hh"
+
+#define BOOST_TEST_MODULE BitPackingTest
+#include <boost/test/unit_test.hpp>
+
+#include <string.h>
+
+namespace util {
+namespace {
+
+const uint64_t test57 = 0x123456789abcdefULL;
+
+BOOST_AUTO_TEST_CASE(ZeroBit) {
+ char mem[16];
+ memset(mem, 0, sizeof(mem));
+ WriteInt57(mem, 0, 57, test57);
+ BOOST_CHECK_EQUAL(test57, ReadInt57(mem, 0, 57, (1ULL << 57) - 1));
+}
+
+BOOST_AUTO_TEST_CASE(EachBit) {
+ char mem[16];
+ for (uint8_t b = 0; b < 8; ++b) {
+ memset(mem, 0, sizeof(mem));
+ WriteInt57(mem, b, 57, test57);
+ BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1));
+ }
+}
+
+BOOST_AUTO_TEST_CASE(Consecutive) {
+ char mem[57+8];
+ memset(mem, 0, sizeof(mem));
+ for (uint64_t b = 0; b < 57 * 8; b += 57) {
+ WriteInt57(mem + (b / 8), b % 8, 57, test57);
+ BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1));
+ }
+ for (uint64_t b = 0; b < 57 * 8; b += 57) {
+ BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1));
+ }
+}
+
+BOOST_AUTO_TEST_CASE(Sanity) {
+ BitPackingSanity();
+}
+
+} // namespace
+} // namespace util
diff --git a/klm/util/exception.cc b/klm/util/exception.cc
index dd337a76..de6dd43c 100644
--- a/klm/util/exception.cc
+++ b/klm/util/exception.cc
@@ -24,7 +24,12 @@ const char *HandleStrerror(const char *ret, const char *buf) {
ErrnoException::ErrnoException() throw() : errno_(errno) {
char buf[200];
buf[0] = 0;
+#ifdef sun
+ const char *add = strerror(errno);
+#else
const char *add = HandleStrerror(strerror_r(errno, buf, 200), buf);
+#endif
+
if (add) {
*this << add << ' ';
}
diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc
index e7bd8659..5a667ebb 100644
--- a/klm/util/file_piece.cc
+++ b/klm/util/file_piece.cc
@@ -2,12 +2,12 @@
#include "util/exception.hh"
+#include <iostream>
#include <string>
#include <limits>
#include <assert.h>
#include <ctype.h>
-#include <err.h>
#include <fcntl.h>
#include <stdlib.h>
#include <sys/mman.h>
@@ -27,7 +27,7 @@ EndOfFileException::EndOfFileException() throw() {
EndOfFileException::~EndOfFileException() throw() {}
ParseNumberException::ParseNumberException(StringPiece value) throw() {
- *this << "Could not parse \"" << value << "\" into a float";
+ *this << "Could not parse \"" << value << "\" into a number";
}
GZException::GZException(void *file) {
@@ -68,12 +68,52 @@ FilePiece::~FilePiece() {
file_.release();
int ret;
if (Z_OK != (ret = gzclose(gz_file_))) {
- errx(1, "could not close file %s using zlib", file_name_.c_str());
+ std::cerr << "could not close file " << file_name_ << " using zlib" << std::endl;
+ abort();
}
}
#endif
}
+StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) {
+ const char *start = position_;
+ do {
+ for (const char *i = start; i < position_end_; ++i) {
+ if (*i == delim) {
+ StringPiece ret(position_, i - position_);
+ position_ = i + 1;
+ return ret;
+ }
+ }
+ size_t skip = position_end_ - position_;
+ Shift();
+ start = position_ + skip;
+ } while (!at_end_);
+ StringPiece ret(position_, position_end_ - position_);
+ position_ = position_end_;
+ return ret;
+}
+
+float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) {
+ return ReadNumber<float>();
+}
+double FilePiece::ReadDouble() throw(GZException, EndOfFileException, ParseNumberException) {
+ return ReadNumber<double>();
+}
+long int FilePiece::ReadLong() throw(GZException, EndOfFileException, ParseNumberException) {
+ return ReadNumber<long int>();
+}
+unsigned long int FilePiece::ReadULong() throw(GZException, EndOfFileException, ParseNumberException) {
+ return ReadNumber<unsigned long int>();
+}
+
+void FilePiece::SkipSpaces() throw (GZException, EndOfFileException) {
+ for (; ; ++position_) {
+ if (position_ == position_end_) Shift();
+ if (!isspace(*position_)) return;
+ }
+}
+
void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) {
#ifdef HAVE_ZLIB
gz_file_ = NULL;
@@ -108,14 +148,34 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t
}
}
-float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) {
+namespace {
+void ParseNumber(const char *begin, char *&end, float &out) {
+#ifdef sun
+ out = static_cast<float>(strtod(begin, &end));
+#else
+ out = strtof(begin, &end);
+#endif
+}
+void ParseNumber(const char *begin, char *&end, double &out) {
+ out = strtod(begin, &end);
+}
+void ParseNumber(const char *begin, char *&end, long int &out) {
+ out = strtol(begin, &end, 10);
+}
+void ParseNumber(const char *begin, char *&end, unsigned long int &out) {
+ out = strtoul(begin, &end, 10);
+}
+} // namespace
+
+template <class T> T FilePiece::ReadNumber() throw(GZException, EndOfFileException, ParseNumberException) {
SkipSpaces();
while (last_space_ < position_) {
if (at_end_) {
// Hallucinate a null off the end of the file.
std::string buffer(position_, position_end_);
char *end;
- float ret = strtof(buffer.c_str(), &end);
+ T ret;
+ ParseNumber(buffer.c_str(), end, ret);
if (buffer.c_str() == end) throw ParseNumberException(buffer);
position_ += end - buffer.c_str();
return ret;
@@ -123,19 +183,13 @@ float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberE
Shift();
}
char *end;
- float ret = strtof(position_, &end);
+ T ret;
+ ParseNumber(position_, end, ret);
if (end == position_) throw ParseNumberException(ReadDelimited());
position_ = end;
return ret;
}
-void FilePiece::SkipSpaces() throw (GZException, EndOfFileException) {
- for (; ; ++position_) {
- if (position_ == position_end_) Shift();
- if (!isspace(*position_)) return;
- }
-}
-
const char *FilePiece::FindDelimiterOrEOF() throw (GZException, EndOfFileException) {
for (const char *i = position_; i <= last_space_; ++i) {
if (isspace(*i)) return i;
@@ -150,25 +204,6 @@ const char *FilePiece::FindDelimiterOrEOF() throw (GZException, EndOfFileExcepti
return position_end_;
}
-StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) {
- const char *start = position_;
- do {
- for (const char *i = start; i < position_end_; ++i) {
- if (*i == delim) {
- StringPiece ret(position_, i - position_);
- position_ = i + 1;
- return ret;
- }
- }
- size_t skip = position_end_ - position_;
- Shift();
- start = position_ + skip;
- } while (!at_end_);
- StringPiece ret(position_, position_end_ - position_);
- position_ = position_end_;
- return ret;
-}
-
void FilePiece::Shift() throw(GZException, EndOfFileException) {
if (at_end_) {
progress_.Finished();
diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh
index 11d4a751..b7697e71 100644
--- a/klm/util/file_piece.hh
+++ b/klm/util/file_piece.hh
@@ -68,6 +68,9 @@ class FilePiece {
StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException);
float ReadFloat() throw(GZException, EndOfFileException, ParseNumberException);
+ double ReadDouble() throw(GZException, EndOfFileException, ParseNumberException);
+ long int ReadLong() throw(GZException, EndOfFileException, ParseNumberException);
+ unsigned long int ReadULong() throw(GZException, EndOfFileException, ParseNumberException);
void SkipSpaces() throw (GZException, EndOfFileException);
@@ -80,6 +83,8 @@ class FilePiece {
private:
void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw(GZException);
+ template <class T> T ReadNumber() throw(GZException, EndOfFileException, ParseNumberException);
+
StringPiece Consume(const char *to) {
StringPiece ret(position_, to - position_);
position_ = to;
diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc
index 23e79fe0..dc9ec7e7 100644
--- a/klm/util/file_piece_test.cc
+++ b/klm/util/file_piece_test.cc
@@ -8,6 +8,8 @@
#include <iostream>
#include <stdio.h>
+#include <sys/types.h>
+#include <sys/stat.h>
namespace util {
namespace {
@@ -27,14 +29,18 @@ BOOST_AUTO_TEST_CASE(MMapReadLine) {
BOOST_CHECK_THROW(test.get(), EndOfFileException);
}
+#ifndef __APPLE__
+/* Apple isn't happy with the popen, fileno, dup. And I don't want to
+ * reimplement popen. This is an issue with the test.
+ */
/* read() implementation */
BOOST_AUTO_TEST_CASE(StreamReadLine) {
std::fstream ref("file_piece.cc", std::ios::in);
- scoped_FILE catter(popen("cat file_piece.cc", "r"));
- BOOST_REQUIRE(catter.get());
+ FILE *catter = popen("cat file_piece.cc", "r");
+ BOOST_REQUIRE(catter);
- FilePiece test(dup(fileno(catter.get())), "file_piece.cc", NULL, 1);
+ FilePiece test(dup(fileno(catter)), "file_piece.cc", NULL, 1);
std::string ref_line;
while (getline(ref, ref_line)) {
StringPiece test_line(test.ReadLine());
@@ -44,7 +50,9 @@ BOOST_AUTO_TEST_CASE(StreamReadLine) {
}
}
BOOST_CHECK_THROW(test.get(), EndOfFileException);
+ BOOST_REQUIRE(!pclose(catter));
}
+#endif // __APPLE__
#ifdef HAVE_ZLIB
@@ -64,14 +72,17 @@ BOOST_AUTO_TEST_CASE(PlainZipReadLine) {
}
BOOST_CHECK_THROW(test.get(), EndOfFileException);
}
-// gzip stream
+
+// gzip stream. Apple doesn't like popen, fileno, dup. This is an issue with
+// the test.
+#ifndef __APPLE__
BOOST_AUTO_TEST_CASE(StreamZipReadLine) {
std::fstream ref("file_piece.cc", std::ios::in);
- scoped_FILE catter(popen("gzip <file_piece.cc", "r"));
- BOOST_REQUIRE(catter.get());
+ FILE * catter = popen("gzip <file_piece.cc", "r");
+ BOOST_REQUIRE(catter);
- FilePiece test(dup(fileno(catter.get())), "file_piece.cc", NULL, 1);
+ FilePiece test(dup(fileno(catter)), "file_piece.cc", NULL, 1);
std::string ref_line;
while (getline(ref, ref_line)) {
StringPiece test_line(test.ReadLine());
@@ -81,9 +92,11 @@ BOOST_AUTO_TEST_CASE(StreamZipReadLine) {
}
}
BOOST_CHECK_THROW(test.get(), EndOfFileException);
+ BOOST_REQUIRE(!pclose(catter));
}
+#endif // __APPLE__
-#endif
+#endif // HAVE_ZLIB
} // namespace
} // namespace util
diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc
index 8685170f..5a810c64 100644
--- a/klm/util/mmap.cc
+++ b/klm/util/mmap.cc
@@ -2,8 +2,9 @@
#include "util/mmap.hh"
#include "util/scoped.hh"
+#include <iostream>
+
#include <assert.h>
-#include <err.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/mman.h>
@@ -14,8 +15,10 @@ namespace util {
scoped_mmap::~scoped_mmap() {
if (data_ != (void*)-1) {
- if (munmap(data_, size_))
- err(1, "munmap failed ");
+ if (munmap(data_, size_)) {
+ std::cerr << "munmap failed for " << size_ << " bytes." << std::endl;
+ abort();
+ }
}
}
@@ -73,18 +76,27 @@ void ReadAll(int fd, void *to_void, std::size_t amount) {
to += ret;
}
}
+
+const int kFileFlags =
+#ifdef MAP_FILE
+ MAP_FILE | MAP_SHARED
+#else
+ MAP_SHARED
+#endif
+ ;
+
} // namespace
void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out) {
switch (method) {
case LAZY:
- out.reset(MapOrThrow(size, false, MAP_FILE | MAP_SHARED, false, fd, offset), size, scoped_memory::MMAP_ALLOCATED);
+ out.reset(MapOrThrow(size, false, kFileFlags, false, fd, offset), size, scoped_memory::MMAP_ALLOCATED);
break;
case POPULATE_OR_LAZY:
#ifdef MAP_POPULATE
case POPULATE_OR_READ:
#endif
- out.reset(MapOrThrow(size, false, MAP_FILE | MAP_SHARED, true, fd, offset), size, scoped_memory::MMAP_ALLOCATED);
+ out.reset(MapOrThrow(size, false, kFileFlags, true, fd, offset), size, scoped_memory::MMAP_ALLOCATED);
break;
#ifndef MAP_POPULATE
case POPULATE_OR_READ:
@@ -115,7 +127,7 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) {
if (-1 == ftruncate(file.get(), size))
UTIL_THROW(ErrnoException, "ftruncate on " << name << " to " << size << " failed");
try {
- return MapOrThrow(size, true, MAP_FILE | MAP_SHARED, false, file.get(), 0);
+ return MapOrThrow(size, true, kFileFlags, false, file.get(), 0);
} catch (ErrnoException &e) {
e << " in file " << name;
throw;
diff --git a/klm/util/murmur_hash.hh b/klm/util/murmur_hash.hh
index 638aaeb2..78fe583f 100644
--- a/klm/util/murmur_hash.hh
+++ b/klm/util/murmur_hash.hh
@@ -1,7 +1,7 @@
#ifndef UTIL_MURMUR_HASH__
#define UTIL_MURMUR_HASH__
#include <cstddef>
-#include <stdint.h>
+#include <inttypes.h>
namespace util {
diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc
index 2c6d5394..a4cc5016 100644
--- a/klm/util/scoped.cc
+++ b/klm/util/scoped.cc
@@ -1,16 +1,24 @@
#include "util/scoped.hh"
-#include <err.h>
+#include <iostream>
+
+#include <stdlib.h>
#include <unistd.h>
namespace util {
scoped_fd::~scoped_fd() {
- if (fd_ != -1 && close(fd_)) err(1, "Could not close file %i", fd_);
+ if (fd_ != -1 && close(fd_)) {
+ std::cerr << "Could not close file " << fd_ << std::endl;
+ abort();
+ }
}
scoped_FILE::~scoped_FILE() {
- if (file_ && fclose(file_)) err(1, "Could not close file");
+ if (file_ && fclose(file_)) {
+ std::cerr << "Could not close file " << std::endl;
+ abort();
+ }
}
} // namespace util
diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh
index 4557173b..3ac2f8a7 100644
--- a/klm/util/string_piece.hh
+++ b/klm/util/string_piece.hh
@@ -51,7 +51,7 @@
//Uncomment this line if you use ICU in your code.
//#define HAVE_ICU
//Uncomment this line if you want boost hashing for your StringPieces.
-#define HAVE_BOOST
+//#define HAVE_BOOST
#include <cstring>
#include <iosfwd>