summaryrefslogtreecommitdiff
path: root/klm/lm/trie_sort.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie_sort.cc')
-rw-r--r--klm/lm/trie_sort.cc24
1 files changed, 16 insertions, 8 deletions
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index b80fed02..0d83221e 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -148,13 +148,17 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f
} // namespace
void RecordReader::Init(FILE *file, std::size_t entry_size) {
- rewind(file);
- file_ = file;
+ entry_size_ = entry_size;
data_.reset(malloc(entry_size));
UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer");
- remains_ = true;
- entry_size_ = entry_size;
- ++*this;
+ file_ = file;
+ if (file) {
+ rewind(file);
+ remains_ = true;
+ ++*this;
+ } else {
+ remains_ = false;
+ }
}
void RecordReader::Overwrite(const void *start, std::size_t amount) {
@@ -169,9 +173,13 @@ void RecordReader::Overwrite(const void *start, std::size_t amount) {
}
void RecordReader::Rewind() {
- rewind(file_);
- remains_ = true;
- ++*this;
+ if (file_) {
+ rewind(file_);
+ remains_ = true;
+ ++*this;
+ } else {
+ remains_ = false;
+ }
}
SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {