summaryrefslogtreecommitdiff
path: root/klm/lm/trie_sort.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie_sort.cc')
-rw-r--r--klm/lm/trie_sort.cc27
1 files changed, 13 insertions, 14 deletions
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index 8663e94e..dc542bb3 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -65,13 +65,13 @@ class PartialViewProxy {
typedef util::ProxyIterator<PartialViewProxy> PartialIter;
-FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) {
- util::scoped_fd file(maker.Make());
+FILE *DiskFlush(const void *mem_begin, const void *mem_end, const std::string &temp_prefix) {
+ util::scoped_fd file(util::MakeTemp(temp_prefix));
util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin);
return util::FDOpenOrThrow(file);
}
-FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) {
+FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_prefix, std::size_t entry_size, unsigned char order) {
const size_t context_size = sizeof(WordIndex) * (order - 1);
// Sort just the contexts using the same memory.
PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
@@ -84,7 +84,7 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make
#endif
(context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
- util::scoped_FILE out(maker.MakeFile());
+ util::scoped_FILE out(util::FMakeTemp(temp_prefix));
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
if (context_begin == context_end) return out.release();
@@ -114,12 +114,12 @@ struct FirstCombine {
}
};
-template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) {
+template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const std::string &temp_prefix, std::size_t weights_size, unsigned char order, const Combine &combine) {
std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
RecordReader first, second;
first.Init(first_file, entry_size);
second.Init(second_file, entry_size);
- util::scoped_FILE out_file(maker.MakeFile());
+ util::scoped_FILE out_file(util::FMakeTemp(temp_prefix));
EntryCompare less(order);
while (first && second) {
if (less(first.Data(), second.Data())) {
@@ -177,9 +177,8 @@ void RecordReader::Rewind() {
}
SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
- util::TempMaker maker(file_prefix);
PositiveProbWarn warn(config.positive_log_probability);
- unigram_.reset(maker.Make());
+ unigram_.reset(util::MakeTemp(file_prefix));
{
// In case <unk> appears.
size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff);
@@ -202,7 +201,7 @@ SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<u
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
for (unsigned char order = 2; order <= counts.size(); ++order) {
- ConvertToSorted(f, vocab, counts, maker, order, warn, mem.get(), buffer);
+ ConvertToSorted(f, vocab, counts, file_prefix, order, warn, mem.get(), buffer);
}
ReadEnd(f);
}
@@ -227,7 +226,7 @@ class Closer {
};
} // namespace
-void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
+void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
ReadNGramHeader(f, order);
const size_t count = counts[order - 1];
// Size of weights. Does it include backoff?
@@ -261,8 +260,8 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
std::sort
#endif
(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order)));
- files.push_back(DiskFlush(begin, out_end, maker));
- contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order));
+ files.push_back(DiskFlush(begin, out_end, file_prefix));
+ contexts.push_back(WriteContextFile(begin, out_end, file_prefix, entry_size, order));
done += (out_end - begin) / entry_size;
}
@@ -270,10 +269,10 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
// All individual files created. Merge them.
while (files.size() > 1) {
- files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine()));
+ files.push_back(MergeSortedFiles(files[0], files[1], file_prefix, weights_size, order, ThrowCombine()));
files_closer.PopFront();
files_closer.PopFront();
- contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine()));
+ contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], file_prefix, 0, order - 1, FirstCombine()));
contexts_closer.PopFront();
contexts_closer.PopFront();
}