diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-01-25 22:30:48 +0200 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-01-25 22:30:48 +0200 |
commit | c4ade3091b812ca135ae6520fa7173e1bbf28754 (patch) | |
tree | 2528af208f6dafd0c27dcbec0d2da291a9c93ca2 /klm/lm/search_trie.cc | |
parent | d04c0ca2d9df0e147239b18e90650ca8bd51d594 (diff) |
update kenlm
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 302 |
1 files changed, 257 insertions, 45 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 3aeeeca3..1060ddef 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -3,6 +3,7 @@ #include "lm/blank.hh" #include "lm/lm_exception.hh" +#include "lm/max_order.hh" #include "lm/read_arpa.hh" #include "lm/trie.hh" #include "lm/vocab.hh" @@ -27,6 +28,7 @@ #include <sys/stat.h> #include <fcntl.h> #include <stdlib.h> +#include <unistd.h> namespace lm { namespace ngram { @@ -98,7 +100,7 @@ class EntryProxy { } const WordIndex *Indices() const { - return static_cast<const WordIndex*>(inner_.Data()); + return reinterpret_cast<const WordIndex*>(inner_.Data()); } private: @@ -114,17 +116,57 @@ class EntryProxy { typedef util::ProxyIterator<EntryProxy> NGramIter; -class CompareRecords : public std::binary_function<const EntryProxy &, const EntryProxy &, bool> { +// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams. +class PartialViewProxy { + public: + PartialViewProxy() : attention_size_(0), inner_() {} + + PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {} + + operator std::string() const { + return std::string(reinterpret_cast<const char*>(inner_.Data()), attention_size_); + } + + PartialViewProxy &operator=(const PartialViewProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), attention_size_); + return *this; + } + + PartialViewProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), attention_size_); + return *this; + } + + const WordIndex *Indices() const { + return reinterpret_cast<const WordIndex*>(inner_.Data()); + } + + private: + friend class util::ProxyIterator<PartialViewProxy>; + + typedef std::string value_type; + + const std::size_t attention_size_; + + typedef EntryIterator InnerIterator; + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +typedef util::ProxyIterator<PartialViewProxy> PartialIter; + +template <class Proxy> class CompareRecords : public std::binary_function<const Proxy &, const Proxy &, bool> { public: explicit CompareRecords(unsigned char order) : order_(order) {} - bool operator()(const EntryProxy &first, const EntryProxy &second) const { + bool operator()(const Proxy &first, const Proxy &second) const { return Compare(first.Indices(), second.Indices()); } - bool operator()(const EntryProxy &first, const std::string &second) const { + bool operator()(const Proxy &first, const std::string &second) const { return Compare(first.Indices(), reinterpret_cast<const WordIndex*>(second.data())); } - bool operator()(const std::string &first, const EntryProxy &second) const { + bool operator()(const std::string &first, const Proxy &second) const { return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices()); } bool operator()(const std::string &first, const std::string &second) const { @@ -144,6 +186,12 @@ class CompareRecords : public std::binary_function<const EntryProxy &, const Ent unsigned char order_; }; +FILE *OpenOrThrow(const char *name, const char *mode) { + FILE *ret = fopen(name, mode); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); + return ret; +} + void WriteOrThrow(FILE *to, const void *data, size_t size) { assert(size); if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); @@ -163,14 +211,26 @@ void CopyOrThrow(FILE *from, FILE *to, size_t size) { } } +void CopyRestOrThrow(FILE *from, FILE *to) { + char buf[kCopyBufSize]; + size_t amount; + while ((amount = fread(buf, 1, kCopyBufSize, from))) { + WriteOrThrow(to, buf, amount); + } + if (!feof(from)) UTIL_THROW(util::ErrnoException, "Short read"); +} + +void RemoveOrThrow(const char *name) { + if (std::remove(name)) UTIL_THROW(util::ErrnoException, "Could not remove " << name); +} + std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) { const std::size_t entry_size = sizeof(WordIndex) * order + weights_size; const std::size_t prefix_size = sizeof(WordIndex) * (order - 1); std::stringstream assembled; assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch; std::string ret(assembled.str()); - util::scoped_FILE out(fopen(ret.c_str(), "w")); - if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing"); + util::scoped_FILE out(OpenOrThrow(ret.c_str(), "w")); // Compress entries that being with the same (order-1) words. for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) { const uint8_t *group_end; @@ -194,8 +254,7 @@ class SortedFileReader { SortedFileReader() : ended_(false) {} void Init(const std::string &name, unsigned char order) { - file_.reset(fopen(name.c_str(), "r")); - if (!file_.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " for read"); + file_.reset(OpenOrThrow(name.c_str(), "r")); header_.resize(order - 1); NextHeader(); } @@ -262,12 +321,13 @@ void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); } -void MergeSortedFiles(const char *first_name, const char *second_name, const char *out, std::size_t weights_size, unsigned char order) { +void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order) { SortedFileReader first, second; - first.Init(first_name, order); - second.Init(second_name, order); - util::scoped_FILE out_file(fopen(out, "w")); - if (!out_file.get()) UTIL_THROW(util::ErrnoException, "Could not open " << out << " for write"); + first.Init(first_name.c_str(), order); + RemoveOrThrow(first_name.c_str()); + second.Init(second_name.c_str(), order); + RemoveOrThrow(second_name.c_str()); + util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); while (!first.Ended() && !second.Ended()) { if (first.HeaderVector() < second.HeaderVector()) { CopyFullRecord(first, out_file.get(), weights_size); @@ -316,10 +376,109 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha } } -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { - if (order == 1) return; - ConvertToSorted(f, vocab, counts, mem, file_prefix, order - 1); +const char *kContextSuffix = "_contexts"; + +void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { + const size_t context_size = sizeof(WordIndex) * (order - 1); + // Sort just the contexts using the same memory. + PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); + PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); + + // TODO: __gnu_parallel::sort here. + std::sort(context_begin, context_end, CompareRecords<PartialViewProxy>(order - 1)); + + std::string name(ngram_file_name + kContextSuffix); + util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); + + // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. + if (context_begin == context_end) return; + PartialIter i(context_begin); + WriteOrThrow(out.get(), i->Indices(), context_size); + const WordIndex *previous = i->Indices(); + ++i; + for (; i != context_end; ++i) { + if (memcmp(previous, i->Indices(), context_size)) { + WriteOrThrow(out.get(), i->Indices(), context_size); + previous = i->Indices(); + } + } +} +class ContextReader { + public: + ContextReader() : length_(0) {} + + ContextReader(const char *name, size_t length) : file_(OpenOrThrow(name, "r")), length_(length), words_(length), valid_(true) { + ++*this; + } + + void Reset(const char *name, size_t length) { + file_.reset(OpenOrThrow(name, "r")); + length_ = length; + words_.resize(length); + valid_ = true; + ++*this; + } + + ContextReader &operator++() { + if (1 != fread(&*words_.begin(), length_, 1, file_.get())) { + if (!feof(file_.get())) + UTIL_THROW(util::ErrnoException, "Short read"); + valid_ = false; + } + return *this; + } + + const WordIndex *operator*() const { return &*words_.begin(); } + + operator bool() const { return valid_; } + + FILE *GetFile() { return file_.get(); } + + private: + util::scoped_FILE file_; + + size_t length_; + + std::vector<WordIndex> words_; + + bool valid_; +}; + +void MergeContextFiles(const std::string &first_base, const std::string &second_base, const std::string &out_base, unsigned char order) { + const size_t context_size = sizeof(WordIndex) * (order - 1); + std::string first_name(first_base + kContextSuffix); + std::string second_name(second_base + kContextSuffix); + ContextReader first(first_name.c_str(), context_size), second(second_name.c_str(), context_size); + RemoveOrThrow(first_name.c_str()); + RemoveOrThrow(second_name.c_str()); + std::string out_name(out_base + kContextSuffix); + util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w")); + while (first && second) { + for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) { + if (f == *first + order) { + // Equal. + WriteOrThrow(out.get(), *first, context_size); + ++first; + ++second; + break; + } + if (*f < *s) { + // First lower + WriteOrThrow(out.get(), *first, context_size); + ++first; + break; + } else if (*f > *s) { + WriteOrThrow(out.get(), *second, context_size); + ++second; + break; + } + } + } + CopyRestOrThrow((first ? first : second).GetFile(), out.get()); +} + +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { ReadNGramHeader(f, order); const size_t count = counts[order - 1]; // Size of weights. Does it include backoff? @@ -341,11 +500,13 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size)); } } - // TODO: __gnu_parallel::sort here. + // Sort full records by full n-gram. EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order)); - + // TODO: __gnu_parallel::sort here. + std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords<EntryProxy>(order)); files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); + WriteContextFile(begin, out_end, files.back(), entry_size, order); + done += (out_end - begin) / entry_size; } @@ -356,10 +517,9 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st std::stringstream assembled; assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++); files.push_back(assembled.str()); - MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), weights_size, order); - if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); + MergeSortedFiles(files[0], files[1], files.back(), weights_size, order); + MergeContextFiles(files[0], files[1], files.back(), order); files.pop_front(); - if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); files.pop_front(); } if (!files.empty()) { @@ -367,6 +527,9 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st assembled << file_prefix << static_cast<unsigned int>(order) << "_merged"; std::string merged_name(assembled.str()); if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); + std::string context_name = files[0] + kContextSuffix; + merged_name += kContextSuffix; + if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); } } @@ -378,26 +541,38 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get())); } + // Only use as much buffer as we need. + size_t buffer_use = 0; + for (unsigned int order = 2; order < counts.size(); ++order) { + buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); + } + buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); + buffer = std::min(buffer, buffer_use); + util::scoped_memory mem; mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); - ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size()); + + for (unsigned char order = 2; order <= counts.size(); ++order) { + ConvertToSorted(f, vocab, counts, mem, file_prefix, order); + } ReadEnd(f); } bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) { for (; words != words_end; ++words, ++header) { if (*words != *header) { - assert(*words <= *header); + //assert(*words <= *header); return false; } } return true; } +// Counting phrase class JustCount { public: - JustCount(UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) + JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) : counts_(counts), longest_counts_(counts + order - 1) {} void Unigrams(WordIndex begin, WordIndex end) { @@ -408,7 +583,7 @@ class JustCount { ++counts_[mid_idx + 1]; } - void Middle(const unsigned char mid_idx, WordIndex /*key*/, const ProbBackoff &/*weights*/) { + void Middle(const unsigned char mid_idx, const WordIndex * /*before*/, WordIndex /*key*/, const ProbBackoff &/*weights*/) { ++counts_[mid_idx + 1]; } @@ -427,7 +602,8 @@ class JustCount { class WriteEntries { public: - WriteEntries(UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + contexts_(contexts), unigrams_(unigrams), middle_(middle), longest_(longest), @@ -444,7 +620,13 @@ class WriteEntries { middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff); } - void Middle(const unsigned char mid_idx, WordIndex key, const ProbBackoff &weights) { + void Middle(const unsigned char mid_idx, const WordIndex *before, WordIndex key, ProbBackoff weights) { + // Order (mid_idx+2). + ContextReader &context = contexts_[mid_idx + 1]; + if (context && !memcmp(before, *context, sizeof(WordIndex) * (mid_idx + 1)) && (*context)[mid_idx + 1] == key) { + SetExtension(weights.backoff); + ++context; + } middle_[mid_idx].Insert(key, weights.prob, weights.backoff); } @@ -455,6 +637,7 @@ class WriteEntries { void Cleanup() {} private: + ContextReader *contexts_; UnigramValue *const unigrams_; BitPackedMiddle *const middle_; BitPackedLongest &longest_; @@ -463,14 +646,15 @@ class WriteEntries { template <class Doing> class RecursiveInsert { public: - RecursiveInsert(SortedFileReader *inputs, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : - doing_(unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), words_(new WordIndex[order]), order_minus_2_(order - 2) { + RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : + doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { } // Outer unigram loop. void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) { util::ErsatzProgress progress(progress_out, message, unigram_count + 1); for (words_[0] = 0; ; ++words_[0]) { + progress.Set(words_[0]); WordIndex min_continue = unigram_count; for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) { if (other->Ended()) continue; @@ -479,7 +663,6 @@ template <class Doing> class RecursiveInsert { // This will write at unigram_count. This is by design so that the next pointers will make sense. doing_.Unigrams(words_[0], min_continue + 1); if (min_continue == unigram_count) break; - progress += min_continue - words_[0]; words_[0] = min_continue; Middle(0); } @@ -497,7 +680,7 @@ template <class Doing> class RecursiveInsert { SortedFileReader &reader = inputs_[mid_idx]; - if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + mid_idx + 1, reader.Header())) { + if (reader.Ended() || !HeadMatch(words_, words_ + mid_idx + 1, reader.Header())) { // This order doesn't have a header match, but longer ones might. MiddleAllBlank(mid_idx); return; @@ -509,7 +692,7 @@ template <class Doing> class RecursiveInsert { while (count) { WordIndex min_continue = std::numeric_limits<WordIndex>::max(); for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header())) + if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); } while (true) { @@ -521,7 +704,7 @@ template <class Doing> class RecursiveInsert { } ProbBackoff weights; reader.ReadWeights(weights); - doing_.Middle(mid_idx, current, weights); + doing_.Middle(mid_idx, words_, current, weights); --count; if (current == min_continue) { words_[mid_idx + 1] = min_continue; @@ -542,7 +725,7 @@ template <class Doing> class RecursiveInsert { while (true) { WordIndex min_continue = std::numeric_limits<WordIndex>::max(); for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header())) + if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); } if (min_continue == std::numeric_limits<WordIndex>::max()) return; @@ -554,7 +737,7 @@ template <class Doing> class RecursiveInsert { void Longest() { SortedFileReader &reader = *(inputs_end_ - 1); - if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + order_minus_2_ + 1, reader.Header())) return; + if (reader.Ended() || !HeadMatch(words_, words_ + order_minus_2_ + 1, reader.Header())) return; WordIndex count = reader.ReadCount(); for (WordIndex i = 0; i < count; ++i) { WordIndex word = reader.ReadWord(); @@ -571,7 +754,7 @@ template <class Doing> class RecursiveInsert { SortedFileReader *inputs_; SortedFileReader *inputs_end_; - util::scoped_array<WordIndex> words_; + WordIndex words_[kMaxOrder]; const unsigned char order_minus_2_; }; @@ -586,17 +769,21 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) { SortedFileReader inputs[counts.size() - 1]; + ContextReader contexts[counts.size() - 1]; for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; assembled << file_prefix << static_cast<unsigned int>(i) << "_merged"; inputs[i-2].Init(assembled.str(), i); - unlink(assembled.str().c_str()); + RemoveOrThrow(assembled.str().c_str()); + assembled << kContextSuffix; + contexts[i-2].Reset(assembled.str().c_str(), (i-1) * sizeof(WordIndex)); + RemoveOrThrow(assembled.str().c_str()); } std::vector<uint64_t> fixed_counts(counts.size()); { - RecursiveInsert<JustCount> counter(inputs, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert<JustCount> counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } SanityCheckCounts(counts, fixed_counts); @@ -609,21 +796,38 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert<WriteEntries> inserter(inputs, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert<WriteEntries> inserter(inputs, contexts, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } // Fill unigram probabilities. { std::string name(file_prefix + "unigrams"); - util::scoped_FILE file(fopen(name.c_str(), "r")); - if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed"); + util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); for (WordIndex i = 0; i < counts[0]; ++i) { ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); + if (contexts[0] && **contexts[0] == i) { + SetExtension(unigrams[i].weights.backoff); + ++contexts[0]; + } } unlink(name.c_str()); } + // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. + for (unsigned char order = 2; order <= counts.size(); ++order) { + const ContextReader &context = contexts[order - 2]; + if (context) { + FormatLoadException e; + e << "An " << static_cast<unsigned int>(order) << "-gram has the context (i.e. all but the last word):"; + for (const WordIndex *i = *context; i != *context + order - 1; ++i) { + e << ' ' << *i; + } + e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not."; + throw e; + } + } + /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. if (!out.middle.empty()) { @@ -634,19 +838,27 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun } } +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} + } // namespace void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; + if (!temporary_directory.empty() && temporary_directory[temporary_directory.size() - 1] != '/' && IsDirectory(temporary_directory.c_str())) + temporary_directory += '/'; } else if (config.write_mmap) { temporary_directory = config.write_mmap; } else { temporary_directory = file; } // Null on end is kludge to ensure null termination. - temporary_directory += "-tmp-XXXXXX"; + temporary_directory += "_trie_tmp_XXXXXX"; temporary_directory += '\0'; if (!mkdtemp(&temporary_directory[0])) { UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str()); @@ -658,7 +870,7 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v // At least 1MB sorting memory. ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory.c_str(), counts, config, *this, backing); + BuildTrie(temporary_directory, counts, config, *this, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } |