summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r--klm/lm/search_trie.cc369
1 files changed, 272 insertions, 97 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 12294682..3aeeeca3 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -1,6 +1,7 @@
/* This is where the trie is built. It's on-disk. */
#include "lm/search_trie.hh"
+#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/read_arpa.hh"
#include "lm/trie.hh"
@@ -13,10 +14,10 @@
#include "util/scoped.hh"
#include <algorithm>
+#include <cmath>
#include <cstring>
#include <cstdio>
#include <deque>
-#include <iostream>
#include <limits>
//#include <parallel/algorithm>
#include <vector>
@@ -152,11 +153,11 @@ void ReadOrThrow(FILE *from, void *data, size_t size) {
if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size);
}
+const std::size_t kCopyBufSize = 512;
void CopyOrThrow(FILE *from, FILE *to, size_t size) {
- const size_t kBufSize = 512;
- char buf[kBufSize];
- for (size_t i = 0; i < size; i += kBufSize) {
- std::size_t amount = std::min(size - i, kBufSize);
+ char buf[std::min<size_t>(size, kCopyBufSize)];
+ for (size_t i = 0; i < size; i += kCopyBufSize) {
+ std::size_t amount = std::min(size - i, kCopyBufSize);
ReadOrThrow(from, buf, amount);
WriteOrThrow(to, buf, amount);
}
@@ -172,8 +173,10 @@ std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::str
if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing");
// Compress entries that being with the same (order-1) words.
for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) {
- const uint8_t *group_end = group_begin;
- for (group_end += entry_size; (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size); group_end += entry_size) {}
+ const uint8_t *group_end;
+ for (group_end = group_begin + entry_size;
+ (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size);
+ group_end += entry_size) {}
WriteOrThrow(out.get(), group_begin, prefix_size);
WordIndex group_size = (group_end - group_begin) / entry_size;
WriteOrThrow(out.get(), &group_size, sizeof(group_size));
@@ -188,7 +191,7 @@ std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::str
class SortedFileReader {
public:
- SortedFileReader() {}
+ SortedFileReader() : ended_(false) {}
void Init(const std::string &name, unsigned char order) {
file_.reset(fopen(name.c_str(), "r"));
@@ -206,25 +209,39 @@ class SortedFileReader {
std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); }
void NextHeader() {
- if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get()) && !Ended()) {
- UTIL_THROW(util::ErrnoException, "Short read of counts");
+ if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get())) {
+ if (feof(file_.get())) {
+ ended_ = true;
+ } else {
+ UTIL_THROW(util::ErrnoException, "Short read of counts");
+ }
}
}
- void ReadCount(WordIndex &to) {
- ReadOrThrow(file_.get(), &to, sizeof(WordIndex));
+ WordIndex ReadCount() {
+ WordIndex ret;
+ ReadOrThrow(file_.get(), &ret, sizeof(WordIndex));
+ return ret;
}
- void ReadWord(WordIndex &to) {
- ReadOrThrow(file_.get(), &to, sizeof(WordIndex));
+ WordIndex ReadWord() {
+ WordIndex ret;
+ ReadOrThrow(file_.get(), &ret, sizeof(WordIndex));
+ return ret;
}
- template <class Weights> void ReadWeights(Weights &to) {
- ReadOrThrow(file_.get(), &to, sizeof(Weights));
+ template <class Weights> void ReadWeights(Weights &weights) {
+ ReadOrThrow(file_.get(), &weights, sizeof(Weights));
}
bool Ended() {
- return feof(file_.get());
+ return ended_;
+ }
+
+ void Rewind() {
+ rewind(file_.get());
+ ended_ = false;
+ NextHeader();
}
FILE *File() { return file_.get(); }
@@ -233,12 +250,13 @@ class SortedFileReader {
util::scoped_FILE file_;
std::vector<WordIndex> header_;
+
+ bool ended_;
};
void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) {
WriteOrThrow(to, from.Header(), from.HeaderBytes());
- WordIndex count;
- from.ReadCount(count);
+ WordIndex count = from.ReadCount();
WriteOrThrow(to, &count, sizeof(WordIndex));
CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count);
@@ -263,25 +281,23 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha
}
// Merge at the entry level.
WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes());
- WordIndex first_count, second_count;
- first.ReadCount(first_count); second.ReadCount(second_count);
+ WordIndex first_count = first.ReadCount(), second_count = second.ReadCount();
WordIndex total_count = first_count + second_count;
WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex));
- WordIndex first_word, second_word;
- first.ReadWord(first_word); second.ReadWord(second_word);
+ WordIndex first_word = first.ReadWord(), second_word = second.ReadWord();
WordIndex first_index = 0, second_index = 0;
while (true) {
if (first_word < second_word) {
WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex));
CopyOrThrow(first.File(), out_file.get(), weights_size);
if (++first_index == first_count) break;
- first.ReadWord(first_word);
+ first_word = first.ReadWord();
} else {
WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex));
CopyOrThrow(second.File(), out_file.get(), weights_size);
if (++second_index == second_count) break;
- second.ReadWord(second_word);
+ second_word = second.ReadWord();
}
}
if (first_index == first_count) {
@@ -358,75 +374,219 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts,
{
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
- util::scoped_mmap unigram_mmap;
- unigram_mmap.reset(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff));
+ util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff));
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));
}
util::scoped_memory mem;
- mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED);
+ mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED);
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size());
ReadEnd(f);
}
-struct RecursiveInsertParams {
- WordIndex *words;
- SortedFileReader *files;
- unsigned char max_order;
- // This is an array of size order - 2.
- BitPackedMiddle *middle;
- // This has exactly one entry.
- BitPackedLongest *longest;
-};
-
-uint64_t RecursiveInsert(RecursiveInsertParams &params, unsigned char order) {
- SortedFileReader &file = params.files[order - 2];
- const uint64_t ret = (order == params.max_order) ? params.longest->InsertIndex() : params.middle[order - 2].InsertIndex();
- if (std::memcmp(params.words, file.Header(), sizeof(WordIndex) * (order - 1)))
- return ret;
- WordIndex count;
- file.ReadCount(count);
- WordIndex key;
- if (order == params.max_order) {
- Prob value;
- for (WordIndex i = 0; i < count; ++i) {
- file.ReadWord(key);
- file.ReadWeights(value);
- params.longest->Insert(key, value.prob);
- }
- file.NextHeader();
- return ret;
- }
- ProbBackoff value;
- for (WordIndex i = 0; i < count; ++i) {
- file.ReadWord(params.words[order - 1]);
- file.ReadWeights(value);
- params.middle[order - 2].Insert(
- params.words[order - 1],
- value.prob,
- value.backoff,
- RecursiveInsert(params, order + 1));
+bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) {
+ for (; words != words_end; ++words, ++header) {
+ if (*words != *header) {
+ assert(*words <= *header);
+ return false;
+ }
}
- file.NextHeader();
- return ret;
+ return true;
}
-void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, std::ostream *messages, TrieSearch &out) {
- UnigramValue *unigrams = out.unigram.Raw();
- // Load unigrams. Leave the next pointers uninitialized.
- {
- std::string name(file_prefix + "unigrams");
- util::scoped_FILE file(fopen(name.c_str(), "r"));
- if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed");
- for (WordIndex i = 0; i < counts[0]; ++i) {
- ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff));
+class JustCount {
+ public:
+ JustCount(UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order)
+ : counts_(counts), longest_counts_(counts + order - 1) {}
+
+ void Unigrams(WordIndex begin, WordIndex end) {
+ counts_[0] += end - begin;
}
- unlink(name.c_str());
+
+ void MiddleBlank(const unsigned char mid_idx, WordIndex /* idx */) {
+ ++counts_[mid_idx + 1];
+ }
+
+ void Middle(const unsigned char mid_idx, WordIndex /*key*/, const ProbBackoff &/*weights*/) {
+ ++counts_[mid_idx + 1];
+ }
+
+ void Longest(WordIndex /*key*/, Prob /*prob*/) {
+ ++*longest_counts_;
+ }
+
+ // Unigrams wrote one past.
+ void Cleanup() {
+ --counts_[0];
+ }
+
+ private:
+ uint64_t *const counts_, *const longest_counts_;
+};
+
+class WriteEntries {
+ public:
+ WriteEntries(UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) :
+ unigrams_(unigrams),
+ middle_(middle),
+ longest_(longest),
+ bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)) {}
+
+ void Unigrams(WordIndex begin, WordIndex end) {
+ uint64_t next = bigram_pack_.InsertIndex();
+ for (UnigramValue *i = unigrams_ + begin; i < unigrams_ + end; ++i) {
+ i->next = next;
+ }
+ }
+
+ void MiddleBlank(const unsigned char mid_idx, WordIndex key) {
+ middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff);
+ }
+
+ void Middle(const unsigned char mid_idx, WordIndex key, const ProbBackoff &weights) {
+ middle_[mid_idx].Insert(key, weights.prob, weights.backoff);
+ }
+
+ void Longest(WordIndex key, Prob prob) {
+ longest_.Insert(key, prob.prob);
+ }
+
+ void Cleanup() {}
+
+ private:
+ UnigramValue *const unigrams_;
+ BitPackedMiddle *const middle_;
+ BitPackedLongest &longest_;
+ BitPacked &bigram_pack_;
+};
+
+template <class Doing> class RecursiveInsert {
+ public:
+ RecursiveInsert(SortedFileReader *inputs, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) :
+ doing_(unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), words_(new WordIndex[order]), order_minus_2_(order - 2) {
+ }
+
+ // Outer unigram loop.
+ void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) {
+ util::ErsatzProgress progress(progress_out, message, unigram_count + 1);
+ for (words_[0] = 0; ; ++words_[0]) {
+ WordIndex min_continue = unigram_count;
+ for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) {
+ if (other->Ended()) continue;
+ min_continue = std::min(min_continue, other->Header()[0]);
+ }
+ // This will write at unigram_count. This is by design so that the next pointers will make sense.
+ doing_.Unigrams(words_[0], min_continue + 1);
+ if (min_continue == unigram_count) break;
+ progress += min_continue - words_[0];
+ words_[0] = min_continue;
+ Middle(0);
+ }
+ doing_.Cleanup();
+ }
+
+ private:
+ void Middle(const unsigned char mid_idx) {
+ // (mid_idx + 2)-gram.
+ if (mid_idx == order_minus_2_) {
+ Longest();
+ return;
+ }
+ // Orders [2, order)
+
+ SortedFileReader &reader = inputs_[mid_idx];
+
+ if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + mid_idx + 1, reader.Header())) {
+ // This order doesn't have a header match, but longer ones might.
+ MiddleAllBlank(mid_idx);
+ return;
+ }
+
+ // There is a header match.
+ WordIndex count = reader.ReadCount();
+ WordIndex current = reader.ReadWord();
+ while (count) {
+ WordIndex min_continue = std::numeric_limits<WordIndex>::max();
+ for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) {
+ if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header()))
+ min_continue = std::min(min_continue, other->Header()[mid_idx + 1]);
+ }
+ while (true) {
+ if (current > min_continue) {
+ doing_.MiddleBlank(mid_idx, min_continue);
+ words_[mid_idx + 1] = min_continue;
+ Middle(mid_idx + 1);
+ break;
+ }
+ ProbBackoff weights;
+ reader.ReadWeights(weights);
+ doing_.Middle(mid_idx, current, weights);
+ --count;
+ if (current == min_continue) {
+ words_[mid_idx + 1] = min_continue;
+ Middle(mid_idx + 1);
+ if (count) current = reader.ReadWord();
+ break;
+ }
+ if (!count) break;
+ current = reader.ReadWord();
+ }
+ }
+ // Count is now zero. Finish off remaining blanks.
+ MiddleAllBlank(mid_idx);
+ reader.NextHeader();
+ }
+
+ void MiddleAllBlank(const unsigned char mid_idx) {
+ while (true) {
+ WordIndex min_continue = std::numeric_limits<WordIndex>::max();
+ for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) {
+ if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header()))
+ min_continue = std::min(min_continue, other->Header()[mid_idx + 1]);
+ }
+ if (min_continue == std::numeric_limits<WordIndex>::max()) return;
+ doing_.MiddleBlank(mid_idx, min_continue);
+ words_[mid_idx + 1] = min_continue;
+ Middle(mid_idx + 1);
+ }
+ }
+
+ void Longest() {
+ SortedFileReader &reader = *(inputs_end_ - 1);
+ if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + order_minus_2_ + 1, reader.Header())) return;
+ WordIndex count = reader.ReadCount();
+ for (WordIndex i = 0; i < count; ++i) {
+ WordIndex word = reader.ReadWord();
+ Prob prob;
+ reader.ReadWeights(prob);
+ doing_.Longest(word, prob);
+ }
+ reader.NextHeader();
+ return;
+ }
+
+ Doing doing_;
+
+ SortedFileReader *inputs_;
+ SortedFileReader *inputs_end_;
+
+ util::scoped_array<WordIndex> words_;
+
+ const unsigned char order_minus_2_;
+};
+
+void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {
+ if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]);
+ if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant");
+ for (unsigned char i = 0; i < initial.size(); ++i) {
+ if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected. This shouldn't happen");
}
+}
- // inputs[0] is bigrams.
+void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) {
SortedFileReader inputs[counts.size() - 1];
+
for (unsigned char i = 2; i <= counts.size(); ++i) {
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(i) << "_merged";
@@ -434,36 +594,49 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
unlink(assembled.str().c_str());
}
- // words[0] is unigrams.
- WordIndex words[counts.size()];
- RecursiveInsertParams params;
- params.words = words;
- params.files = inputs;
- params.max_order = static_cast<unsigned char>(counts.size());
- params.middle = &*out.middle.begin();
- params.longest = &out.longest;
+ std::vector<uint64_t> fixed_counts(counts.size());
+ {
+ RecursiveInsert<JustCount> counter(inputs, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
+ counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
+ }
+ SanityCheckCounts(counts, fixed_counts);
+
+ out.SetupMemory(GrowForSearch(config, TrieSearch::kModelType, fixed_counts, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config);
+
+ for (unsigned char i = 2; i <= counts.size(); ++i) {
+ inputs[i-2].Rewind();
+ }
+ UnigramValue *unigrams = out.unigram.Raw();
+ // Fill entries except unigram probabilities.
{
- util::ErsatzProgress progress(messages, "Building trie", counts[0]);
- for (words[0] = 0; words[0] < counts[0]; ++words[0], ++progress) {
- unigrams[words[0]].next = RecursiveInsert(params, 2);
+ RecursiveInsert<WriteEntries> inserter(inputs, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
+ inserter.Apply(config.messages, "Building trie", fixed_counts[0]);
+ }
+
+ // Fill unigram probabilities.
+ {
+ std::string name(file_prefix + "unigrams");
+ util::scoped_FILE file(fopen(name.c_str(), "r"));
+ if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed");
+ for (WordIndex i = 0; i < counts[0]; ++i) {
+ ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff));
}
+ unlink(name.c_str());
}
/* Set ending offsets so the last entry will be sized properly */
+ // Last entry for unigrams was already set.
if (!out.middle.empty()) {
- unigrams[counts[0]].next = out.middle.front().InsertIndex();
for (size_t i = 0; i < out.middle.size() - 1; ++i) {
out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex());
}
out.middle.back().FinishedLoading(out.longest.InsertIndex());
- } else {
- unigrams[counts[0]].next = out.longest.InsertIndex();
}
}
} // namespace
-void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab) {
+void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
std::string temporary_directory;
if (config.temporary_directory_prefix) {
temporary_directory = config.temporary_directory_prefix;
@@ -473,7 +646,8 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const
temporary_directory = file;
}
// Null on end is kludge to ensure null termination.
- temporary_directory += "-tmp-XXXXXX\0";
+ temporary_directory += "-tmp-XXXXXX";
+ temporary_directory += '\0';
if (!mkdtemp(&temporary_directory[0])) {
UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str());
}
@@ -483,9 +657,10 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const
temporary_directory += '/';
// At least 1MB sorting memory.
ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
- BuildTrie(temporary_directory.c_str(), counts, config.messages, *this);
- if (rmdir(temporary_directory.c_str())) {
- std::cerr << "Failed to delete " << temporary_directory << std::endl;
+
+ BuildTrie(temporary_directory.c_str(), counts, config, *this, backing);
+ if (rmdir(temporary_directory.c_str()) && config.messages) {
+ *config.messages << "Failed to delete " << temporary_directory << std::endl;
}
}