From 5aee54869aa19cfe9be965e67a472e94449d16da Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 24 Apr 2013 10:12:41 +0100 Subject: KenLM 0831569c3137536165b107c6841603c725dfa2b1 --- klm/lm/builder/corpus_count.cc | 82 ++++++++++++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 23 deletions(-) (limited to 'klm/lm/builder/corpus_count.cc') diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc index abea4ed0..aea93ad1 100644 --- a/klm/lm/builder/corpus_count.cc +++ b/klm/lm/builder/corpus_count.cc @@ -3,6 +3,7 @@ #include "lm/builder/ngram.hh" #include "lm/lm_exception.hh" #include "lm/word_index.hh" +#include "util/fake_ofstream.hh" #include "util/file.hh" #include "util/file_piece.hh" #include "util/murmur_hash.hh" @@ -23,39 +24,71 @@ namespace lm { namespace builder { namespace { +#pragma pack(push) +#pragma pack(4) +struct VocabEntry { + typedef uint64_t Key; + + uint64_t GetKey() const { return key; } + void SetKey(uint64_t to) { key = to; } + + uint64_t key; + lm::WordIndex value; +}; +#pragma pack(pop) + +const float kProbingMultiplier = 1.5; + class VocabHandout { public: - explicit VocabHandout(int fd) { - util::scoped_fd duped(util::DupOrThrow(fd)); - word_list_.reset(util::FDOpenOrThrow(duped)); - + static std::size_t MemUsage(WordIndex initial_guess) { + if (initial_guess < 2) initial_guess = 2; + return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier)); + } + + explicit VocabHandout(int fd, WordIndex initial_guess) : + table_backing_(util::CallocOrThrow(MemUsage(initial_guess))), + table_(table_backing_.get(), MemUsage(initial_guess)), + double_cutoff_(std::max(initial_guess * 1.1, 1)), + word_list_(fd) { Lookup(""); // Force 0 Lookup(""); // Force 1 Lookup(""); // Force 2 } WordIndex Lookup(const StringPiece &word) { - uint64_t hashed = util::MurmurHashNative(word.data(), word.size()); - std::pair ret(seen_.insert(std::pair(hashed, seen_.size()))); - if (ret.second) { - char null_delimit = 0; - util::WriteOrThrow(word_list_.get(), word.data(), word.size()); - util::WriteOrThrow(word_list_.get(), &null_delimit, 1); - UTIL_THROW_IF(seen_.size() >= std::numeric_limits::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); + VocabEntry entry; + entry.key = util::MurmurHashNative(word.data(), word.size()); + entry.value = table_.SizeNoSerialization(); + + Table::MutableIterator it; + if (table_.FindOrInsert(entry, it)) + return it->value; + word_list_ << word << '\0'; + UTIL_THROW_IF(Size() >= std::numeric_limits::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); + if (Size() >= double_cutoff_) { + table_backing_.call_realloc(table_.DoubleTo()); + table_.Double(table_backing_.get()); + double_cutoff_ *= 2; } - return ret.first->second; + return entry.value; } WordIndex Size() const { - return seen_.size(); + return table_.SizeNoSerialization(); } private: - typedef boost::unordered_map Seen; + // TODO: factor out a resizable probing hash table. + // TODO: use mremap on linux to get all zeros on resizes. + util::scoped_malloc table_backing_; - Seen seen_; + typedef util::ProbingHashTable Table; + Table table_; - util::scoped_FILE word_list_; + std::size_t double_cutoff_; + + util::FakeOFStream word_list_; }; class DedupeHash : public std::unary_function { @@ -85,6 +118,7 @@ class DedupeEquals : public std::binary_function Dedupe; -const float kProbingMultiplier = 1.5; - class Writer { public: Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size) @@ -105,7 +137,7 @@ class Writer { dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)), buffer_(new WordIndex[order - 1]), block_size_(position.GetChain().BlockSize()) { - dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); + dedupe_.Clear(); assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size); if (order == 1) { // Add special words. AdjustCounts is responsible if order != 1. @@ -149,7 +181,7 @@ class Writer { } // Block end. Need to store the context in a temporary buffer. std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); - dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); + dedupe_.Clear(); block_->SetValidSize(block_size_); gram_.ReBase((++block_)->Get()); std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin()); @@ -187,18 +219,22 @@ float CorpusCount::DedupeMultiplier(std::size_t order) { return kProbingMultiplier * static_cast(sizeof(DedupeEntry)) / static_cast(NGram::TotalSize(order)); } +std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { + return VocabHandout::MemUsage(vocab_estimate); +} + CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { - token_count_ = 0; - type_count_ = 0; } void CorpusCount::Run(const util::stream::ChainPosition &position) { UTIL_TIMER("(%w s) Counted n-grams\n"); - VocabHandout vocab(vocab_write_); + VocabHandout vocab(vocab_write_, type_count_); + token_count_ = 0; + type_count_ = 0; const WordIndex end_sentence = vocab.Lookup(""); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; -- cgit v1.2.3