diff options
Diffstat (limited to 'klm/lm/trie_sort.cc')
-rw-r--r-- | klm/lm/trie_sort.cc | 217 |
1 files changed, 123 insertions, 94 deletions
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index bb126f18..b80fed02 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -14,6 +14,7 @@ #include <algorithm> #include <cstring> #include <cstdio> +#include <cstdlib> #include <deque> #include <limits> #include <vector> @@ -22,14 +23,6 @@ namespace lm { namespace ngram { namespace trie { -const char *kContextSuffix = "_contexts"; - -FILE *OpenOrThrow(const char *name, const char *mode) { - FILE *ret = fopen(name, mode); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); - return ret; -} - void WriteOrThrow(FILE *to, const void *data, size_t size) { assert(size); if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); @@ -78,28 +71,29 @@ class PartialViewProxy { typedef util::ProxyIterator<PartialViewProxy> PartialIter; -std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order) { - std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch; - std::string ret(assembled.str()); - util::scoped_fd out(util::CreateOrThrow(ret.c_str())); - util::WriteOrThrow(out.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); - return ret; +FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) { + util::scoped_fd file(maker.Make()); + util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); + return util::FDOpenOrThrow(file); } -void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { +FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) { const size_t context_size = sizeof(WordIndex) * (order - 1); // Sort just the contexts using the same memory. PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); - std::sort(context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1))); +#if defined(_WIN32) || defined(_WIN64) + std::stable_sort +#else + std::sort +#endif + (context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1))); - std::string name(ngram_file_name + kContextSuffix); - util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); + util::scoped_FILE out(maker.MakeFile()); // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. - if (context_begin == context_end) return; + if (context_begin == context_end) return out.release(); PartialIter i(context_begin); WriteOrThrow(out.get(), i->Data(), context_size); const void *previous = i->Data(); @@ -110,6 +104,7 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil previous = i->Data(); } } + return out.release(); } struct ThrowCombine { @@ -125,14 +120,12 @@ struct FirstCombine { } }; -template <class Combine> void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order, const Combine &combine = ThrowCombine()) { +template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) { std::size_t entry_size = sizeof(WordIndex) * order + weights_size; RecordReader first, second; - first.Init(first_name.c_str(), entry_size); - util::RemoveOrThrow(first_name.c_str()); - second.Init(second_name.c_str(), entry_size); - util::RemoveOrThrow(second_name.c_str()); - util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); + first.Init(first_file, entry_size); + second.Init(second_file, entry_size); + util::scoped_FILE out_file(maker.MakeFile()); EntryCompare less(order); while (first && second) { if (less(first.Data(), second.Data())) { @@ -149,67 +142,14 @@ template <class Combine> void MergeSortedFiles(const std::string &first_name, co for (RecordReader &remains = (first ? first : second); remains; ++remains) { WriteOrThrow(out_file.get(), remains.Data(), entry_size); } -} - -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { - ReadNGramHeader(f, order); - const size_t count = counts[order - 1]; - // Size of weights. Does it include backoff? - const size_t words_size = sizeof(WordIndex) * order; - const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); - const size_t entry_size = words_size + weights_size; - const size_t batch_size = std::min(count, mem.size() / entry_size); - uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get()); - std::deque<std::string> files; - for (std::size_t batch = 0, done = 0; done < count; ++batch) { - uint8_t *out = begin; - uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; - if (order == counts.size()) { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn); - } - } else { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn); - } - } - // Sort full records by full n-gram. - util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - // parallel_sort uses too much RAM - std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order))); - files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order)); - WriteContextFile(begin, out_end, files.back(), entry_size, order); - - done += (out_end - begin) / entry_size; - } - - // All individual files created. Merge them. - - std::size_t merge_count = 0; - while (files.size() > 1) { - std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++); - files.push_back(assembled.str()); - MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); - MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order - 1, FirstCombine()); - files.pop_front(); - files.pop_front(); - } - if (!files.empty()) { - std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(order) << "_merged"; - std::string merged_name(assembled.str()); - if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); - std::string context_name = files[0] + kContextSuffix; - merged_name += kContextSuffix; - if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); - } + return out_file.release(); } } // namespace -void RecordReader::Init(const std::string &name, std::size_t entry_size) { - file_.reset(OpenOrThrow(name.c_str(), "r+")); +void RecordReader::Init(FILE *file, std::size_t entry_size) { + rewind(file); + file_ = file; data_.reset(malloc(entry_size)); UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer"); remains_ = true; @@ -219,20 +159,29 @@ void RecordReader::Init(const std::string &name, std::size_t entry_size) { void RecordReader::Overwrite(const void *start, std::size_t amount) { long internal = (uint8_t*)start - (uint8_t*)data_.get(); - UTIL_THROW_IF(fseek(file_.get(), internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); - WriteOrThrow(file_.get(), start, amount); + UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); + WriteOrThrow(file_, start, amount); long forward = entry_size_ - internal - amount; - if (forward) UTIL_THROW_IF(fseek(file_.get(), forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision"); +#if !defined(_WIN32) && !defined(_WIN64) + if (forward) +#endif + UTIL_THROW_IF(fseek(file_, forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision"); } -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { +void RecordReader::Rewind() { + rewind(file_); + remains_ = true; + ++*this; +} + +SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + util::TempMaker maker(file_prefix); PositiveProbWarn warn(config.positive_log_probability); + unigram_.reset(maker.Make()); { - std::string unigram_name = file_prefix + "unigrams"; - util::scoped_fd unigram_file; // In case <unk> appears. - size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); + size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_.get(), size_out), size_out); Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn); CheckSpecials(config, vocab); if (!vocab.SawUnk()) ++counts[0]; @@ -246,16 +195,96 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); buffer = std::min<size_t>(buffer, buffer_use); - util::scoped_memory mem; - mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); + util::scoped_malloc mem; + mem.reset(malloc(buffer)); if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); for (unsigned char order = 2; order <= counts.size(); ++order) { - ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); + ConvertToSorted(f, vocab, counts, maker, order, warn, mem.get(), buffer); } ReadEnd(f); } +namespace { +class Closer { + public: + explicit Closer(std::deque<FILE*> &files) : files_(files) {} + + ~Closer() { + for (std::deque<FILE*>::iterator i = files_.begin(); i != files_.end(); ++i) { + util::scoped_FILE deleter(*i); + } + } + + void PopFront() { + util::scoped_FILE deleter(files_.front()); + files_.pop_front(); + } + private: + std::deque<FILE*> &files_; +}; +} // namespace + +void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) { + ReadNGramHeader(f, order); + const size_t count = counts[order - 1]; + // Size of weights. Does it include backoff? + const size_t words_size = sizeof(WordIndex) * order; + const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); + const size_t entry_size = words_size + weights_size; + const size_t batch_size = std::min(count, mem_size / entry_size); + uint8_t *const begin = reinterpret_cast<uint8_t*>(mem); + + std::deque<FILE*> files, contexts; + Closer files_closer(files), contexts_closer(contexts); + + for (std::size_t batch = 0, done = 0; done < count; ++batch) { + uint8_t *out = begin; + uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; + if (order == counts.size()) { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn); + } + } else { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn); + } + } + // Sort full records by full n-gram. + util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); + // parallel_sort uses too much RAM. TODO: figure out why windows sort doesn't like my proxies. +#if defined(_WIN32) || defined(_WIN64) + std::stable_sort +#else + std::sort +#endif + (NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order))); + files.push_back(DiskFlush(begin, out_end, maker)); + contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order)); + + done += (out_end - begin) / entry_size; + } + + // All individual files created. Merge them. + + while (files.size() > 1) { + files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine())); + files_closer.PopFront(); + files_closer.PopFront(); + contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine())); + contexts_closer.PopFront(); + contexts_closer.PopFront(); + } + + if (!files.empty()) { + // Steal from closers. + full_[order - 2].reset(files.front()); + files.pop_front(); + context_[order - 2].reset(contexts.front()); + contexts.pop_front(); + } +} + } // namespace trie } // namespace ngram } // namespace lm |