diff options
Diffstat (limited to 'klm/lm/trie_sort.cc')
-rw-r--r-- | klm/lm/trie_sort.cc | 44 |
1 files changed, 23 insertions, 21 deletions
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index b80fed02..8663e94e 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -22,12 +22,6 @@ namespace lm { namespace ngram { namespace trie { - -void WriteOrThrow(FILE *to, const void *data, size_t size) { - assert(size); - if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); -} - namespace { typedef util::SizedIterator NGramIter; @@ -95,12 +89,12 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. if (context_begin == context_end) return out.release(); PartialIter i(context_begin); - WriteOrThrow(out.get(), i->Data(), context_size); + util::WriteOrThrow(out.get(), i->Data(), context_size); const void *previous = i->Data(); ++i; for (; i != context_end; ++i) { if (memcmp(previous, i->Data(), context_size)) { - WriteOrThrow(out.get(), i->Data(), context_size); + util::WriteOrThrow(out.get(), i->Data(), context_size); previous = i->Data(); } } @@ -116,7 +110,7 @@ struct ThrowCombine { // Useful for context files that just contain records with no value. struct FirstCombine { void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const { - WriteOrThrow(out, first, entry_size); + util::WriteOrThrow(out, first, entry_size); } }; @@ -129,10 +123,10 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f EntryCompare less(order); while (first && second) { if (less(first.Data(), second.Data())) { - WriteOrThrow(out_file.get(), first.Data(), entry_size); + util::WriteOrThrow(out_file.get(), first.Data(), entry_size); ++first; } else if (less(second.Data(), first.Data())) { - WriteOrThrow(out_file.get(), second.Data(), entry_size); + util::WriteOrThrow(out_file.get(), second.Data(), entry_size); ++second; } else { combine(entry_size, first.Data(), second.Data(), out_file.get()); @@ -140,7 +134,7 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f } } for (RecordReader &remains = (first ? first : second); remains; ++remains) { - WriteOrThrow(out_file.get(), remains.Data(), entry_size); + util::WriteOrThrow(out_file.get(), remains.Data(), entry_size); } return out_file.release(); } @@ -148,19 +142,23 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f } // namespace void RecordReader::Init(FILE *file, std::size_t entry_size) { - rewind(file); - file_ = file; + entry_size_ = entry_size; data_.reset(malloc(entry_size)); UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer"); - remains_ = true; - entry_size_ = entry_size; - ++*this; + file_ = file; + if (file) { + rewind(file); + remains_ = true; + ++*this; + } else { + remains_ = false; + } } void RecordReader::Overwrite(const void *start, std::size_t amount) { long internal = (uint8_t*)start - (uint8_t*)data_.get(); UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); - WriteOrThrow(file_, start, amount); + util::WriteOrThrow(file_, start, amount); long forward = entry_size_ - internal - amount; #if !defined(_WIN32) && !defined(_WIN64) if (forward) @@ -169,9 +167,13 @@ void RecordReader::Overwrite(const void *start, std::size_t amount) { } void RecordReader::Rewind() { - rewind(file_); - remains_ = true; - ++*this; + if (file_) { + rewind(file_); + remains_ = true; + ++*this; + } else { + remains_ = false; + } } SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { |