diff options
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 50 |
1 files changed, 27 insertions, 23 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 1060ddef..63631223 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -11,6 +11,7 @@ #include "lm/word_index.hh" #include "util/ersatz_progress.hh" #include "util/file_piece.hh" +#include "util/have.hh" #include "util/proxy_iterator.hh" #include "util/scoped.hh" @@ -20,7 +21,6 @@ #include <cstdio> #include <deque> #include <limits> -//#include <parallel/algorithm> #include <vector> #include <sys/mman.h> @@ -170,7 +170,7 @@ template <class Proxy> class CompareRecords : public std::binary_function<const return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices()); } bool operator()(const std::string &first, const std::string &second) const { - return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(first.data())); + return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(second.data())); } private: @@ -384,7 +384,6 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); - // TODO: __gnu_parallel::sort here. std::sort(context_begin, context_end, CompareRecords<PartialViewProxy>(order - 1)); std::string name(ngram_file_name + kContextSuffix); @@ -406,16 +405,16 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil class ContextReader { public: - ContextReader() : length_(0) {} + ContextReader() : valid_(false) {} - ContextReader(const char *name, size_t length) : file_(OpenOrThrow(name, "r")), length_(length), words_(length), valid_(true) { - ++*this; + ContextReader(const char *name, unsigned char order) { + Reset(name, order); } - void Reset(const char *name, size_t length) { + void Reset(const char *name, unsigned char order) { file_.reset(OpenOrThrow(name, "r")); - length_ = length; - words_.resize(length); + length_ = sizeof(WordIndex) * static_cast<size_t>(order); + words_.resize(order); valid_ = true; ++*this; } @@ -449,14 +448,14 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_ const size_t context_size = sizeof(WordIndex) * (order - 1); std::string first_name(first_base + kContextSuffix); std::string second_name(second_base + kContextSuffix); - ContextReader first(first_name.c_str(), context_size), second(second_name.c_str(), context_size); + ContextReader first(first_name.c_str(), order - 1), second(second_name.c_str(), order - 1); RemoveOrThrow(first_name.c_str()); RemoveOrThrow(second_name.c_str()); std::string out_name(out_base + kContextSuffix); util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w")); while (first && second) { for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) { - if (f == *first + order) { + if (f == *first + order - 1) { // Equal. WriteOrThrow(out.get(), *first, context_size); ++first; @@ -475,7 +474,10 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_ } } } - CopyRestOrThrow((first ? first : second).GetFile(), out.get()); + ContextReader &remaining = first ? first : second; + if (!remaining) return; + WriteOrThrow(out.get(), *remaining, context_size); + CopyRestOrThrow(remaining.GetFile(), out.get()); } void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { @@ -502,7 +504,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st } // Sort full records by full n-gram. EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - // TODO: __gnu_parallel::sort here. + // parallel_sort uses too much RAM std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords<EntryProxy>(order)); files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); WriteContextFile(begin, out_end, files.back(), entry_size, order); @@ -533,21 +535,22 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st } } -void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, const std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { { std::string unigram_name = file_prefix + "unigrams"; util::scoped_fd unigram_file; util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get())); + CheckSpecials(config, vocab); } // Only use as much buffer as we need. size_t buffer_use = 0; for (unsigned int order = 2; order < counts.size(); ++order) { - buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); + buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); } - buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); - buffer = std::min(buffer, buffer_use); + buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); + buffer = std::min<size_t>(buffer, buffer_use); util::scoped_memory mem; mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); @@ -767,7 +770,7 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u } } -void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) { +void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) { SortedFileReader inputs[counts.size() - 1]; ContextReader contexts[counts.size() - 1]; @@ -777,7 +780,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun inputs[i-2].Init(assembled.str(), i); RemoveOrThrow(assembled.str().c_str()); assembled << kContextSuffix; - contexts[i-2].Reset(assembled.str().c_str(), (i-1) * sizeof(WordIndex)); + contexts[i-2].Reset(assembled.str().c_str(), i-1); RemoveOrThrow(assembled.str().c_str()); } @@ -787,8 +790,9 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } SanityCheckCounts(counts, fixed_counts); + counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::kModelType, fixed_counts, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -811,7 +815,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun ++contexts[0]; } } - unlink(name.c_str()); + RemoveOrThrow(name.c_str()); } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. @@ -823,7 +827,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun for (const WordIndex *i = *context; i != *context + order - 1; ++i) { e << ' ' << *i; } - e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not."; + e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not"; throw e; } } @@ -868,7 +872,7 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v // Add directory delimiter. Assumes a real operating system. temporary_directory += '/'; // At least 1MB sorting memory. - ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); + ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); BuildTrie(temporary_directory, counts, config, *this, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { |