summaryrefslogtreecommitdiff
path: root/klm/lm/trie_sort.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie_sort.cc')
-rw-r--r--klm/lm/trie_sort.cc217
1 files changed, 123 insertions, 94 deletions
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index bb126f18..b80fed02 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -14,6 +14,7 @@
#include <algorithm>
#include <cstring>
#include <cstdio>
+#include <cstdlib>
#include <deque>
#include <limits>
#include <vector>
@@ -22,14 +23,6 @@ namespace lm {
namespace ngram {
namespace trie {
-const char *kContextSuffix = "_contexts";
-
-FILE *OpenOrThrow(const char *name, const char *mode) {
- FILE *ret = fopen(name, mode);
- if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode);
- return ret;
-}
-
void WriteOrThrow(FILE *to, const void *data, size_t size) {
assert(size);
if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
@@ -78,28 +71,29 @@ class PartialViewProxy {
typedef util::ProxyIterator<PartialViewProxy> PartialIter;
-std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order) {
- std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch;
- std::string ret(assembled.str());
- util::scoped_fd out(util::CreateOrThrow(ret.c_str()));
- util::WriteOrThrow(out.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin);
- return ret;
+FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) {
+ util::scoped_fd file(maker.Make());
+ util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin);
+ return util::FDOpenOrThrow(file);
}
-void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) {
+FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) {
const size_t context_size = sizeof(WordIndex) * (order - 1);
// Sort just the contexts using the same memory.
PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size));
- std::sort(context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
+#if defined(_WIN32) || defined(_WIN64)
+ std::stable_sort
+#else
+ std::sort
+#endif
+ (context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
- std::string name(ngram_file_name + kContextSuffix);
- util::scoped_FILE out(OpenOrThrow(name.c_str(), "w"));
+ util::scoped_FILE out(maker.MakeFile());
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
- if (context_begin == context_end) return;
+ if (context_begin == context_end) return out.release();
PartialIter i(context_begin);
WriteOrThrow(out.get(), i->Data(), context_size);
const void *previous = i->Data();
@@ -110,6 +104,7 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil
previous = i->Data();
}
}
+ return out.release();
}
struct ThrowCombine {
@@ -125,14 +120,12 @@ struct FirstCombine {
}
};
-template <class Combine> void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order, const Combine &combine = ThrowCombine()) {
+template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) {
std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
RecordReader first, second;
- first.Init(first_name.c_str(), entry_size);
- util::RemoveOrThrow(first_name.c_str());
- second.Init(second_name.c_str(), entry_size);
- util::RemoveOrThrow(second_name.c_str());
- util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w"));
+ first.Init(first_file, entry_size);
+ second.Init(second_file, entry_size);
+ util::scoped_FILE out_file(maker.MakeFile());
EntryCompare less(order);
while (first && second) {
if (less(first.Data(), second.Data())) {
@@ -149,67 +142,14 @@ template <class Combine> void MergeSortedFiles(const std::string &first_name, co
for (RecordReader &remains = (first ? first : second); remains; ++remains) {
WriteOrThrow(out_file.get(), remains.Data(), entry_size);
}
-}
-
-void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) {
- ReadNGramHeader(f, order);
- const size_t count = counts[order - 1];
- // Size of weights. Does it include backoff?
- const size_t words_size = sizeof(WordIndex) * order;
- const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float));
- const size_t entry_size = words_size + weights_size;
- const size_t batch_size = std::min(count, mem.size() / entry_size);
- uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get());
- std::deque<std::string> files;
- for (std::size_t batch = 0, done = 0; done < count; ++batch) {
- uint8_t *out = begin;
- uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
- if (order == counts.size()) {
- for (; out != out_end; out += entry_size) {
- ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn);
- }
- } else {
- for (; out != out_end; out += entry_size) {
- ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
- }
- }
- // Sort full records by full n-gram.
- util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
- // parallel_sort uses too much RAM
- std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order)));
- files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order));
- WriteContextFile(begin, out_end, files.back(), entry_size, order);
-
- done += (out_end - begin) / entry_size;
- }
-
- // All individual files created. Merge them.
-
- std::size_t merge_count = 0;
- while (files.size() > 1) {
- std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++);
- files.push_back(assembled.str());
- MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine());
- MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order - 1, FirstCombine());
- files.pop_front();
- files.pop_front();
- }
- if (!files.empty()) {
- std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(order) << "_merged";
- std::string merged_name(assembled.str());
- if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str());
- std::string context_name = files[0] + kContextSuffix;
- merged_name += kContextSuffix;
- if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str());
- }
+ return out_file.release();
}
} // namespace
-void RecordReader::Init(const std::string &name, std::size_t entry_size) {
- file_.reset(OpenOrThrow(name.c_str(), "r+"));
+void RecordReader::Init(FILE *file, std::size_t entry_size) {
+ rewind(file);
+ file_ = file;
data_.reset(malloc(entry_size));
UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer");
remains_ = true;
@@ -219,20 +159,29 @@ void RecordReader::Init(const std::string &name, std::size_t entry_size) {
void RecordReader::Overwrite(const void *start, std::size_t amount) {
long internal = (uint8_t*)start - (uint8_t*)data_.get();
- UTIL_THROW_IF(fseek(file_.get(), internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
- WriteOrThrow(file_.get(), start, amount);
+ UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
+ WriteOrThrow(file_, start, amount);
long forward = entry_size_ - internal - amount;
- if (forward) UTIL_THROW_IF(fseek(file_.get(), forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision");
+#if !defined(_WIN32) && !defined(_WIN64)
+ if (forward)
+#endif
+ UTIL_THROW_IF(fseek(file_, forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision");
}
-void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
+void RecordReader::Rewind() {
+ rewind(file_);
+ remains_ = true;
+ ++*this;
+}
+
+SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
+ util::TempMaker maker(file_prefix);
PositiveProbWarn warn(config.positive_log_probability);
+ unigram_.reset(maker.Make());
{
- std::string unigram_name = file_prefix + "unigrams";
- util::scoped_fd unigram_file;
// In case <unk> appears.
- size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff);
- util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out);
+ size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff);
+ util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_.get(), size_out), size_out);
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);
CheckSpecials(config, vocab);
if (!vocab.SawUnk()) ++counts[0];
@@ -246,16 +195,96 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin
buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
buffer = std::min<size_t>(buffer, buffer_use);
- util::scoped_memory mem;
- mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED);
+ util::scoped_malloc mem;
+ mem.reset(malloc(buffer));
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
for (unsigned char order = 2; order <= counts.size(); ++order) {
- ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn);
+ ConvertToSorted(f, vocab, counts, maker, order, warn, mem.get(), buffer);
}
ReadEnd(f);
}
+namespace {
+class Closer {
+ public:
+ explicit Closer(std::deque<FILE*> &files) : files_(files) {}
+
+ ~Closer() {
+ for (std::deque<FILE*>::iterator i = files_.begin(); i != files_.end(); ++i) {
+ util::scoped_FILE deleter(*i);
+ }
+ }
+
+ void PopFront() {
+ util::scoped_FILE deleter(files_.front());
+ files_.pop_front();
+ }
+ private:
+ std::deque<FILE*> &files_;
+};
+} // namespace
+
+void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
+ ReadNGramHeader(f, order);
+ const size_t count = counts[order - 1];
+ // Size of weights. Does it include backoff?
+ const size_t words_size = sizeof(WordIndex) * order;
+ const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float));
+ const size_t entry_size = words_size + weights_size;
+ const size_t batch_size = std::min(count, mem_size / entry_size);
+ uint8_t *const begin = reinterpret_cast<uint8_t*>(mem);
+
+ std::deque<FILE*> files, contexts;
+ Closer files_closer(files), contexts_closer(contexts);
+
+ for (std::size_t batch = 0, done = 0; done < count; ++batch) {
+ uint8_t *out = begin;
+ uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
+ if (order == counts.size()) {
+ for (; out != out_end; out += entry_size) {
+ ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn);
+ }
+ } else {
+ for (; out != out_end; out += entry_size) {
+ ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
+ }
+ }
+ // Sort full records by full n-gram.
+ util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
+ // parallel_sort uses too much RAM. TODO: figure out why windows sort doesn't like my proxies.
+#if defined(_WIN32) || defined(_WIN64)
+ std::stable_sort
+#else
+ std::sort
+#endif
+ (NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order)));
+ files.push_back(DiskFlush(begin, out_end, maker));
+ contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order));
+
+ done += (out_end - begin) / entry_size;
+ }
+
+ // All individual files created. Merge them.
+
+ while (files.size() > 1) {
+ files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine()));
+ files_closer.PopFront();
+ files_closer.PopFront();
+ contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine()));
+ contexts_closer.PopFront();
+ contexts_closer.PopFront();
+ }
+
+ if (!files.empty()) {
+ // Steal from closers.
+ full_[order - 2].reset(files.front());
+ files.pop_front();
+ context_[order - 2].reset(contexts.front());
+ contexts.pop_front();
+ }
+}
+
} // namespace trie
} // namespace ngram
} // namespace lm