diff options
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 200 |
1 files changed, 146 insertions, 54 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 182e27f5..12294682 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,3 +1,4 @@ +/* This is where the trie is built. It's on-disk. */ #include "lm/search_trie.hh" #include "lm/lm_exception.hh" @@ -8,6 +9,7 @@ #include "lm/word_index.hh" #include "util/ersatz_progress.hh" #include "util/file_piece.hh" +#include "util/proxy_iterator.hh" #include "util/scoped.hh" #include <algorithm> @@ -30,43 +32,119 @@ namespace ngram { namespace trie { namespace { -template <unsigned char Order> class FullEntry { +/* An entry is a n-gram with probability. It consists of: + * WordIndex[order] + * float probability + * backoff probability (omitted for highest order n-gram) + * These are stored consecutively in memory. We want to sort them. + * + * The problem is the length depends on order (but all n-grams being compared + * have the same order). Allocating each entry on the heap (i.e. std::vector + * or std::string) then sorting pointers is the normal solution. But that's + * too memory inefficient. A lot of this code is just here to force std::sort + * to work with records where length is specified at runtime (and avoid using + * Boost for LM code). I could have used qsort, but the point is to also + * support __gnu_cxx:parallel_sort which doesn't have a qsort version. + */ + +class EntryIterator { public: - typedef ProbBackoff Weights; - static const unsigned char kOrder = Order; + EntryIterator() {} - // reverse order - WordIndex words[Order]; - Weights weights; + EntryIterator(void *ptr, std::size_t size) : ptr_(static_cast<uint8_t*>(ptr)), size_(size) {} - bool operator<(const FullEntry<Order> &other) const { - for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { - if (*i < *j) return true; - if (*i > *j) return false; - } - return false; + bool operator==(const EntryIterator &other) const { + return ptr_ == other.ptr_; + } + bool operator<(const EntryIterator &other) const { + return ptr_ < other.ptr_; + } + EntryIterator &operator+=(std::ptrdiff_t amount) { + ptr_ += amount * size_; + return *this; + } + std::ptrdiff_t operator-(const EntryIterator &other) const { + return (ptr_ - other.ptr_) / size_; } + + const void *Data() const { return ptr_; } + void *Data() { return ptr_; } + std::size_t EntrySize() const { return size_; } + + private: + uint8_t *ptr_; + std::size_t size_; }; -template <unsigned char Order> class ProbEntry { +class EntryProxy { public: - typedef Prob Weights; - static const unsigned char kOrder = Order; + EntryProxy() {} + + EntryProxy(void *ptr, std::size_t size) : inner_(ptr, size) {} + + operator std::string() const { + return std::string(reinterpret_cast<const char*>(inner_.Data()), inner_.EntrySize()); + } + + EntryProxy &operator=(const EntryProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize()); + return *this; + } + + EntryProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), inner_.EntrySize()); + return *this; + } + + const WordIndex *Indices() const { + return static_cast<const WordIndex*>(inner_.Data()); + } + + private: + friend class util::ProxyIterator<EntryProxy>; + + typedef std::string value_type; - // reverse order - WordIndex words[Order]; - Weights weights; + typedef EntryIterator InnerIterator; + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; - bool operator<(const ProbEntry<Order> &other) const { - for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { - if (*i < *j) return true; - if (*i > *j) return false; +typedef util::ProxyIterator<EntryProxy> NGramIter; + +class CompareRecords : public std::binary_function<const EntryProxy &, const EntryProxy &, bool> { + public: + explicit CompareRecords(unsigned char order) : order_(order) {} + + bool operator()(const EntryProxy &first, const EntryProxy &second) const { + return Compare(first.Indices(), second.Indices()); + } + bool operator()(const EntryProxy &first, const std::string &second) const { + return Compare(first.Indices(), reinterpret_cast<const WordIndex*>(second.data())); + } + bool operator()(const std::string &first, const EntryProxy &second) const { + return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices()); + } + bool operator()(const std::string &first, const std::string &second) const { + return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(first.data())); + } + + private: + bool Compare(const WordIndex *first, const WordIndex *second) const { + const WordIndex *end = first + order_; + for (; first != end; ++first, ++second) { + if (*first < *second) return true; + if (*first > *second) return false; } return false; } + + unsigned char order_; }; void WriteOrThrow(FILE *to, const void *data, size_t size) { + assert(size); if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); } @@ -84,21 +162,24 @@ void CopyOrThrow(FILE *from, FILE *to, size_t size) { } } -template <class Entry> std::string DiskFlush(const Entry *begin, const Entry *end, const std::string &file_prefix, std::size_t batch) { +std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) { + const std::size_t entry_size = sizeof(WordIndex) * order + weights_size; + const std::size_t prefix_size = sizeof(WordIndex) * (order - 1); std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << '_' << batch; + assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch; std::string ret(assembled.str()); util::scoped_FILE out(fopen(ret.c_str(), "w")); if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing"); - for (const Entry *group_begin = begin; group_begin != end;) { - const Entry *group_end = group_begin; - for (++group_end; (group_end != end) && !memcmp(group_begin->words, group_end->words, sizeof(WordIndex) * (Entry::kOrder - 1)); ++group_end) {} - WriteOrThrow(out.get(), group_begin->words, sizeof(WordIndex) * (Entry::kOrder - 1)); - WordIndex group_size = group_end - group_begin; + // Compress entries that being with the same (order-1) words. + for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) { + const uint8_t *group_end = group_begin; + for (group_end += entry_size; (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size); group_end += entry_size) {} + WriteOrThrow(out.get(), group_begin, prefix_size); + WordIndex group_size = (group_end - group_begin) / entry_size; WriteOrThrow(out.get(), &group_size, sizeof(group_size)); - for (const Entry *i = group_begin; i != group_end; ++i) { - WriteOrThrow(out.get(), &i->words[Entry::kOrder - 1], sizeof(WordIndex)); - WriteOrThrow(out.get(), &i->weights, sizeof(typename Entry::Weights)); + for (const uint8_t *i = group_begin; i != group_end; i += entry_size) { + WriteOrThrow(out.get(), i + prefix_size, sizeof(WordIndex)); + WriteOrThrow(out.get(), i + sizeof(WordIndex) * order, weights_size); } group_begin = group_end; } @@ -219,25 +300,37 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha } } -template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix) { - ConvertToSorted<FullEntry<Entry::kOrder - 1> >(f, vocab, counts, mem, file_prefix); - - ReadNGramHeader(f, Entry::kOrder); - const size_t count = counts[Entry::kOrder - 1]; - const size_t batch_size = std::min(count, mem.size() / sizeof(Entry)); - Entry *const begin = reinterpret_cast<Entry*>(mem.get()); +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { + if (order == 1) return; + ConvertToSorted(f, vocab, counts, mem, file_prefix, order - 1); + + ReadNGramHeader(f, order); + const size_t count = counts[order - 1]; + // Size of weights. Does it include backoff? + const size_t words_size = sizeof(WordIndex) * order; + const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); + const size_t entry_size = words_size + weights_size; + const size_t batch_size = std::min(count, mem.size() / entry_size); + uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get()); std::deque<std::string> files; for (std::size_t batch = 0, done = 0; done < count; ++batch) { - Entry *out = begin; - Entry *out_end = out + std::min(count - done, batch_size); - for (; out != out_end; ++out) { - ReadNGram(f, Entry::kOrder, vocab, out->words, out->weights); + uint8_t *out = begin; + uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; + if (order == counts.size()) { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size)); + } + } else { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size)); + } } - //__gnu_parallel::sort(begin, out_end); - std::sort(begin, out_end); + // TODO: __gnu_parallel::sort here. + EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); + std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order)); - files.push_back(DiskFlush(begin, out_end, file_prefix, batch)); - done += out_end - begin; + files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); + done += (out_end - begin) / entry_size; } // All individual files created. Merge them. @@ -245,9 +338,9 @@ template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVoca std::size_t merge_count = 0; while (files.size() > 1) { std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merge_" << (merge_count++); + assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++); files.push_back(assembled.str()); - MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), sizeof(typename Entry::Weights), Entry::kOrder); + MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), weights_size, order); if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); files.pop_front(); if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); @@ -255,14 +348,12 @@ template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVoca } if (!files.empty()) { std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merged"; + assembled << file_prefix << static_cast<unsigned int>(order) << "_merged"; std::string merged_name(assembled.str()); if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); } } -template <> void ConvertToSorted<FullEntry<1> >(util::FilePiece &/*f*/, const SortedVocabulary &/*vocab*/, const std::vector<uint64_t> &/*counts*/, util::scoped_memory &/*mem*/, const std::string &/*file_prefix*/) {} - void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { { std::string unigram_name = file_prefix + "unigrams"; @@ -275,7 +366,7 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, util::scoped_memory mem; mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED); if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); - ConvertToSorted<ProbEntry<5> >(f, vocab, counts, mem, file_prefix); + ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size()); ReadEnd(f); } @@ -390,7 +481,8 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const temporary_directory.resize(strlen(temporary_directory.c_str())); // Add directory delimiter. Assumes a real operating system. temporary_directory += '/'; - ARPAToSortedFiles(f, counts, config.building_memory, temporary_directory.c_str(), vocab); + // At least 1MB sorting memory. + ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); BuildTrie(temporary_directory.c_str(), counts, config.messages, *this); if (rmdir(temporary_directory.c_str())) { std::cerr << "Failed to delete " << temporary_directory << std::endl; |