summaryrefslogtreecommitdiff
path: root/klm/lm/trie_sort.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie_sort.cc')
-rw-r--r--klm/lm/trie_sort.cc47
1 files changed, 20 insertions, 27 deletions
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index 0d83221e..dc542bb3 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -22,12 +22,6 @@
namespace lm {
namespace ngram {
namespace trie {
-
-void WriteOrThrow(FILE *to, const void *data, size_t size) {
- assert(size);
- if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
-}
-
namespace {
typedef util::SizedIterator NGramIter;
@@ -71,13 +65,13 @@ class PartialViewProxy {
typedef util::ProxyIterator<PartialViewProxy> PartialIter;
-FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) {
- util::scoped_fd file(maker.Make());
+FILE *DiskFlush(const void *mem_begin, const void *mem_end, const std::string &temp_prefix) {
+ util::scoped_fd file(util::MakeTemp(temp_prefix));
util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin);
return util::FDOpenOrThrow(file);
}
-FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) {
+FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_prefix, std::size_t entry_size, unsigned char order) {
const size_t context_size = sizeof(WordIndex) * (order - 1);
// Sort just the contexts using the same memory.
PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
@@ -90,17 +84,17 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make
#endif
(context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
- util::scoped_FILE out(maker.MakeFile());
+ util::scoped_FILE out(util::FMakeTemp(temp_prefix));
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
if (context_begin == context_end) return out.release();
PartialIter i(context_begin);
- WriteOrThrow(out.get(), i->Data(), context_size);
+ util::WriteOrThrow(out.get(), i->Data(), context_size);
const void *previous = i->Data();
++i;
for (; i != context_end; ++i) {
if (memcmp(previous, i->Data(), context_size)) {
- WriteOrThrow(out.get(), i->Data(), context_size);
+ util::WriteOrThrow(out.get(), i->Data(), context_size);
previous = i->Data();
}
}
@@ -116,23 +110,23 @@ struct ThrowCombine {
// Useful for context files that just contain records with no value.
struct FirstCombine {
void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const {
- WriteOrThrow(out, first, entry_size);
+ util::WriteOrThrow(out, first, entry_size);
}
};
-template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) {
+template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const std::string &temp_prefix, std::size_t weights_size, unsigned char order, const Combine &combine) {
std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
RecordReader first, second;
first.Init(first_file, entry_size);
second.Init(second_file, entry_size);
- util::scoped_FILE out_file(maker.MakeFile());
+ util::scoped_FILE out_file(util::FMakeTemp(temp_prefix));
EntryCompare less(order);
while (first && second) {
if (less(first.Data(), second.Data())) {
- WriteOrThrow(out_file.get(), first.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), first.Data(), entry_size);
++first;
} else if (less(second.Data(), first.Data())) {
- WriteOrThrow(out_file.get(), second.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), second.Data(), entry_size);
++second;
} else {
combine(entry_size, first.Data(), second.Data(), out_file.get());
@@ -140,7 +134,7 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f
}
}
for (RecordReader &remains = (first ? first : second); remains; ++remains) {
- WriteOrThrow(out_file.get(), remains.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), remains.Data(), entry_size);
}
return out_file.release();
}
@@ -164,7 +158,7 @@ void RecordReader::Init(FILE *file, std::size_t entry_size) {
void RecordReader::Overwrite(const void *start, std::size_t amount) {
long internal = (uint8_t*)start - (uint8_t*)data_.get();
UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
- WriteOrThrow(file_, start, amount);
+ util::WriteOrThrow(file_, start, amount);
long forward = entry_size_ - internal - amount;
#if !defined(_WIN32) && !defined(_WIN64)
if (forward)
@@ -183,9 +177,8 @@ void RecordReader::Rewind() {
}
SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
- util::TempMaker maker(file_prefix);
PositiveProbWarn warn(config.positive_log_probability);
- unigram_.reset(maker.Make());
+ unigram_.reset(util::MakeTemp(file_prefix));
{
// In case <unk> appears.
size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff);
@@ -208,7 +201,7 @@ SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<u
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
for (unsigned char order = 2; order <= counts.size(); ++order) {
- ConvertToSorted(f, vocab, counts, maker, order, warn, mem.get(), buffer);
+ ConvertToSorted(f, vocab, counts, file_prefix, order, warn, mem.get(), buffer);
}
ReadEnd(f);
}
@@ -233,7 +226,7 @@ class Closer {
};
} // namespace
-void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
+void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
ReadNGramHeader(f, order);
const size_t count = counts[order - 1];
// Size of weights. Does it include backoff?
@@ -267,8 +260,8 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
std::sort
#endif
(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order)));
- files.push_back(DiskFlush(begin, out_end, maker));
- contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order));
+ files.push_back(DiskFlush(begin, out_end, file_prefix));
+ contexts.push_back(WriteContextFile(begin, out_end, file_prefix, entry_size, order));
done += (out_end - begin) / entry_size;
}
@@ -276,10 +269,10 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
// All individual files created. Merge them.
while (files.size() > 1) {
- files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine()));
+ files.push_back(MergeSortedFiles(files[0], files[1], file_prefix, weights_size, order, ThrowCombine()));
files_closer.PopFront();
files_closer.PopFront();
- contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine()));
+ contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], file_prefix, 0, order - 1, FirstCombine()));
contexts_closer.PopFront();
contexts_closer.PopFront();
}