diff options
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 88 |
1 files changed, 31 insertions, 57 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 4bd3f4ee..ffadfa94 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -13,6 +13,7 @@ #include "lm/weights.hh" #include "lm/word_index.hh" #include "util/ersatz_progress.hh" +#include "util/mmap.hh" #include "util/proxy_iterator.hh" #include "util/scoped.hh" #include "util/sized_iterator.hh" @@ -20,14 +21,15 @@ #include <algorithm> #include <cstring> #include <cstdio> +#include <cstdlib> #include <queue> #include <limits> #include <numeric> #include <vector> -#include <sys/mman.h> -#include <sys/types.h> -#include <sys/stat.h> +#if defined(_WIN32) || defined(_WIN64) +#include <windows.h> +#endif namespace lm { namespace ngram { @@ -195,7 +197,7 @@ class SRISucks { void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) { for (unsigned char i = 0; i < kMaxOrder - 1; ++i) { - it_[i] = &*values_[i].begin(); + it_[i] = values_[i].empty() ? NULL : &*values_[i].begin(); } messages_[0].Apply(it_, unigram_file); BackoffMessages *messages = messages_ + 1; @@ -227,8 +229,8 @@ class SRISucks { class FindBlanks { public: - FindBlanks(uint64_t *counts, unsigned char order, const ProbBackoff *unigrams, SRISucks &messages) - : counts_(counts), longest_counts_(counts + order - 1), unigrams_(unigrams), sri_(messages) {} + FindBlanks(unsigned char order, const ProbBackoff *unigrams, SRISucks &messages) + : counts_(order), unigrams_(unigrams), sri_(messages) {} float UnigramProb(WordIndex index) const { return unigrams_[index].prob; @@ -248,7 +250,7 @@ class FindBlanks { } void Longest(const void * /*data*/) { - ++*longest_counts_; + ++counts_.back(); } // Unigrams wrote one past. @@ -256,8 +258,12 @@ class FindBlanks { --counts_[0]; } + const std::vector<uint64_t> &Counts() const { + return counts_; + } + private: - uint64_t *const counts_, *const longest_counts_; + std::vector<uint64_t> counts_; const ProbBackoff *unigrams_; @@ -375,7 +381,7 @@ template <class Doing> class BlankManager { template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) { util::ErsatzProgress progress(progress_out, message, unigram_count + 1); - unsigned int unigram = 0; + WordIndex unigram = 0; std::priority_queue<Gram> grams; grams.push(Gram(&unigram, 1)); for (unsigned char i = 2; i <= total_order; ++i) { @@ -461,42 +467,33 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c } // namespace -template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { RecordReader inputs[kMaxOrder - 1]; RecordReader contexts[kMaxOrder - 1]; for (unsigned char i = 2; i <= counts.size(); ++i) { - std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(i) << "_merged"; - inputs[i-2].Init(assembled.str(), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff))); - util::RemoveOrThrow(assembled.str().c_str()); - assembled << kContextSuffix; - contexts[i-2].Init(assembled.str(), (i-1) * sizeof(WordIndex)); - util::RemoveOrThrow(assembled.str().c_str()); + inputs[i-2].Init(files.Full(i), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff))); + contexts[i-2].Init(files.Context(i), (i-1) * sizeof(WordIndex)); } SRISucks sri; - std::vector<uint64_t> fixed_counts(counts.size()); + std::vector<uint64_t> fixed_counts; + util::scoped_FILE unigram_file; + util::scoped_fd unigram_fd(files.StealUnigram()); { - std::string temp(file_prefix); temp += "unigrams"; - util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str())); util::scoped_memory unigrams; - MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); - FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); + MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); + FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); + fixed_counts = finder.Counts(); } + unigram_file.reset(util::FDOpenOrThrow(unigram_fd)); for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) { if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - util::scoped_FILE unigram_file; - { - std::string name(file_prefix + "unigrams"); - unigram_file.reset(OpenOrThrow(name.c_str(), "r+")); - util::RemoveOrThrow(name.c_str()); - } sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); @@ -587,42 +584,19 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBin longest.LoadedBinary(); } -namespace { -bool IsDirectory(const char *path) { - struct stat info; - if (0 != stat(path, &info)) return false; - return S_ISDIR(info.st_mode); -} -} // namespace - template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { - std::string temporary_directory; + std::string temporary_prefix; if (config.temporary_directory_prefix) { - temporary_directory = config.temporary_directory_prefix; - if (!temporary_directory.empty() && temporary_directory[temporary_directory.size() - 1] != '/' && IsDirectory(temporary_directory.c_str())) - temporary_directory += '/'; + temporary_prefix = config.temporary_directory_prefix; } else if (config.write_mmap) { - temporary_directory = config.write_mmap; + temporary_prefix = config.write_mmap; } else { - temporary_directory = file; - } - // Null on end is kludge to ensure null termination. - temporary_directory += "_trie_tmp_XXXXXX"; - temporary_directory += '\0'; - if (!mkdtemp(&temporary_directory[0])) { - UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str()); + temporary_prefix = file; } - // Chop off null kludge. - temporary_directory.resize(strlen(temporary_directory.c_str())); - // Add directory delimiter. Assumes a real operating system. - temporary_directory += '/'; // At least 1MB sorting memory. - ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); + SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab); - BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing); - if (rmdir(temporary_directory.c_str()) && config.messages) { - *config.messages << "Failed to delete " << temporary_directory << std::endl; - } + BuildTrie(sorted, counts, config, *this, quant_, vocab, backing); } template class TrieSearch<DontQuantize, DontBhiksha>; |