#include "lm/builder/corpus_count.hh" #include "lm/builder/ngram.hh" #include "lm/lm_exception.hh" #include "lm/word_index.hh" #include "util/fake_ofstream.hh" #include "util/file.hh" #include "util/file_piece.hh" #include "util/murmur_hash.hh" #include "util/probing_hash_table.hh" #include "util/scoped.hh" #include "util/stream/chain.hh" #include "util/stream/timer.hh" #include "util/tokenize_piece.hh" #include <boost/unordered_set.hpp> #include <boost/unordered_map.hpp> #include <functional> #include <stdint.h> namespace lm { namespace builder { namespace { #pragma pack(push) #pragma pack(4) struct VocabEntry { typedef uint64_t Key; uint64_t GetKey() const { return key; } void SetKey(uint64_t to) { key = to; } uint64_t key; lm::WordIndex value; }; #pragma pack(pop) const float kProbingMultiplier = 1.5; class VocabHandout { public: static std::size_t MemUsage(WordIndex initial_guess) { if (initial_guess < 2) initial_guess = 2; return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier)); } explicit VocabHandout(int fd, WordIndex initial_guess) : table_backing_(util::CallocOrThrow(MemUsage(initial_guess))), table_(table_backing_.get(), MemUsage(initial_guess)), double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)), word_list_(fd) { Lookup("<unk>"); // Force 0 Lookup("<s>"); // Force 1 Lookup("</s>"); // Force 2 } WordIndex Lookup(const StringPiece &word) { VocabEntry entry; entry.key = util::MurmurHashNative(word.data(), word.size()); entry.value = table_.SizeNoSerialization(); Table::MutableIterator it; if (table_.FindOrInsert(entry, it)) return it->value; word_list_ << word << '\0'; UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); if (Size() >= double_cutoff_) { table_backing_.call_realloc(table_.DoubleTo()); table_.Double(table_backing_.get()); double_cutoff_ *= 2; } return entry.value; } WordIndex Size() const { return table_.SizeNoSerialization(); } private: // TODO: factor out a resizable probing hash table. // TODO: use mremap on linux to get all zeros on resizes. util::scoped_malloc table_backing_; typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table; Table table_; std::size_t double_cutoff_; util::FakeOFStream word_list_; }; class DedupeHash : public std::unary_function<const WordIndex *, bool> { public: explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} std::size_t operator()(const WordIndex *start) const { return util::MurmurHashNative(start, size_); } private: const std::size_t size_; }; class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> { public: explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {} bool operator()(const WordIndex *first, const WordIndex *second) const { return !memcmp(first, second, size_); } private: const std::size_t size_; }; struct DedupeEntry { typedef WordIndex *Key; Key GetKey() const { return key; } void SetKey(WordIndex *to) { key = to; } Key key; static DedupeEntry Construct(WordIndex *at) { DedupeEntry ret; ret.key = at; return ret; } }; typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe; class Writer { public: Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size) : block_(position), gram_(block_->Get(), order), dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()), dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)), buffer_(new WordIndex[order - 1]), block_size_(position.GetChain().BlockSize()) { dedupe_.Clear(); assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size); if (order == 1) { // Add special words. AdjustCounts is responsible if order != 1. AddUnigramWord(kUNK); AddUnigramWord(kBOS); } } ~Writer() { block_->SetValidSize(reinterpret_cast<const uint8_t*>(gram_.begin()) - static_cast<const uint8_t*>(block_->Get())); (++block_).Poison(); } // Write context with a bunch of <s> void StartSentence() { for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) { *i = kBOS; } } void Append(WordIndex word) { *(gram_.end() - 1) = word; Dedupe::MutableIterator at; bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at); if (found) { // Already present. NGram already(at->key, gram_.Order()); ++(already.Count()); // Shift left by one. memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1)); return; } // Complete the write. gram_.Count() = 1; // Prepare the next n-gram. if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) { NGram last(gram_); gram_.NextInMemory(); std::copy(last.begin() + 1, last.end(), gram_.begin()); return; } // Block end. Need to store the context in a temporary buffer. std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); dedupe_.Clear(); block_->SetValidSize(block_size_); gram_.ReBase((++block_)->Get()); std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin()); } private: void AddUnigramWord(WordIndex index) { *gram_.begin() = index; gram_.Count() = 0; gram_.NextInMemory(); if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) { block_->SetValidSize(block_size_); gram_.ReBase((++block_)->Get()); } } util::stream::Link block_; NGram gram_; // This is the memory behind the invalid value in dedupe_. std::vector<WordIndex> dedupe_invalid_; // Hash table combiner implementation. Dedupe dedupe_; // Small buffer to hold existing ngrams when shifting across a block boundary. boost::scoped_array<WordIndex> buffer_; const std::size_t block_size_; }; } // namespace float CorpusCount::DedupeMultiplier(std::size_t order) { return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order)); } std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { return VocabHandout::MemUsage(vocab_estimate); } CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { } void CorpusCount::Run(const util::stream::ChainPosition &position) { UTIL_TIMER("(%w s) Counted n-grams\n"); VocabHandout vocab(vocab_write_, type_count_); token_count_ = 0; type_count_ = 0; const WordIndex end_sentence = vocab.Lookup("</s>"); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; bool delimiters[256]; memset(delimiters, 0, sizeof(delimiters)); const char kDelimiterSet[] = "\0\t\n\r "; for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) { delimiters[static_cast<unsigned char>(*i)] = true; } try { while(true) { StringPiece line(from_.ReadLine()); writer.StartSentence(); for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) { WordIndex word = vocab.Lookup(*w); UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future."); writer.Append(word); ++count; } writer.Append(end_sentence); } } catch (const util::EndOfFileException &e) {} token_count_ = count; type_count_ = vocab.Size(); } } // namespace builder } // namespace lm