summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r--klm/lm/search_trie.cc88
1 files changed, 31 insertions, 57 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 4bd3f4ee..ffadfa94 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -13,6 +13,7 @@
#include "lm/weights.hh"
#include "lm/word_index.hh"
#include "util/ersatz_progress.hh"
+#include "util/mmap.hh"
#include "util/proxy_iterator.hh"
#include "util/scoped.hh"
#include "util/sized_iterator.hh"
@@ -20,14 +21,15 @@
#include <algorithm>
#include <cstring>
#include <cstdio>
+#include <cstdlib>
#include <queue>
#include <limits>
#include <numeric>
#include <vector>
-#include <sys/mman.h>
-#include <sys/types.h>
-#include <sys/stat.h>
+#if defined(_WIN32) || defined(_WIN64)
+#include <windows.h>
+#endif
namespace lm {
namespace ngram {
@@ -195,7 +197,7 @@ class SRISucks {
void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) {
for (unsigned char i = 0; i < kMaxOrder - 1; ++i) {
- it_[i] = &*values_[i].begin();
+ it_[i] = values_[i].empty() ? NULL : &*values_[i].begin();
}
messages_[0].Apply(it_, unigram_file);
BackoffMessages *messages = messages_ + 1;
@@ -227,8 +229,8 @@ class SRISucks {
class FindBlanks {
public:
- FindBlanks(uint64_t *counts, unsigned char order, const ProbBackoff *unigrams, SRISucks &messages)
- : counts_(counts), longest_counts_(counts + order - 1), unigrams_(unigrams), sri_(messages) {}
+ FindBlanks(unsigned char order, const ProbBackoff *unigrams, SRISucks &messages)
+ : counts_(order), unigrams_(unigrams), sri_(messages) {}
float UnigramProb(WordIndex index) const {
return unigrams_[index].prob;
@@ -248,7 +250,7 @@ class FindBlanks {
}
void Longest(const void * /*data*/) {
- ++*longest_counts_;
+ ++counts_.back();
}
// Unigrams wrote one past.
@@ -256,8 +258,12 @@ class FindBlanks {
--counts_[0];
}
+ const std::vector<uint64_t> &Counts() const {
+ return counts_;
+ }
+
private:
- uint64_t *const counts_, *const longest_counts_;
+ std::vector<uint64_t> counts_;
const ProbBackoff *unigrams_;
@@ -375,7 +381,7 @@ template <class Doing> class BlankManager {
template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) {
util::ErsatzProgress progress(progress_out, message, unigram_count + 1);
- unsigned int unigram = 0;
+ WordIndex unigram = 0;
std::priority_queue<Gram> grams;
grams.push(Gram(&unigram, 1));
for (unsigned char i = 2; i <= total_order; ++i) {
@@ -461,42 +467,33 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c
} // namespace
-template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
RecordReader inputs[kMaxOrder - 1];
RecordReader contexts[kMaxOrder - 1];
for (unsigned char i = 2; i <= counts.size(); ++i) {
- std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(i) << "_merged";
- inputs[i-2].Init(assembled.str(), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff)));
- util::RemoveOrThrow(assembled.str().c_str());
- assembled << kContextSuffix;
- contexts[i-2].Init(assembled.str(), (i-1) * sizeof(WordIndex));
- util::RemoveOrThrow(assembled.str().c_str());
+ inputs[i-2].Init(files.Full(i), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff)));
+ contexts[i-2].Init(files.Context(i), (i-1) * sizeof(WordIndex));
}
SRISucks sri;
- std::vector<uint64_t> fixed_counts(counts.size());
+ std::vector<uint64_t> fixed_counts;
+ util::scoped_FILE unigram_file;
+ util::scoped_fd unigram_fd(files.StealUnigram());
{
- std::string temp(file_prefix); temp += "unigrams";
- util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str()));
util::scoped_memory unigrams;
- MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
- FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
+ MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
+ FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder);
+ fixed_counts = finder.Counts();
}
+ unigram_file.reset(util::FDOpenOrThrow(unigram_fd));
for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) {
if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading");
}
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
- util::scoped_FILE unigram_file;
- {
- std::string name(file_prefix + "unigrams");
- unigram_file.reset(OpenOrThrow(name.c_str(), "r+"));
- util::RemoveOrThrow(name.c_str());
- }
sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
@@ -587,42 +584,19 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBin
longest.LoadedBinary();
}
-namespace {
-bool IsDirectory(const char *path) {
- struct stat info;
- if (0 != stat(path, &info)) return false;
- return S_ISDIR(info.st_mode);
-}
-} // namespace
-
template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
- std::string temporary_directory;
+ std::string temporary_prefix;
if (config.temporary_directory_prefix) {
- temporary_directory = config.temporary_directory_prefix;
- if (!temporary_directory.empty() && temporary_directory[temporary_directory.size() - 1] != '/' && IsDirectory(temporary_directory.c_str()))
- temporary_directory += '/';
+ temporary_prefix = config.temporary_directory_prefix;
} else if (config.write_mmap) {
- temporary_directory = config.write_mmap;
+ temporary_prefix = config.write_mmap;
} else {
- temporary_directory = file;
- }
- // Null on end is kludge to ensure null termination.
- temporary_directory += "_trie_tmp_XXXXXX";
- temporary_directory += '\0';
- if (!mkdtemp(&temporary_directory[0])) {
- UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str());
+ temporary_prefix = file;
}
- // Chop off null kludge.
- temporary_directory.resize(strlen(temporary_directory.c_str()));
- // Add directory delimiter. Assumes a real operating system.
- temporary_directory += '/';
// At least 1MB sorting memory.
- ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
+ SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab);
- BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing);
- if (rmdir(temporary_directory.c_str()) && config.messages) {
- *config.messages << "Failed to delete " << temporary_directory << std::endl;
- }
+ BuildTrie(sorted, counts, config, *this, quant_, vocab, backing);
}
template class TrieSearch<DontQuantize, DontBhiksha>;