diff options
Diffstat (limited to 'klm/lm/trie_sort.cc')
-rw-r--r-- | klm/lm/trie_sort.cc | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 8663e94e..dc542bb3 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -65,13 +65,13 @@ class PartialViewProxy { typedef util::ProxyIterator<PartialViewProxy> PartialIter; -FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) { - util::scoped_fd file(maker.Make()); +FILE *DiskFlush(const void *mem_begin, const void *mem_end, const std::string &temp_prefix) { + util::scoped_fd file(util::MakeTemp(temp_prefix)); util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); return util::FDOpenOrThrow(file); } -FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) { +FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_prefix, std::size_t entry_size, unsigned char order) { const size_t context_size = sizeof(WordIndex) * (order - 1); // Sort just the contexts using the same memory. PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); @@ -84,7 +84,7 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make #endif (context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1))); - util::scoped_FILE out(maker.MakeFile()); + util::scoped_FILE out(util::FMakeTemp(temp_prefix)); // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. if (context_begin == context_end) return out.release(); @@ -114,12 +114,12 @@ struct FirstCombine { } }; -template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) { +template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const std::string &temp_prefix, std::size_t weights_size, unsigned char order, const Combine &combine) { std::size_t entry_size = sizeof(WordIndex) * order + weights_size; RecordReader first, second; first.Init(first_file, entry_size); second.Init(second_file, entry_size); - util::scoped_FILE out_file(maker.MakeFile()); + util::scoped_FILE out_file(util::FMakeTemp(temp_prefix)); EntryCompare less(order); while (first && second) { if (less(first.Data(), second.Data())) { @@ -177,9 +177,8 @@ void RecordReader::Rewind() { } SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { - util::TempMaker maker(file_prefix); PositiveProbWarn warn(config.positive_log_probability); - unigram_.reset(maker.Make()); + unigram_.reset(util::MakeTemp(file_prefix)); { // In case <unk> appears. size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff); @@ -202,7 +201,7 @@ SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<u if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); for (unsigned char order = 2; order <= counts.size(); ++order) { - ConvertToSorted(f, vocab, counts, maker, order, warn, mem.get(), buffer); + ConvertToSorted(f, vocab, counts, file_prefix, order, warn, mem.get(), buffer); } ReadEnd(f); } @@ -227,7 +226,7 @@ class Closer { }; } // namespace -void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) { +void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) { ReadNGramHeader(f, order); const size_t count = counts[order - 1]; // Size of weights. Does it include backoff? @@ -261,8 +260,8 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo std::sort #endif (NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order))); - files.push_back(DiskFlush(begin, out_end, maker)); - contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order)); + files.push_back(DiskFlush(begin, out_end, file_prefix)); + contexts.push_back(WriteContextFile(begin, out_end, file_prefix, entry_size, order)); done += (out_end - begin) / entry_size; } @@ -270,10 +269,10 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo // All individual files created. Merge them. while (files.size() > 1) { - files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine())); + files.push_back(MergeSortedFiles(files[0], files[1], file_prefix, weights_size, order, ThrowCombine())); files_closer.PopFront(); files_closer.PopFront(); - contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine())); + contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], file_prefix, 0, order - 1, FirstCombine())); contexts_closer.PopFront(); contexts_closer.PopFront(); } |