summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2013-04-24 10:12:41 +0100
committerKenneth Heafield <github@kheafield.com>2013-04-24 10:12:41 +0100
commit5aee54869aa19cfe9be965e67a472e94449d16da (patch)
tree8f2111e58c4cc4108b8daabb8c38f16adcaff2f0 /klm
parent9957c8a43354fe0a81b83659de965d6d0934adf8 (diff)
KenLM 0831569c3137536165b107c6841603c725dfa2b1
Diffstat (limited to 'klm')
-rw-r--r--klm/lm/builder/corpus_count.cc82
-rw-r--r--klm/lm/builder/corpus_count.hh5
-rw-r--r--klm/lm/builder/corpus_count_test.cc2
-rw-r--r--klm/lm/builder/lmplz_main.cc17
-rw-r--r--klm/lm/builder/pipeline.cc7
-rw-r--r--klm/lm/builder/pipeline.hh9
-rw-r--r--klm/lm/builder/print.cc74
-rw-r--r--klm/lm/builder/print.hh3
-rw-r--r--klm/lm/filter/filter_main.cc4
-rw-r--r--klm/lm/kenlm_max_order_main.cc6
-rw-r--r--klm/lm/query_main.cc1
-rw-r--r--klm/util/fake_ofstream.hh94
-rw-r--r--klm/util/file.cc37
-rw-r--r--klm/util/file_piece.cc32
-rw-r--r--klm/util/file_piece.hh5
-rw-r--r--klm/util/mmap.cc14
-rw-r--r--klm/util/probing_hash_table.hh92
-rw-r--r--klm/util/probing_hash_table_test.cc52
-rw-r--r--klm/util/read_compressed.cc100
-rw-r--r--klm/util/scoped.cc28
-rw-r--r--klm/util/scoped.hh1
-rw-r--r--klm/util/sized_iterator.hh8
-rw-r--r--klm/util/usage.cc12
23 files changed, 484 insertions, 201 deletions
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc
index abea4ed0..aea93ad1 100644
--- a/klm/lm/builder/corpus_count.cc
+++ b/klm/lm/builder/corpus_count.cc
@@ -3,6 +3,7 @@
#include "lm/builder/ngram.hh"
#include "lm/lm_exception.hh"
#include "lm/word_index.hh"
+#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include "util/murmur_hash.hh"
@@ -23,39 +24,71 @@ namespace lm {
namespace builder {
namespace {
+#pragma pack(push)
+#pragma pack(4)
+struct VocabEntry {
+ typedef uint64_t Key;
+
+ uint64_t GetKey() const { return key; }
+ void SetKey(uint64_t to) { key = to; }
+
+ uint64_t key;
+ lm::WordIndex value;
+};
+#pragma pack(pop)
+
+const float kProbingMultiplier = 1.5;
+
class VocabHandout {
public:
- explicit VocabHandout(int fd) {
- util::scoped_fd duped(util::DupOrThrow(fd));
- word_list_.reset(util::FDOpenOrThrow(duped));
-
+ static std::size_t MemUsage(WordIndex initial_guess) {
+ if (initial_guess < 2) initial_guess = 2;
+ return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier));
+ }
+
+ explicit VocabHandout(int fd, WordIndex initial_guess) :
+ table_backing_(util::CallocOrThrow(MemUsage(initial_guess))),
+ table_(table_backing_.get(), MemUsage(initial_guess)),
+ double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)),
+ word_list_(fd) {
Lookup("<unk>"); // Force 0
Lookup("<s>"); // Force 1
Lookup("</s>"); // Force 2
}
WordIndex Lookup(const StringPiece &word) {
- uint64_t hashed = util::MurmurHashNative(word.data(), word.size());
- std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size())));
- if (ret.second) {
- char null_delimit = 0;
- util::WriteOrThrow(word_list_.get(), word.data(), word.size());
- util::WriteOrThrow(word_list_.get(), &null_delimit, 1);
- UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ VocabEntry entry;
+ entry.key = util::MurmurHashNative(word.data(), word.size());
+ entry.value = table_.SizeNoSerialization();
+
+ Table::MutableIterator it;
+ if (table_.FindOrInsert(entry, it))
+ return it->value;
+ word_list_ << word << '\0';
+ UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ if (Size() >= double_cutoff_) {
+ table_backing_.call_realloc(table_.DoubleTo());
+ table_.Double(table_backing_.get());
+ double_cutoff_ *= 2;
}
- return ret.first->second;
+ return entry.value;
}
WordIndex Size() const {
- return seen_.size();
+ return table_.SizeNoSerialization();
}
private:
- typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen;
+ // TODO: factor out a resizable probing hash table.
+ // TODO: use mremap on linux to get all zeros on resizes.
+ util::scoped_malloc table_backing_;
- Seen seen_;
+ typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table;
+ Table table_;
- util::scoped_FILE word_list_;
+ std::size_t double_cutoff_;
+
+ util::FakeOFStream word_list_;
};
class DedupeHash : public std::unary_function<const WordIndex *, bool> {
@@ -85,6 +118,7 @@ class DedupeEquals : public std::binary_function<const WordIndex *, const WordIn
struct DedupeEntry {
typedef WordIndex *Key;
Key GetKey() const { return key; }
+ void SetKey(WordIndex *to) { key = to; }
Key key;
static DedupeEntry Construct(WordIndex *at) {
DedupeEntry ret;
@@ -95,8 +129,6 @@ struct DedupeEntry {
typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
-const float kProbingMultiplier = 1.5;
-
class Writer {
public:
Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
@@ -105,7 +137,7 @@ class Writer {
dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
buffer_(new WordIndex[order - 1]),
block_size_(position.GetChain().BlockSize()) {
- dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ dedupe_.Clear();
assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
if (order == 1) {
// Add special words. AdjustCounts is responsible if order != 1.
@@ -149,7 +181,7 @@ class Writer {
}
// Block end. Need to store the context in a temporary buffer.
std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
- dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ dedupe_.Clear();
block_->SetValidSize(block_size_);
gram_.ReBase((++block_)->Get());
std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
@@ -187,18 +219,22 @@ float CorpusCount::DedupeMultiplier(std::size_t order) {
return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
}
+std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
+ return VocabHandout::MemUsage(vocab_estimate);
+}
+
CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
- token_count_ = 0;
- type_count_ = 0;
}
void CorpusCount::Run(const util::stream::ChainPosition &position) {
UTIL_TIMER("(%w s) Counted n-grams\n");
- VocabHandout vocab(vocab_write_);
+ VocabHandout vocab(vocab_write_, type_count_);
+ token_count_ = 0;
+ type_count_ = 0;
const WordIndex end_sentence = vocab.Lookup("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;
diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh
index e255bad1..aa0ed8ed 100644
--- a/klm/lm/builder/corpus_count.hh
+++ b/klm/lm/builder/corpus_count.hh
@@ -23,6 +23,11 @@ class CorpusCount {
// Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size
static float DedupeMultiplier(std::size_t order);
+ // How much memory vocabulary will use based on estimated size of the vocab.
+ static std::size_t VocabUsage(std::size_t vocab_estimate);
+
+ // token_count: out.
+ // type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value.
CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);
void Run(const util::stream::ChainPosition &position);
diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc
index 8d53ca9d..6d325ef5 100644
--- a/klm/lm/builder/corpus_count_test.cc
+++ b/klm/lm/builder/corpus_count_test.cc
@@ -44,7 +44,7 @@ BOOST_AUTO_TEST_CASE(Short) {
util::stream::Chain chain(config);
NGramStream stream;
uint64_t token_count;
- WordIndex type_count;
+ WordIndex type_count = 10;
CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize());
chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;
diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc
index 90b9dca2..1e086dcc 100644
--- a/klm/lm/builder/lmplz_main.cc
+++ b/klm/lm/builder/lmplz_main.cc
@@ -6,6 +6,7 @@
#include <iostream>
#include <boost/program_options.hpp>
+#include <boost/version.hpp>
namespace {
class SizeNotify {
@@ -33,13 +34,17 @@ int main(int argc, char *argv[]) {
lm::builder::PipelineConfig pipeline;
options.add_options()
- ("order,o", po::value<std::size_t>(&pipeline.order)->required(), "Order of the model")
+ ("order,o", po::value<std::size_t>(&pipeline.order)
+#if BOOST_VERSION >= 104200
+ ->required()
+#endif
+ , "Order of the model")
("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
- ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step")
("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
+ ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.");
@@ -68,6 +73,14 @@ int main(int argc, char *argv[]) {
po::store(po::parse_command_line(argc, argv, options), vm);
po::notify(vm);
+ // required() appeared in Boost 1.42.0.
+#if BOOST_VERSION < 104200
+ if (!vm.count("order")) {
+ std::cerr << "the option '--order' is required but missing" << std::endl;
+ return 1;
+ }
+#endif
+
util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;
diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc
index 14a1f721..b89ea6ba 100644
--- a/klm/lm/builder/pipeline.cc
+++ b/klm/lm/builder/pipeline.cc
@@ -207,17 +207,18 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
const PipelineConfig &config = master.Config();
std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl;
- UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory());
+ const std::size_t vocab_usage = CorpusCount::VocabUsage(config.vocab_estimate);
+ UTIL_THROW_IF(config.TotalMemory() < vocab_usage, util::Exception, "Vocab hash size estimate " << vocab_usage << " exceeds total memory " << config.TotalMemory());
std::size_t memory_for_chain =
// This much memory to work with after vocab hash table.
- static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) /
+ static_cast<float>(config.TotalMemory() - vocab_usage) /
// Solve for block size including the dedupe multiplier for one block.
(static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *
// Chain likes memory expressed in terms of total memory.
static_cast<float>(config.block_count);
util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain));
- WordIndex type_count;
+ WordIndex type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());
diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh
index f1d6c5f6..845e5481 100644
--- a/klm/lm/builder/pipeline.hh
+++ b/klm/lm/builder/pipeline.hh
@@ -3,6 +3,7 @@
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/header_info.hh"
+#include "lm/word_index.hh"
#include "util/stream/config.hh"
#include "util/file_piece.hh"
@@ -19,9 +20,9 @@ struct PipelineConfig {
util::stream::ChainConfig read_backoffs;
bool verbose_header;
- // Amount of memory to assume that the vocabulary hash table will use. This
- // is subtracted from total memory for CorpusCount.
- std::size_t assume_vocab_hash_size;
+ // Estimated vocabulary size. Used for sizing CorpusCount memory and
+ // initial probing hash table sizing, also in CorpusCount.
+ lm::WordIndex vocab_estimate;
// Minimum block size to tolerate.
std::size_t minimum_block;
@@ -33,7 +34,7 @@ struct PipelineConfig {
std::size_t TotalMemory() const { return sort.total_memory; }
};
-// Takes ownership of text_file.
+// Takes ownership of text_file and out_arpa.
void Pipeline(PipelineConfig config, int text_file, int out_arpa);
}} // namespaces
diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc
index b0323221..84bd81ca 100644
--- a/klm/lm/builder/print.cc
+++ b/klm/lm/builder/print.cc
@@ -1,15 +1,11 @@
#include "lm/builder/print.hh"
-#include "util/double-conversion/double-conversion.h"
-#include "util/double-conversion/utils.h"
+#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/mmap.hh"
#include "util/scoped.hh"
#include "util/stream/timer.hh"
-#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE
-#include <boost/lexical_cast.hpp>
-
#include <sstream>
#include <string.h>
@@ -28,71 +24,6 @@ VocabReconstitute::VocabReconstitute(int fd) {
map_.push_back(i);
}
-namespace {
-class OutputManager {
- public:
- static const std::size_t kOutBuf = 1048576;
-
- // Does not take ownership of out.
- explicit OutputManager(int out)
- : buf_(util::MallocOrThrow(kOutBuf)),
- builder_(static_cast<char*>(buf_.get()), kOutBuf),
- // Mostly the default but with inf instead. And no flags.
- convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0),
- fd_(out) {}
-
- ~OutputManager() {
- Flush();
- }
-
- OutputManager &operator<<(float value) {
- // Odd, but this is the largest number found in the comments.
- EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
- convert_.ToShortestSingle(value, &builder_);
- return *this;
- }
-
- OutputManager &operator<<(StringPiece str) {
- if (str.size() > kOutBuf) {
- Flush();
- util::WriteOrThrow(fd_, str.data(), str.size());
- } else {
- EnsureRemaining(str.size());
- builder_.AddSubstring(str.data(), str.size());
- }
- return *this;
- }
-
- // Inefficient!
- OutputManager &operator<<(unsigned val) {
- return *this << boost::lexical_cast<std::string>(val);
- }
-
- OutputManager &operator<<(char c) {
- EnsureRemaining(1);
- builder_.AddCharacter(c);
- return *this;
- }
-
- void Flush() {
- util::WriteOrThrow(fd_, buf_.get(), builder_.position());
- builder_.Reset();
- }
-
- private:
- void EnsureRemaining(std::size_t amount) {
- if (static_cast<std::size_t>(builder_.size() - builder_.position()) < amount) {
- Flush();
- }
- }
-
- util::scoped_malloc buf_;
- double_conversion::StringBuilder builder_;
- double_conversion::DoubleToStringConverter convert_;
- int fd_;
-};
-} // namespace
-
PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd)
: vocab_(vocab), out_fd_(out_fd) {
std::stringstream stream;
@@ -112,8 +43,9 @@ PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t>
}
void PrintARPA::Run(const ChainPositions &positions) {
+ util::scoped_fd closer(out_fd_);
UTIL_TIMER("(%w s) Wrote ARPA file\n");
- OutputManager out(out_fd_);
+ util::FakeOFStream out(out_fd_);
for (unsigned order = 1; order <= positions.size(); ++order) {
out << "\\" << order << "-grams:" << '\n';
for (NGramStream stream(positions[order - 1]); stream; ++stream) {
diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh
index aa932e75..adbbb94a 100644
--- a/klm/lm/builder/print.hh
+++ b/klm/lm/builder/print.hh
@@ -88,7 +88,8 @@ template <class V> class Print {
class PrintARPA {
public:
- // header_info may be NULL to disable the header
+ // header_info may be NULL to disable the header.
+ // Takes ownership of out_fd upon Run().
explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd);
void Run(const ChainPositions &positions);
diff --git a/klm/lm/filter/filter_main.cc b/klm/lm/filter/filter_main.cc
index 1a4ba84f..1736bc40 100644
--- a/klm/lm/filter/filter_main.cc
+++ b/klm/lm/filter/filter_main.cc
@@ -25,8 +25,8 @@ void DisplayHelp(const char *name) {
" parser.\n"
"single mode treats the entire input as a single sentence.\n"
"multiple mode filters to multiple sentences in parallel. Each sentence is on\n"
- " a separate line. A separate file is created for each file by appending the\n"
- " 0-indexed line number to the output file name.\n"
+ " a separate line. A separate file is created for each sentence by appending\n"
+ " the 0-indexed line number to the output file name.\n"
"union mode produces one filtered model that is the union of models created by\n"
" multiple mode.\n\n"
"context means only the context (all but last word) has to pass the filter, but\n"
diff --git a/klm/lm/kenlm_max_order_main.cc b/klm/lm/kenlm_max_order_main.cc
deleted file mode 100644
index 94221201..00000000
--- a/klm/lm/kenlm_max_order_main.cc
+++ /dev/null
@@ -1,6 +0,0 @@
-#include "lm/max_order.hh"
-#include <iostream>
-
-int main(int argc, char *argv[]) {
- std::cerr << "KenLM was compiled with a maximum supported n-gram order set to " << KENLM_MAX_ORDER << "." << std::endl;
-}
diff --git a/klm/lm/query_main.cc b/klm/lm/query_main.cc
index 49757d9a..27d3a1a5 100644
--- a/klm/lm/query_main.cc
+++ b/klm/lm/query_main.cc
@@ -2,6 +2,7 @@
int main(int argc, char *argv[]) {
if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) {
+ std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl;
std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl;
std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl;
return 1;
diff --git a/klm/util/fake_ofstream.hh b/klm/util/fake_ofstream.hh
new file mode 100644
index 00000000..bcdebe45
--- /dev/null
+++ b/klm/util/fake_ofstream.hh
@@ -0,0 +1,94 @@
+/* Like std::ofstream but without being incredibly slow. Backed by a raw fd.
+ * Does not support many data types. Currently, it's targeted at writing ARPA
+ * files quickly.
+ */
+#include "util/double-conversion/double-conversion.h"
+#include "util/double-conversion/utils.h"
+#include "util/file.hh"
+#include "util/scoped.hh"
+#include "util/string_piece.hh"
+
+#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE
+#include <boost/lexical_cast.hpp>
+
+namespace util {
+class FakeOFStream {
+ public:
+ static const std::size_t kOutBuf = 1048576;
+
+ // Does not take ownership of out.
+ explicit FakeOFStream(int out)
+ : buf_(util::MallocOrThrow(kOutBuf)),
+ builder_(static_cast<char*>(buf_.get()), kOutBuf),
+ // Mostly the default but with inf instead. And no flags.
+ convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0),
+ fd_(out) {}
+
+ ~FakeOFStream() {
+ if (buf_.get()) Flush();
+ }
+
+ FakeOFStream &operator<<(float value) {
+ // Odd, but this is the largest number found in the comments.
+ EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
+ convert_.ToShortestSingle(value, &builder_);
+ return *this;
+ }
+
+ FakeOFStream &operator<<(double value) {
+ EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
+ convert_.ToShortest(value, &builder_);
+ return *this;
+ }
+
+ FakeOFStream &operator<<(StringPiece str) {
+ if (str.size() > kOutBuf) {
+ Flush();
+ util::WriteOrThrow(fd_, str.data(), str.size());
+ } else {
+ EnsureRemaining(str.size());
+ builder_.AddSubstring(str.data(), str.size());
+ }
+ return *this;
+ }
+
+ // Inefficient! TODO: more efficient implementation
+ FakeOFStream &operator<<(unsigned value) {
+ return *this << boost::lexical_cast<std::string>(value);
+ }
+
+ FakeOFStream &operator<<(char c) {
+ EnsureRemaining(1);
+ builder_.AddCharacter(c);
+ return *this;
+ }
+
+ // Note this does not sync.
+ void Flush() {
+ util::WriteOrThrow(fd_, buf_.get(), builder_.position());
+ builder_.Reset();
+ }
+
+ // Not necessary, but does assure the data is cleared.
+ void Finish() {
+ Flush();
+ // It will segfault trying to null terminate otherwise.
+ builder_.Finalize();
+ buf_.reset();
+ util::FSyncOrThrow(fd_);
+ }
+
+ private:
+ void EnsureRemaining(std::size_t amount) {
+ if (static_cast<std::size_t>(builder_.size() - builder_.position()) <= amount) {
+ Flush();
+ }
+ }
+
+ util::scoped_malloc buf_;
+ double_conversion::StringBuilder builder_;
+ double_conversion::DoubleToStringConverter convert_;
+ int fd_;
+};
+
+} // namespace
diff --git a/klm/util/file.cc b/klm/util/file.cc
index 86d9b12d..c7d8e23b 100644
--- a/klm/util/file.cc
+++ b/klm/util/file.cc
@@ -111,15 +111,26 @@ void ResizeOrThrow(int fd, uint64_t to) {
UTIL_THROW_IF_ARG(ret, FDException, (fd), "while resizing to " << to << " bytes");
}
+namespace {
+std::size_t GuardLarge(std::size_t size) {
+ // The following operating systems have broken read/write/pread/pwrite that
+ // only supports up to 2^31.
+#if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID)
+ return std::min(static_cast<std::size_t>(INT_MAX), size);
+#else
+ return size;
+#endif
+}
+}
+
std::size_t PartialRead(int fd, void *to, std::size_t amount) {
#if defined(_WIN32) || defined(_WIN64)
- amount = min(static_cast<std::size_t>(INT_MAX), amount);
- int ret = _read(fd, to, amount);
+ int ret = _read(fd, to, GuardLarge(amount));
#else
errno = 0;
ssize_t ret;
do {
- ret = read(fd, to, amount);
+ ret = read(fd, to, GuardLarge(amount));
} while (ret == -1 && errno == EINTR);
#endif
UTIL_THROW_IF_ARG(ret < 0, FDException, (fd), "while reading " << amount << " bytes");
@@ -169,11 +180,13 @@ void PReadOrThrow(int fd, void *to_void, std::size_t size, uint64_t off) {
ssize_t ret;
errno = 0;
do {
+ ret =
#ifdef OS_ANDROID
- ret = pread64(fd, to, size, off);
+ pread64
#else
- ret = pread(fd, to, size, off);
+ pread
#endif
+ (fd, to, GuardLarge(size), off);
} while (ret == -1 && errno == EINTR);
if (ret <= 0) {
UTIL_THROW_IF(ret == 0, EndOfFileException, " for reading " << size << " bytes at " << off << " from " << NameFromFD(fd));
@@ -190,14 +203,20 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
const uint8_t *data = static_cast<const uint8_t*>(data_void);
while (size) {
#if defined(_WIN32) || defined(_WIN64)
- int ret = write(fd, data, min(static_cast<std::size_t>(INT_MAX), size));
+ int ret;
#else
- errno = 0;
ssize_t ret;
+#endif
+ errno = 0;
do {
- ret = write(fd, data, size);
- } while (ret == -1 && errno == EINTR);
+ ret =
+#if defined(_WIN32) || defined(_WIN64)
+ _write
+#else
+ write
#endif
+ (fd, data, GuardLarge(size));
+ } while (ret == -1 && errno == EINTR);
UTIL_THROW_IF_ARG(ret < 1, FDException, (fd), "while writing " << size << " bytes");
data += ret;
size -= ret;
diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc
index 9de30fc4..b5961bea 100644
--- a/klm/util/file_piece.cc
+++ b/klm/util/file_piece.cc
@@ -51,7 +51,7 @@ FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std:
FilePiece::FilePiece(std::istream &stream, const char *name, std::size_t min_buffer) :
total_size_(kBadSize), page_(SizePage()) {
- InitializeNoRead(name ? name : "istream", min_buffer);
+ InitializeNoRead("istream", min_buffer);
fallback_to_read_ = true;
data_.reset(MallocOrThrow(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED);
@@ -95,32 +95,6 @@ unsigned long int FilePiece::ReadULong() {
return ReadNumber<unsigned long int>();
}
-std::size_t FilePiece::Raw(void *to, std::size_t limit) {
- if (!limit) return 0;
- std::size_t in_buf = static_cast<std::size_t>(position_end_ - position_);
- if (in_buf) {
- std::size_t amount = std::min(in_buf, limit);
- memcpy(to, position_, amount);
- position_ += amount;
- return amount;
- }
-
- std::size_t read_return;
- if (fallback_to_read_) {
- read_return = fell_back_.Read(to, limit);
- progress_.Set(fell_back_.RawAmount());
- } else {
- uint64_t desired_begin = mapped_offset_ + static_cast<uint64_t>(position_ - data_.begin());
- SeekOrThrow(file_.get(), desired_begin);
- read_return = ReadOrEOF(file_.get(), to, limit);
- // Good thing we never rewind. This makes desired_begin calculate the right way the next time.
- mapped_offset_ += static_cast<uint64_t>(read_return);
- progress_ += read_return;
- }
- at_end_ |= (read_return == 0);
- return read_return;
-}
-
// Factored out so that istream can call this.
void FilePiece::InitializeNoRead(const char *name, std::size_t min_buffer) {
file_name_ = name;
@@ -146,7 +120,7 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::s
}
Shift();
// gzip detect.
- if ((position_end_ - position_) >= ReadCompressed::kMagicSize && ReadCompressed::DetectCompressedMagic(position_)) {
+ if ((position_end_ >= position_ + ReadCompressed::kMagicSize) && ReadCompressed::DetectCompressedMagic(position_)) {
if (!fallback_to_read_) {
at_end_ = false;
TransitionToRead();
@@ -244,7 +218,7 @@ void FilePiece::MMapShift(uint64_t desired_begin) {
// Use mmap.
uint64_t ignore = desired_begin % page_;
// Duplicate request for Shift means give more data.
- if (position_ == data_.begin() + ignore) {
+ if (position_ == data_.begin() + ignore && position_) {
default_map_size_ *= 2;
}
// Local version so that in case of failure it doesn't overwrite the class variable.
diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh
index 1b110287..c07c6011 100644
--- a/klm/util/file_piece.hh
+++ b/klm/util/file_piece.hh
@@ -64,10 +64,7 @@ class FilePiece {
long int ReadLong();
unsigned long int ReadULong();
- // Fake read() function. Reads up to limit bytes, returning the amount read. Returns 0 on EOF || limit == 0.
- std::size_t Raw(void *to, std::size_t limit);
-
- // Skip spaces defined by being in delim.
+ // Skip spaces defined by isspace.
void SkipSpaces(const bool *delim = kSpaces) {
for (; ; ++position_) {
if (position_ == position_end_) Shift();
diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc
index bc9e3f81..6f79f26f 100644
--- a/klm/util/mmap.cc
+++ b/klm/util/mmap.cc
@@ -6,6 +6,7 @@
#include "util/exception.hh"
#include "util/file.hh"
+#include "util/scoped.hh"
#include <iostream>
@@ -110,8 +111,14 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int
UTIL_THROW_IF(!ret, ErrnoException, "MapViewOfFile failed");
#else
int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ;
- void *ret = mmap(NULL, size, protect, flags, fd, offset);
- UTIL_THROW_IF(ret == MAP_FAILED, ErrnoException, "mmap failed for size " << size << " at offset " << offset);
+ void *ret;
+ UTIL_THROW_IF((ret = mmap(NULL, size, protect, flags, fd, offset)) == MAP_FAILED, ErrnoException, "mmap failed for size " << size << " at offset " << offset);
+# ifdef MADV_HUGEPAGE
+ /* We like huge pages but it's fine if we can't have them. Note that huge
+ * pages are not supported for file-backed mmap on linux.
+ */
+ madvise(ret, size, MADV_HUGEPAGE);
+# endif
#endif
return ret;
}
@@ -141,8 +148,7 @@ void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scope
case POPULATE_OR_READ:
#endif
case READ:
- out.reset(malloc(size), size, scoped_memory::MALLOC_ALLOCATED);
- if (!out.get()) UTIL_THROW(util::ErrnoException, "Allocating " << size << " bytes with malloc");
+ out.reset(MallocOrThrow(size), size, scoped_memory::MALLOC_ALLOCATED);
SeekOrThrow(fd, offset);
ReadOrThrow(fd, out.get(), size);
break;
diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh
index 6780489d..57866ff9 100644
--- a/klm/util/probing_hash_table.hh
+++ b/klm/util/probing_hash_table.hh
@@ -6,6 +6,7 @@
#include <algorithm>
#include <cstddef>
#include <functional>
+#include <vector>
#include <assert.h>
#include <stdint.h>
@@ -73,10 +74,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
assert(initialized_);
#endif
UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full.");
- for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) {
- if (equal_(i->GetKey(), invalid_)) { *i = t; return i; }
- if (++i == end_) { i = begin_; }
- }
+ return UncheckedInsert(t);
}
// Return true if the value was found (and not inserted). This is consistent with Find but the opposite if hash_map!
@@ -126,12 +124,96 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
}
}
- void Clear(Entry invalid) {
+ void Clear() {
+ Entry invalid;
+ invalid.SetKey(invalid_);
std::fill(begin_, end_, invalid);
entries_ = 0;
}
+ // Return number of entries assuming no serialization went on.
+ std::size_t SizeNoSerialization() const {
+ return entries_;
+ }
+
+ // Return memory size expected by Double.
+ std::size_t DoubleTo() const {
+ return buckets_ * 2 * sizeof(Entry);
+ }
+
+ // Inform the table that it has double the amount of memory.
+ // Pass clear_new = false if you are sure the new memory is initialized
+ // properly (to invalid_) i.e. by mremap.
+ void Double(void *new_base, bool clear_new = true) {
+ begin_ = static_cast<MutableIterator>(new_base);
+ MutableIterator old_end = begin_ + buckets_;
+ buckets_ *= 2;
+ end_ = begin_ + buckets_;
+ if (clear_new) {
+ Entry invalid;
+ invalid.SetKey(invalid_);
+ std::fill(old_end, end_, invalid);
+ }
+ std::vector<Entry> rolled_over;
+ // Move roll-over entries to a buffer because they might not roll over anymore. This should be small.
+ for (MutableIterator i = begin_; i != old_end && !equal_(i->GetKey(), invalid_); ++i) {
+ rolled_over.push_back(*i);
+ i->SetKey(invalid_);
+ }
+ /* Re-insert everything. Entries might go backwards to take over a
+ * recently opened gap, stay, move to new territory, or wrap around. If
+ * an entry wraps around, it might go to a pointer greater than i (which
+ * can happen at the beginning) and it will be revisited to possibly fill
+ * in a gap created later.
+ */
+ Entry temp;
+ for (MutableIterator i = begin_; i != old_end; ++i) {
+ if (!equal_(i->GetKey(), invalid_)) {
+ temp = *i;
+ i->SetKey(invalid_);
+ UncheckedInsert(temp);
+ }
+ }
+ // Put the roll-over entries back in.
+ for (typename std::vector<Entry>::const_iterator i(rolled_over.begin()); i != rolled_over.end(); ++i) {
+ UncheckedInsert(*i);
+ }
+ }
+
+ // Mostly for tests, check consistency of every entry.
+ void CheckConsistency() {
+ MutableIterator last;
+ for (last = end_ - 1; last >= begin_ && !equal_(last->GetKey(), invalid_); --last) {}
+ UTIL_THROW_IF(last == begin_, ProbingSizeException, "Completely full");
+ MutableIterator i;
+ // Beginning can be wrap-arounds.
+ for (i = begin_; !equal_(i->GetKey(), invalid_); ++i) {
+ MutableIterator ideal = Ideal(*i);
+ UTIL_THROW_IF(ideal > i && ideal <= last, Exception, "Inconsistency at position " << (i - begin_) << " should be at " << (ideal - begin_));
+ }
+ MutableIterator pre_gap = i;
+ for (; i != end_; ++i) {
+ if (equal_(i->GetKey(), invalid_)) {
+ pre_gap = i;
+ continue;
+ }
+ MutableIterator ideal = Ideal(*i);
+ UTIL_THROW_IF(ideal > i || ideal <= pre_gap, Exception, "Inconsistency at position " << (i - begin_) << " with ideal " << (ideal - begin_));
+ }
+ }
+
private:
+ template <class T> MutableIterator Ideal(const T &t) {
+ return begin_ + (hash_(t.GetKey()) % buckets_);
+ }
+
+ template <class T> MutableIterator UncheckedInsert(const T &t) {
+ for (MutableIterator i(Ideal(t));;) {
+ if (equal_(i->GetKey(), invalid_)) { *i = t; return i; }
+ if (++i == end_) { i = begin_; }
+ }
+ }
+
MutableIterator begin_;
std::size_t buckets_;
MutableIterator end_;
diff --git a/klm/util/probing_hash_table_test.cc b/klm/util/probing_hash_table_test.cc
index be0fa859..9f7948ce 100644
--- a/klm/util/probing_hash_table_test.cc
+++ b/klm/util/probing_hash_table_test.cc
@@ -1,10 +1,14 @@
#include "util/probing_hash_table.hh"
+#include "util/murmur_hash.hh"
+#include "util/scoped.hh"
+
#define BOOST_TEST_MODULE ProbingHashTableTest
#include <boost/test/unit_test.hpp>
#include <boost/scoped_array.hpp>
#include <boost/functional/hash.hpp>
#include <stdio.h>
+#include <stdlib.h>
#include <string.h>
#include <stdint.h>
@@ -19,6 +23,10 @@ struct Entry {
return key;
}
+ void SetKey(unsigned char to) {
+ key = to;
+ }
+
uint64_t GetValue() const {
return value;
}
@@ -46,5 +54,49 @@ BOOST_AUTO_TEST_CASE(simple) {
BOOST_CHECK(!table.Find(2, i));
}
+struct Entry64 {
+ uint64_t key;
+ typedef uint64_t Key;
+
+ Entry64() {}
+
+ explicit Entry64(uint64_t key_in) {
+ key = key_in;
+ }
+
+ Key GetKey() const { return key; }
+ void SetKey(uint64_t to) { key = to; }
+};
+
+struct MurmurHashEntry64 {
+ std::size_t operator()(uint64_t value) const {
+ return util::MurmurHash64A(&value, 8);
+ }
+};
+
+typedef ProbingHashTable<Entry64, MurmurHashEntry64> Table64;
+
+BOOST_AUTO_TEST_CASE(Double) {
+ for (std::size_t initial = 19; initial < 30; ++initial) {
+ size_t size = Table64::Size(initial, 1.2);
+ scoped_malloc mem(MallocOrThrow(size));
+ Table64 table(mem.get(), size, std::numeric_limits<uint64_t>::max());
+ table.Clear();
+ for (uint64_t i = 0; i < 19; ++i) {
+ table.Insert(Entry64(i));
+ }
+ table.CheckConsistency();
+ mem.call_realloc(table.DoubleTo());
+ table.Double(mem.get());
+ table.CheckConsistency();
+ for (uint64_t i = 20; i < 40 ; ++i) {
+ table.Insert(Entry64(i));
+ }
+ mem.call_realloc(table.DoubleTo());
+ table.Double(mem.get());
+ table.CheckConsistency();
+ }
+}
+
} // namespace
} // namespace util
diff --git a/klm/util/read_compressed.cc b/klm/util/read_compressed.cc
index b81549e4..b62a6e83 100644
--- a/klm/util/read_compressed.cc
+++ b/klm/util/read_compressed.cc
@@ -180,12 +180,73 @@ class GZip : public ReadBase {
};
#endif // HAVE_ZLIB
+const uint8_t kBZMagic[3] = {'B', 'Z', 'h'};
+
#ifdef HAVE_BZLIB
class BZip : public ReadBase {
public:
- explicit BZip(int fd, void *already_data, std::size_t already_size) {
+ BZip(int fd, void *already_data, std::size_t already_size) {
scoped_fd hold(fd);
closer_.reset(FDOpenReadOrThrow(hold));
+ file_ = NULL;
+ Open(already_data, already_size);
+ }
+
+ BZip(FILE *file, void *already_data, std::size_t already_size) {
+ closer_.reset(file);
+ file_ = NULL;
+ Open(already_data, already_size);
+ }
+
+ ~BZip() {
+ Close(file_);
+ }
+
+ std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
+ assert(file_);
+ int bzerror = BZ_OK;
+ int ret = BZ2_bzRead(&bzerror, file_, to, std::min<std::size_t>(static_cast<std::size_t>(INT_MAX), amount));
+ long pos = ftell(closer_.get());
+ if (pos != -1) ReadCount(thunk) = pos;
+ switch (bzerror) {
+ case BZ_STREAM_END:
+ /* bzip2 files can be concatenated by e.g. pbzip2. Annoyingly, the
+ * library doesn't handle this internally. This gets the trailing
+ * data, grows it up to magic as needed, validates the magic, and
+ * reopens.
+ */
+ {
+ bzerror = BZ_OK;
+ void *trailing_data;
+ int trailing_size;
+ BZ2_bzReadGetUnused(&bzerror, file_, &trailing_data, &trailing_size);
+ UTIL_THROW_IF(bzerror != BZ_OK, BZException, "bzip2 error in BZ2_bzReadGetUnused " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror);
+ std::string trailing(static_cast<const char*>(trailing_data), trailing_size);
+ Close(file_);
+
+ if (trailing_size < (int)sizeof(kBZMagic)) {
+ trailing.resize(sizeof(kBZMagic));
+ if (1 != fread(&trailing[trailing_size], sizeof(kBZMagic) - trailing_size, 1, closer_.get())) {
+ UTIL_THROW_IF(trailing_size, BZException, "File has trailing cruft");
+ // Legitimate end of file.
+ ReplaceThis(new Complete(), thunk);
+ return ret;
+ }
+ }
+ UTIL_THROW_IF(memcmp(trailing.data(), kBZMagic, sizeof(kBZMagic)), BZException, "Trailing cruft is not another bzip2 stream");
+ Open(&trailing[0], trailing.size());
+ }
+ return ret;
+ case BZ_OK:
+ return ret;
+ default:
+ UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror);
+ }
+ }
+
+ private:
+ void Open(void *already_data, std::size_t already_size) {
+ assert(!file_);
int bzerror = BZ_OK;
file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size);
switch (bzerror) {
@@ -199,38 +260,23 @@ class BZip : public ReadBase {
UTIL_THROW(BZException, "IO error reading file");
case BZ_MEM_ERROR:
throw std::bad_alloc();
+ default:
+ UTIL_THROW(BZException, "Unknown bzip2 error code " << bzerror);
}
+ assert(file_);
}
- ~BZip() {
+ static void Close(BZFILE *&file) {
+ if (file == NULL) return;
int bzerror = BZ_OK;
- BZ2_bzReadClose(&bzerror, file_);
+ BZ2_bzReadClose(&bzerror, file);
if (bzerror != BZ_OK) {
- std::cerr << "bz2 readclose error" << std::endl;
+ std::cerr << "bz2 readclose error number " << bzerror << std::endl;
abort();
}
+ file = NULL;
}
- std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
- int bzerror = BZ_OK;
- int ret = BZ2_bzRead(&bzerror, file_, to, std::min<std::size_t>(static_cast<std::size_t>(INT_MAX), amount));
- long pos;
- switch (bzerror) {
- case BZ_STREAM_END:
- pos = ftell(closer_.get());
- if (pos != -1) ReadCount(thunk) = pos;
- ReplaceThis(new Complete(), thunk);
- return ret;
- case BZ_OK:
- pos = ftell(closer_.get());
- if (pos != -1) ReadCount(thunk) = pos;
- return ret;
- default:
- UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror);
- }
- }
-
- private:
scoped_FILE closer_;
BZFILE *file_;
};
@@ -346,11 +392,11 @@ MagicResult DetectMagic(const void *from_void) {
if (header[0] == 0x1f && header[1] == 0x8b) {
return GZIP;
}
- if (header[0] == 'B' && header[1] == 'Z' && header[2] == 'h') {
+ if (!memcmp(header, kBZMagic, sizeof(kBZMagic))) {
return BZIP;
}
- const uint8_t xzmagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 };
- if (!memcmp(header, xzmagic, 6)) {
+ const uint8_t kXZMagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 };
+ if (!memcmp(header, kXZMagic, sizeof(kXZMagic))) {
return XZIP;
}
return UNKNOWN;
diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc
index e7066ee4..6c5b0c2d 100644
--- a/klm/util/scoped.cc
+++ b/klm/util/scoped.cc
@@ -1,6 +1,9 @@
#include "util/scoped.hh"
#include <cstdlib>
+#if !defined(_WIN32) && !defined(_WIN64)
+#include <sys/mman.h>
+#endif
namespace util {
@@ -10,20 +13,31 @@ MallocException::MallocException(std::size_t requested) throw() {
MallocException::~MallocException() throw() {}
+namespace {
+void *InspectAddr(void *addr, std::size_t requested, const char *func_name) {
+ UTIL_THROW_IF_ARG(!addr && requested, MallocException, (requested), "in " << func_name);
+ // These routines are often used for large chunks of memory where huge pages help.
+#if MADV_HUGEPAGE
+ madvise(addr, requested, MADV_HUGEPAGE);
+#endif
+ return addr;
+}
+} // namespace
+
void *MallocOrThrow(std::size_t requested) {
- void *ret;
- UTIL_THROW_IF_ARG(!(ret = std::malloc(requested)), MallocException, (requested), "in malloc");
- return ret;
+ return InspectAddr(std::malloc(requested), requested, "malloc");
+}
+
+void *CallocOrThrow(std::size_t requested) {
+ return InspectAddr(std::calloc(1, requested), requested, "calloc");
}
scoped_malloc::~scoped_malloc() {
std::free(p_);
}
-void scoped_malloc::call_realloc(std::size_t to) {
- void *ret;
- UTIL_THROW_IF_ARG(!(ret = std::realloc(p_, to)) && to, MallocException, (to), "in realloc");
- p_ = ret;
+void scoped_malloc::call_realloc(std::size_t requested) {
+ p_ = InspectAddr(std::realloc(p_, requested), requested, "realloc");
}
} // namespace util
diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh
index d0a5aabd..b642d064 100644
--- a/klm/util/scoped.hh
+++ b/klm/util/scoped.hh
@@ -14,6 +14,7 @@ class MallocException : public ErrnoException {
};
void *MallocOrThrow(std::size_t requested);
+void *CallocOrThrow(std::size_t requested);
class scoped_malloc {
public:
diff --git a/klm/util/sized_iterator.hh b/klm/util/sized_iterator.hh
index aabcc531..cf998953 100644
--- a/klm/util/sized_iterator.hh
+++ b/klm/util/sized_iterator.hh
@@ -3,6 +3,7 @@
#include "util/proxy_iterator.hh"
+#include <algorithm>
#include <functional>
#include <string>
@@ -63,6 +64,13 @@ class SizedProxy {
const void *Data() const { return inner_.Data(); }
void *Data() { return inner_.Data(); }
+ friend void swap(SizedProxy &first, SizedProxy &second) {
+ std::swap_ranges(
+ static_cast<char*>(first.inner_.Data()),
+ static_cast<char*>(first.inner_.Data()) + first.inner_.EntrySize(),
+ static_cast<char*>(second.inner_.Data()));
+ }
+
private:
friend class util::ProxyIterator<SizedProxy>;
diff --git a/klm/util/usage.cc b/klm/util/usage.cc
index b8e125d0..ad4dc7b4 100644
--- a/klm/util/usage.cc
+++ b/klm/util/usage.cc
@@ -22,6 +22,11 @@ float FloatSec(const struct timeval &tv) {
return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000.0);
}
#endif
+
+const char *SkipSpaces(const char *at) {
+ for (; *at == ' '; ++at) {}
+ return at;
+}
} // namespace
void PrintUsage(std::ostream &out) {
@@ -32,18 +37,19 @@ void PrintUsage(std::ostream &out) {
return;
}
out << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n';
-
+ out << "CPU\t" << (FloatSec(usage.ru_utime) + FloatSec(usage.ru_stime)) << '\n';
// Linux doesn't set memory usage :-(.
std::ifstream status("/proc/self/status", std::ios::in);
std::string line;
while (getline(status, line)) {
if (!strncmp(line.c_str(), "VmRSS:\t", 7)) {
- out << "VmRSS: " << (line.c_str() + 7) << '\n';
+ out << "RSSCur\t" << SkipSpaces(line.c_str() + 7) << '\n';
break;
} else if (!strncmp(line.c_str(), "VmPeak:\t", 8)) {
- out << "VmPeak: " << (line.c_str() + 8) << '\n';
+ out << "VmPeak\t" << SkipSpaces(line.c_str() + 8) << '\n';
}
}
+ out << "RSSMax\t" << usage.ru_maxrss << " kB" << '\n';
#endif
}