diff options
Diffstat (limited to 'klm/lm/builder/corpus_count.cc')
-rw-r--r-- | klm/lm/builder/corpus_count.cc | 82 |
1 files changed, 59 insertions, 23 deletions
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc index abea4ed0..aea93ad1 100644 --- a/klm/lm/builder/corpus_count.cc +++ b/klm/lm/builder/corpus_count.cc @@ -3,6 +3,7 @@ #include "lm/builder/ngram.hh" #include "lm/lm_exception.hh" #include "lm/word_index.hh" +#include "util/fake_ofstream.hh" #include "util/file.hh" #include "util/file_piece.hh" #include "util/murmur_hash.hh" @@ -23,39 +24,71 @@ namespace lm { namespace builder { namespace { +#pragma pack(push) +#pragma pack(4) +struct VocabEntry { + typedef uint64_t Key; + + uint64_t GetKey() const { return key; } + void SetKey(uint64_t to) { key = to; } + + uint64_t key; + lm::WordIndex value; +}; +#pragma pack(pop) + +const float kProbingMultiplier = 1.5; + class VocabHandout { public: - explicit VocabHandout(int fd) { - util::scoped_fd duped(util::DupOrThrow(fd)); - word_list_.reset(util::FDOpenOrThrow(duped)); - + static std::size_t MemUsage(WordIndex initial_guess) { + if (initial_guess < 2) initial_guess = 2; + return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier)); + } + + explicit VocabHandout(int fd, WordIndex initial_guess) : + table_backing_(util::CallocOrThrow(MemUsage(initial_guess))), + table_(table_backing_.get(), MemUsage(initial_guess)), + double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)), + word_list_(fd) { Lookup("<unk>"); // Force 0 Lookup("<s>"); // Force 1 Lookup("</s>"); // Force 2 } WordIndex Lookup(const StringPiece &word) { - uint64_t hashed = util::MurmurHashNative(word.data(), word.size()); - std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size()))); - if (ret.second) { - char null_delimit = 0; - util::WriteOrThrow(word_list_.get(), word.data(), word.size()); - util::WriteOrThrow(word_list_.get(), &null_delimit, 1); - UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); + VocabEntry entry; + entry.key = util::MurmurHashNative(word.data(), word.size()); + entry.value = table_.SizeNoSerialization(); + + Table::MutableIterator it; + if (table_.FindOrInsert(entry, it)) + return it->value; + word_list_ << word << '\0'; + UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); + if (Size() >= double_cutoff_) { + table_backing_.call_realloc(table_.DoubleTo()); + table_.Double(table_backing_.get()); + double_cutoff_ *= 2; } - return ret.first->second; + return entry.value; } WordIndex Size() const { - return seen_.size(); + return table_.SizeNoSerialization(); } private: - typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen; + // TODO: factor out a resizable probing hash table. + // TODO: use mremap on linux to get all zeros on resizes. + util::scoped_malloc table_backing_; - Seen seen_; + typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table; + Table table_; - util::scoped_FILE word_list_; + std::size_t double_cutoff_; + + util::FakeOFStream word_list_; }; class DedupeHash : public std::unary_function<const WordIndex *, bool> { @@ -85,6 +118,7 @@ class DedupeEquals : public std::binary_function<const WordIndex *, const WordIn struct DedupeEntry { typedef WordIndex *Key; Key GetKey() const { return key; } + void SetKey(WordIndex *to) { key = to; } Key key; static DedupeEntry Construct(WordIndex *at) { DedupeEntry ret; @@ -95,8 +129,6 @@ struct DedupeEntry { typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe; -const float kProbingMultiplier = 1.5; - class Writer { public: Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size) @@ -105,7 +137,7 @@ class Writer { dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)), buffer_(new WordIndex[order - 1]), block_size_(position.GetChain().BlockSize()) { - dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); + dedupe_.Clear(); assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size); if (order == 1) { // Add special words. AdjustCounts is responsible if order != 1. @@ -149,7 +181,7 @@ class Writer { } // Block end. Need to store the context in a temporary buffer. std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); - dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); + dedupe_.Clear(); block_->SetValidSize(block_size_); gram_.ReBase((++block_)->Get()); std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin()); @@ -187,18 +219,22 @@ float CorpusCount::DedupeMultiplier(std::size_t order) { return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order)); } +std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { + return VocabHandout::MemUsage(vocab_estimate); +} + CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { - token_count_ = 0; - type_count_ = 0; } void CorpusCount::Run(const util::stream::ChainPosition &position) { UTIL_TIMER("(%w s) Counted n-grams\n"); - VocabHandout vocab(vocab_write_); + VocabHandout vocab(vocab_write_, type_count_); + token_count_ = 0; + type_count_ = 0; const WordIndex end_sentence = vocab.Lookup("</s>"); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; |