diff options
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 369 |
1 files changed, 272 insertions, 97 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 12294682..3aeeeca3 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,6 +1,7 @@ /* This is where the trie is built. It's on-disk. */ #include "lm/search_trie.hh" +#include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/read_arpa.hh" #include "lm/trie.hh" @@ -13,10 +14,10 @@ #include "util/scoped.hh" #include <algorithm> +#include <cmath> #include <cstring> #include <cstdio> #include <deque> -#include <iostream> #include <limits> //#include <parallel/algorithm> #include <vector> @@ -152,11 +153,11 @@ void ReadOrThrow(FILE *from, void *data, size_t size) { if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size); } +const std::size_t kCopyBufSize = 512; void CopyOrThrow(FILE *from, FILE *to, size_t size) { - const size_t kBufSize = 512; - char buf[kBufSize]; - for (size_t i = 0; i < size; i += kBufSize) { - std::size_t amount = std::min(size - i, kBufSize); + char buf[std::min<size_t>(size, kCopyBufSize)]; + for (size_t i = 0; i < size; i += kCopyBufSize) { + std::size_t amount = std::min(size - i, kCopyBufSize); ReadOrThrow(from, buf, amount); WriteOrThrow(to, buf, amount); } @@ -172,8 +173,10 @@ std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::str if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing"); // Compress entries that being with the same (order-1) words. for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) { - const uint8_t *group_end = group_begin; - for (group_end += entry_size; (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size); group_end += entry_size) {} + const uint8_t *group_end; + for (group_end = group_begin + entry_size; + (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size); + group_end += entry_size) {} WriteOrThrow(out.get(), group_begin, prefix_size); WordIndex group_size = (group_end - group_begin) / entry_size; WriteOrThrow(out.get(), &group_size, sizeof(group_size)); @@ -188,7 +191,7 @@ std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::str class SortedFileReader { public: - SortedFileReader() {} + SortedFileReader() : ended_(false) {} void Init(const std::string &name, unsigned char order) { file_.reset(fopen(name.c_str(), "r")); @@ -206,25 +209,39 @@ class SortedFileReader { std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); } void NextHeader() { - if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get()) && !Ended()) { - UTIL_THROW(util::ErrnoException, "Short read of counts"); + if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get())) { + if (feof(file_.get())) { + ended_ = true; + } else { + UTIL_THROW(util::ErrnoException, "Short read of counts"); + } } } - void ReadCount(WordIndex &to) { - ReadOrThrow(file_.get(), &to, sizeof(WordIndex)); + WordIndex ReadCount() { + WordIndex ret; + ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); + return ret; } - void ReadWord(WordIndex &to) { - ReadOrThrow(file_.get(), &to, sizeof(WordIndex)); + WordIndex ReadWord() { + WordIndex ret; + ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); + return ret; } - template <class Weights> void ReadWeights(Weights &to) { - ReadOrThrow(file_.get(), &to, sizeof(Weights)); + template <class Weights> void ReadWeights(Weights &weights) { + ReadOrThrow(file_.get(), &weights, sizeof(Weights)); } bool Ended() { - return feof(file_.get()); + return ended_; + } + + void Rewind() { + rewind(file_.get()); + ended_ = false; + NextHeader(); } FILE *File() { return file_.get(); } @@ -233,12 +250,13 @@ class SortedFileReader { util::scoped_FILE file_; std::vector<WordIndex> header_; + + bool ended_; }; void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) { WriteOrThrow(to, from.Header(), from.HeaderBytes()); - WordIndex count; - from.ReadCount(count); + WordIndex count = from.ReadCount(); WriteOrThrow(to, &count, sizeof(WordIndex)); CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); @@ -263,25 +281,23 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha } // Merge at the entry level. WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes()); - WordIndex first_count, second_count; - first.ReadCount(first_count); second.ReadCount(second_count); + WordIndex first_count = first.ReadCount(), second_count = second.ReadCount(); WordIndex total_count = first_count + second_count; WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex)); - WordIndex first_word, second_word; - first.ReadWord(first_word); second.ReadWord(second_word); + WordIndex first_word = first.ReadWord(), second_word = second.ReadWord(); WordIndex first_index = 0, second_index = 0; while (true) { if (first_word < second_word) { WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); CopyOrThrow(first.File(), out_file.get(), weights_size); if (++first_index == first_count) break; - first.ReadWord(first_word); + first_word = first.ReadWord(); } else { WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); CopyOrThrow(second.File(), out_file.get(), weights_size); if (++second_index == second_count) break; - second.ReadWord(second_word); + second_word = second.ReadWord(); } } if (first_index == first_count) { @@ -358,75 +374,219 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, { std::string unigram_name = file_prefix + "unigrams"; util::scoped_fd unigram_file; - util::scoped_mmap unigram_mmap; - unigram_mmap.reset(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get())); } util::scoped_memory mem; - mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED); + mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size()); ReadEnd(f); } -struct RecursiveInsertParams { - WordIndex *words; - SortedFileReader *files; - unsigned char max_order; - // This is an array of size order - 2. - BitPackedMiddle *middle; - // This has exactly one entry. - BitPackedLongest *longest; -}; - -uint64_t RecursiveInsert(RecursiveInsertParams ¶ms, unsigned char order) { - SortedFileReader &file = params.files[order - 2]; - const uint64_t ret = (order == params.max_order) ? params.longest->InsertIndex() : params.middle[order - 2].InsertIndex(); - if (std::memcmp(params.words, file.Header(), sizeof(WordIndex) * (order - 1))) - return ret; - WordIndex count; - file.ReadCount(count); - WordIndex key; - if (order == params.max_order) { - Prob value; - for (WordIndex i = 0; i < count; ++i) { - file.ReadWord(key); - file.ReadWeights(value); - params.longest->Insert(key, value.prob); - } - file.NextHeader(); - return ret; - } - ProbBackoff value; - for (WordIndex i = 0; i < count; ++i) { - file.ReadWord(params.words[order - 1]); - file.ReadWeights(value); - params.middle[order - 2].Insert( - params.words[order - 1], - value.prob, - value.backoff, - RecursiveInsert(params, order + 1)); +bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) { + for (; words != words_end; ++words, ++header) { + if (*words != *header) { + assert(*words <= *header); + return false; + } } - file.NextHeader(); - return ret; + return true; } -void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, std::ostream *messages, TrieSearch &out) { - UnigramValue *unigrams = out.unigram.Raw(); - // Load unigrams. Leave the next pointers uninitialized. - { - std::string name(file_prefix + "unigrams"); - util::scoped_FILE file(fopen(name.c_str(), "r")); - if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed"); - for (WordIndex i = 0; i < counts[0]; ++i) { - ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); +class JustCount { + public: + JustCount(UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) + : counts_(counts), longest_counts_(counts + order - 1) {} + + void Unigrams(WordIndex begin, WordIndex end) { + counts_[0] += end - begin; } - unlink(name.c_str()); + + void MiddleBlank(const unsigned char mid_idx, WordIndex /* idx */) { + ++counts_[mid_idx + 1]; + } + + void Middle(const unsigned char mid_idx, WordIndex /*key*/, const ProbBackoff &/*weights*/) { + ++counts_[mid_idx + 1]; + } + + void Longest(WordIndex /*key*/, Prob /*prob*/) { + ++*longest_counts_; + } + + // Unigrams wrote one past. + void Cleanup() { + --counts_[0]; + } + + private: + uint64_t *const counts_, *const longest_counts_; +}; + +class WriteEntries { + public: + WriteEntries(UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + unigrams_(unigrams), + middle_(middle), + longest_(longest), + bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)) {} + + void Unigrams(WordIndex begin, WordIndex end) { + uint64_t next = bigram_pack_.InsertIndex(); + for (UnigramValue *i = unigrams_ + begin; i < unigrams_ + end; ++i) { + i->next = next; + } + } + + void MiddleBlank(const unsigned char mid_idx, WordIndex key) { + middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff); + } + + void Middle(const unsigned char mid_idx, WordIndex key, const ProbBackoff &weights) { + middle_[mid_idx].Insert(key, weights.prob, weights.backoff); + } + + void Longest(WordIndex key, Prob prob) { + longest_.Insert(key, prob.prob); + } + + void Cleanup() {} + + private: + UnigramValue *const unigrams_; + BitPackedMiddle *const middle_; + BitPackedLongest &longest_; + BitPacked &bigram_pack_; +}; + +template <class Doing> class RecursiveInsert { + public: + RecursiveInsert(SortedFileReader *inputs, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : + doing_(unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), words_(new WordIndex[order]), order_minus_2_(order - 2) { + } + + // Outer unigram loop. + void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) { + util::ErsatzProgress progress(progress_out, message, unigram_count + 1); + for (words_[0] = 0; ; ++words_[0]) { + WordIndex min_continue = unigram_count; + for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) { + if (other->Ended()) continue; + min_continue = std::min(min_continue, other->Header()[0]); + } + // This will write at unigram_count. This is by design so that the next pointers will make sense. + doing_.Unigrams(words_[0], min_continue + 1); + if (min_continue == unigram_count) break; + progress += min_continue - words_[0]; + words_[0] = min_continue; + Middle(0); + } + doing_.Cleanup(); + } + + private: + void Middle(const unsigned char mid_idx) { + // (mid_idx + 2)-gram. + if (mid_idx == order_minus_2_) { + Longest(); + return; + } + // Orders [2, order) + + SortedFileReader &reader = inputs_[mid_idx]; + + if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + mid_idx + 1, reader.Header())) { + // This order doesn't have a header match, but longer ones might. + MiddleAllBlank(mid_idx); + return; + } + + // There is a header match. + WordIndex count = reader.ReadCount(); + WordIndex current = reader.ReadWord(); + while (count) { + WordIndex min_continue = std::numeric_limits<WordIndex>::max(); + for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { + if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header())) + min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); + } + while (true) { + if (current > min_continue) { + doing_.MiddleBlank(mid_idx, min_continue); + words_[mid_idx + 1] = min_continue; + Middle(mid_idx + 1); + break; + } + ProbBackoff weights; + reader.ReadWeights(weights); + doing_.Middle(mid_idx, current, weights); + --count; + if (current == min_continue) { + words_[mid_idx + 1] = min_continue; + Middle(mid_idx + 1); + if (count) current = reader.ReadWord(); + break; + } + if (!count) break; + current = reader.ReadWord(); + } + } + // Count is now zero. Finish off remaining blanks. + MiddleAllBlank(mid_idx); + reader.NextHeader(); + } + + void MiddleAllBlank(const unsigned char mid_idx) { + while (true) { + WordIndex min_continue = std::numeric_limits<WordIndex>::max(); + for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { + if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header())) + min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); + } + if (min_continue == std::numeric_limits<WordIndex>::max()) return; + doing_.MiddleBlank(mid_idx, min_continue); + words_[mid_idx + 1] = min_continue; + Middle(mid_idx + 1); + } + } + + void Longest() { + SortedFileReader &reader = *(inputs_end_ - 1); + if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + order_minus_2_ + 1, reader.Header())) return; + WordIndex count = reader.ReadCount(); + for (WordIndex i = 0; i < count; ++i) { + WordIndex word = reader.ReadWord(); + Prob prob; + reader.ReadWeights(prob); + doing_.Longest(word, prob); + } + reader.NextHeader(); + return; + } + + Doing doing_; + + SortedFileReader *inputs_; + SortedFileReader *inputs_end_; + + util::scoped_array<WordIndex> words_; + + const unsigned char order_minus_2_; +}; + +void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) { + if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]); + if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant"); + for (unsigned char i = 0; i < initial.size(); ++i) { + if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected. This shouldn't happen"); } +} - // inputs[0] is bigrams. +void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) { SortedFileReader inputs[counts.size() - 1]; + for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; assembled << file_prefix << static_cast<unsigned int>(i) << "_merged"; @@ -434,36 +594,49 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun unlink(assembled.str().c_str()); } - // words[0] is unigrams. - WordIndex words[counts.size()]; - RecursiveInsertParams params; - params.words = words; - params.files = inputs; - params.max_order = static_cast<unsigned char>(counts.size()); - params.middle = &*out.middle.begin(); - params.longest = &out.longest; + std::vector<uint64_t> fixed_counts(counts.size()); + { + RecursiveInsert<JustCount> counter(inputs, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); + } + SanityCheckCounts(counts, fixed_counts); + + out.SetupMemory(GrowForSearch(config, TrieSearch::kModelType, fixed_counts, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + + for (unsigned char i = 2; i <= counts.size(); ++i) { + inputs[i-2].Rewind(); + } + UnigramValue *unigrams = out.unigram.Raw(); + // Fill entries except unigram probabilities. { - util::ErsatzProgress progress(messages, "Building trie", counts[0]); - for (words[0] = 0; words[0] < counts[0]; ++words[0], ++progress) { - unigrams[words[0]].next = RecursiveInsert(params, 2); + RecursiveInsert<WriteEntries> inserter(inputs, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + inserter.Apply(config.messages, "Building trie", fixed_counts[0]); + } + + // Fill unigram probabilities. + { + std::string name(file_prefix + "unigrams"); + util::scoped_FILE file(fopen(name.c_str(), "r")); + if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed"); + for (WordIndex i = 0; i < counts[0]; ++i) { + ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); } + unlink(name.c_str()); } /* Set ending offsets so the last entry will be sized properly */ + // Last entry for unigrams was already set. if (!out.middle.empty()) { - unigrams[counts[0]].next = out.middle.front().InsertIndex(); for (size_t i = 0; i < out.middle.size() - 1; ++i) { out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); } out.middle.back().FinishedLoading(out.longest.InsertIndex()); - } else { - unigrams[counts[0]].next = out.longest.InsertIndex(); } } } // namespace -void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab) { +void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -473,7 +646,8 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const temporary_directory = file; } // Null on end is kludge to ensure null termination. - temporary_directory += "-tmp-XXXXXX\0"; + temporary_directory += "-tmp-XXXXXX"; + temporary_directory += '\0'; if (!mkdtemp(&temporary_directory[0])) { UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str()); } @@ -483,9 +657,10 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const temporary_directory += '/'; // At least 1MB sorting memory. ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory.c_str(), counts, config.messages, *this); - if (rmdir(temporary_directory.c_str())) { - std::cerr << "Failed to delete " << temporary_directory << std::endl; + + BuildTrie(temporary_directory.c_str(), counts, config, *this, backing); + if (rmdir(temporary_directory.c_str()) && config.messages) { + *config.messages << "Failed to delete " << temporary_directory << std::endl; } } |