summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r--klm/lm/search_trie.cc1052
1 files changed, 348 insertions, 704 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 05059ffb..6479813b 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -2,26 +2,25 @@
#include "lm/search_trie.hh"
#include "lm/bhiksha.hh"
+#include "lm/binary_format.hh"
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/max_order.hh"
#include "lm/quantize.hh"
-#include "lm/read_arpa.hh"
#include "lm/trie.hh"
+#include "lm/trie_sort.hh"
#include "lm/vocab.hh"
#include "lm/weights.hh"
#include "lm/word_index.hh"
#include "util/ersatz_progress.hh"
-#include "util/file_piece.hh"
-#include "util/have.hh"
#include "util/proxy_iterator.hh"
#include "util/scoped.hh"
+#include "util/sized_iterator.hh"
#include <algorithm>
-#include <cmath>
#include <cstring>
#include <cstdio>
-#include <deque>
+#include <queue>
#include <limits>
#include <numeric>
#include <vector>
@@ -29,575 +28,221 @@
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
-#include <fcntl.h>
-#include <stdlib.h>
-#include <unistd.h>
namespace lm {
namespace ngram {
namespace trie {
namespace {
-/* An entry is a n-gram with probability. It consists of:
- * WordIndex[order]
- * float probability
- * backoff probability (omitted for highest order n-gram)
- * These are stored consecutively in memory. We want to sort them.
- *
- * The problem is the length depends on order (but all n-grams being compared
- * have the same order). Allocating each entry on the heap (i.e. std::vector
- * or std::string) then sorting pointers is the normal solution. But that's
- * too memory inefficient. A lot of this code is just here to force std::sort
- * to work with records where length is specified at runtime (and avoid using
- * Boost for LM code). I could have used qsort, but the point is to also
- * support __gnu_cxx:parallel_sort which doesn't have a qsort version.
- */
-
-class EntryIterator {
- public:
- EntryIterator() {}
-
- EntryIterator(void *ptr, std::size_t size) : ptr_(static_cast<uint8_t*>(ptr)), size_(size) {}
-
- bool operator==(const EntryIterator &other) const {
- return ptr_ == other.ptr_;
- }
- bool operator<(const EntryIterator &other) const {
- return ptr_ < other.ptr_;
- }
- EntryIterator &operator+=(std::ptrdiff_t amount) {
- ptr_ += amount * size_;
- return *this;
- }
- std::ptrdiff_t operator-(const EntryIterator &other) const {
- return (ptr_ - other.ptr_) / size_;
- }
-
- const void *Data() const { return ptr_; }
- void *Data() { return ptr_; }
- std::size_t EntrySize() const { return size_; }
-
- private:
- uint8_t *ptr_;
- std::size_t size_;
-};
-
-class EntryProxy {
- public:
- EntryProxy() {}
-
- EntryProxy(void *ptr, std::size_t size) : inner_(ptr, size) {}
-
- operator std::string() const {
- return std::string(reinterpret_cast<const char*>(inner_.Data()), inner_.EntrySize());
- }
-
- EntryProxy &operator=(const EntryProxy &from) {
- memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize());
- return *this;
- }
-
- EntryProxy &operator=(const std::string &from) {
- memcpy(inner_.Data(), from.data(), inner_.EntrySize());
- return *this;
- }
-
- const WordIndex *Indices() const {
- return reinterpret_cast<const WordIndex*>(inner_.Data());
- }
-
- private:
- friend class util::ProxyIterator<EntryProxy>;
-
- typedef std::string value_type;
-
- typedef EntryIterator InnerIterator;
- InnerIterator &Inner() { return inner_; }
- const InnerIterator &Inner() const { return inner_; }
- InnerIterator inner_;
-};
-
-typedef util::ProxyIterator<EntryProxy> NGramIter;
-
-// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams.
-class PartialViewProxy {
- public:
- PartialViewProxy() : attention_size_(0), inner_() {}
-
- PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {}
-
- operator std::string() const {
- return std::string(reinterpret_cast<const char*>(inner_.Data()), attention_size_);
- }
-
- PartialViewProxy &operator=(const PartialViewProxy &from) {
- memcpy(inner_.Data(), from.inner_.Data(), attention_size_);
- return *this;
- }
-
- PartialViewProxy &operator=(const std::string &from) {
- memcpy(inner_.Data(), from.data(), attention_size_);
- return *this;
- }
-
- const WordIndex *Indices() const {
- return reinterpret_cast<const WordIndex*>(inner_.Data());
- }
-
- private:
- friend class util::ProxyIterator<PartialViewProxy>;
-
- typedef std::string value_type;
-
- const std::size_t attention_size_;
-
- typedef EntryIterator InnerIterator;
- InnerIterator &Inner() { return inner_; }
- const InnerIterator &Inner() const { return inner_; }
- InnerIterator inner_;
-};
-
-typedef util::ProxyIterator<PartialViewProxy> PartialIter;
-
-template <class Proxy> class CompareRecords : public std::binary_function<const Proxy &, const Proxy &, bool> {
- public:
- explicit CompareRecords(unsigned char order) : order_(order) {}
-
- bool operator()(const Proxy &first, const Proxy &second) const {
- return Compare(first.Indices(), second.Indices());
- }
- bool operator()(const Proxy &first, const std::string &second) const {
- return Compare(first.Indices(), reinterpret_cast<const WordIndex*>(second.data()));
- }
- bool operator()(const std::string &first, const Proxy &second) const {
- return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices());
- }
- bool operator()(const std::string &first, const std::string &second) const {
- return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(second.data()));
- }
-
- private:
- bool Compare(const WordIndex *first, const WordIndex *second) const {
- const WordIndex *end = first + order_;
- for (; first != end; ++first, ++second) {
- if (*first < *second) return true;
- if (*first > *second) return false;
- }
- return false;
- }
-
- unsigned char order_;
-};
-
-FILE *OpenOrThrow(const char *name, const char *mode) {
- FILE *ret = fopen(name, mode);
- if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode);
- return ret;
-}
-
-void WriteOrThrow(FILE *to, const void *data, size_t size) {
- assert(size);
- if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
-}
-
void ReadOrThrow(FILE *from, void *data, size_t size) {
- if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size);
-}
-
-const std::size_t kCopyBufSize = 512;
-void CopyOrThrow(FILE *from, FILE *to, size_t size) {
- char buf[std::min<size_t>(size, kCopyBufSize)];
- for (size_t i = 0; i < size; i += kCopyBufSize) {
- std::size_t amount = std::min(size - i, kCopyBufSize);
- ReadOrThrow(from, buf, amount);
- WriteOrThrow(to, buf, amount);
- }
+ UTIL_THROW_IF(1 != std::fread(data, size, 1, from), util::ErrnoException, "Short read");
}
-void CopyRestOrThrow(FILE *from, FILE *to) {
- char buf[kCopyBufSize];
- size_t amount;
- while ((amount = fread(buf, 1, kCopyBufSize, from))) {
- WriteOrThrow(to, buf, amount);
+int Compare(unsigned char order, const void *first_void, const void *second_void) {
+ const WordIndex *first = reinterpret_cast<const WordIndex*>(first_void), *second = reinterpret_cast<const WordIndex*>(second_void);
+ const WordIndex *end = first + order;
+ for (; first != end; ++first, ++second) {
+ if (*first < *second) return -1;
+ if (*first > *second) return 1;
}
- if (!feof(from)) UTIL_THROW(util::ErrnoException, "Short read");
-}
-
-void RemoveOrThrow(const char *name) {
- if (std::remove(name)) UTIL_THROW(util::ErrnoException, "Could not remove " << name);
+ return 0;
}
-std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) {
- const std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
- const std::size_t prefix_size = sizeof(WordIndex) * (order - 1);
- std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch;
- std::string ret(assembled.str());
- util::scoped_FILE out(OpenOrThrow(ret.c_str(), "w"));
- // Compress entries that being with the same (order-1) words.
- for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) {
- const uint8_t *group_end;
- for (group_end = group_begin + entry_size;
- (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size);
- group_end += entry_size) {}
- WriteOrThrow(out.get(), group_begin, prefix_size);
- WordIndex group_size = (group_end - group_begin) / entry_size;
- WriteOrThrow(out.get(), &group_size, sizeof(group_size));
- for (const uint8_t *i = group_begin; i != group_end; i += entry_size) {
- WriteOrThrow(out.get(), i + prefix_size, sizeof(WordIndex));
- WriteOrThrow(out.get(), i + sizeof(WordIndex) * order, weights_size);
- }
- group_begin = group_end;
- }
- return ret;
-}
+struct ProbPointer {
+ unsigned char array;
+ uint64_t index;
+};
-class SortedFileReader {
+// Array of n-grams and float indices.
+class BackoffMessages {
public:
- SortedFileReader() : ended_(false) {}
-
- void Init(const std::string &name, unsigned char order) {
- file_.reset(OpenOrThrow(name.c_str(), "r"));
- header_.resize(order - 1);
- NextHeader();
+ void Init(std::size_t entry_size) {
+ current_ = NULL;
+ allocated_ = NULL;
+ entry_size_ = entry_size;
}
- // Preceding words.
- const WordIndex *Header() const {
- return &*header_.begin();
- }
- const std::vector<WordIndex> &HeaderVector() const { return header_;}
-
- std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); }
-
- void NextHeader() {
- if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get())) {
- if (feof(file_.get())) {
- ended_ = true;
- } else {
- UTIL_THROW(util::ErrnoException, "Short read of counts");
+ void Add(const WordIndex *to, ProbPointer index) {
+ while (current_ + entry_size_ > allocated_) {
+ std::size_t allocated_size = allocated_ - (uint8_t*)backing_.get();
+ Resize(std::max<std::size_t>(allocated_size * 2, entry_size_));
+ }
+ memcpy(current_, to, entry_size_ - sizeof(ProbPointer));
+ *reinterpret_cast<ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer)) = index;
+ current_ += entry_size_;
+ }
+
+ void Apply(float *const *const base, FILE *unigrams) {
+ FinishedAdding();
+ if (current_ == allocated_) return;
+ rewind(unigrams);
+ ProbBackoff weights;
+ WordIndex unigram = 0;
+ ReadOrThrow(unigrams, &weights, sizeof(weights));
+ for (; current_ != allocated_; current_ += entry_size_) {
+ const WordIndex &cur_word = *reinterpret_cast<const WordIndex*>(current_);
+ for (; unigram < cur_word; ++unigram) {
+ ReadOrThrow(unigrams, &weights, sizeof(weights));
}
+ if (!HasExtension(weights.backoff)) {
+ weights.backoff = kExtensionBackoff;
+ UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed.");
+ WriteOrThrow(unigrams, &weights, sizeof(weights));
+ }
+ const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex));
+ base[write_to.array][write_to.index] += weights.backoff;
}
+ backing_.reset();
+ }
+
+ void Apply(float *const *const base, RecordReader &reader) {
+ FinishedAdding();
+ if (current_ == allocated_) return;
+ // We'll also use the same buffer to record messages to blanks that they extend.
+ WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_);
+ const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex);
+ for (reader.Rewind(); reader && (current_ != allocated_); ) {
+ switch (Compare(order, reader.Data(), current_)) {
+ case -1:
+ ++reader;
+ break;
+ case 1:
+ // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends.
+ for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w;
+ current_ += entry_size_;
+ break;
+ case 0:
+ float &backoff = reinterpret_cast<ProbBackoff*>((uint8_t*)reader.Data() + order * sizeof(WordIndex))->backoff;
+ if (!HasExtension(backoff)) {
+ backoff = kExtensionBackoff;
+ reader.Overwrite(&backoff, sizeof(float));
+ } else {
+ const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer));
+ base[write_to.array][write_to.index] += backoff;
+ }
+ current_ += entry_size_;
+ break;
+ }
+ }
+ // Now this is a list of blanks that extend right.
+ entry_size_ = sizeof(WordIndex) * order;
+ Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get()));
+ current_ = (uint8_t*)backing_.get();
}
- WordIndex ReadCount() {
- WordIndex ret;
- ReadOrThrow(file_.get(), &ret, sizeof(WordIndex));
- return ret;
- }
-
- WordIndex ReadWord() {
- WordIndex ret;
- ReadOrThrow(file_.get(), &ret, sizeof(WordIndex));
- return ret;
- }
-
- template <class Weights> void ReadWeights(Weights &weights) {
- ReadOrThrow(file_.get(), &weights, sizeof(Weights));
+ // Call after Apply
+ bool Extends(unsigned char order, const WordIndex *words) {
+ if (current_ == allocated_) return false;
+ assert(order * sizeof(WordIndex) == entry_size_);
+ while (true) {
+ switch(Compare(order, words, current_)) {
+ case 1:
+ current_ += entry_size_;
+ if (current_ == allocated_) return false;
+ break;
+ case -1:
+ return false;
+ case 0:
+ return true;
+ }
+ }
}
- bool Ended() const {
- return ended_;
+ private:
+ void FinishedAdding() {
+ Resize(current_ - (uint8_t*)backing_.get());
+ current_ = (uint8_t*)backing_.get();
}
- void Rewind() {
- rewind(file_.get());
- ended_ = false;
- NextHeader();
+ void Resize(std::size_t to) {
+ std::size_t current = current_ - (uint8_t*)backing_.get();
+ backing_.call_realloc(to);
+ current_ = (uint8_t*)backing_.get() + current;
+ allocated_ = (uint8_t*)backing_.get() + to;
}
- FILE *File() { return file_.get(); }
-
- private:
- util::scoped_FILE file_;
+ util::scoped_malloc backing_;
- std::vector<WordIndex> header_;
+ uint8_t *current_, *allocated_;
- bool ended_;
+ std::size_t entry_size_;
};
-void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) {
- WriteOrThrow(to, from.Header(), from.HeaderBytes());
- WordIndex count = from.ReadCount();
- WriteOrThrow(to, &count, sizeof(WordIndex));
-
- CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count);
-}
-
-void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order) {
- SortedFileReader first, second;
- first.Init(first_name.c_str(), order);
- RemoveOrThrow(first_name.c_str());
- second.Init(second_name.c_str(), order);
- RemoveOrThrow(second_name.c_str());
- util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w"));
- while (!first.Ended() && !second.Ended()) {
- if (first.HeaderVector() < second.HeaderVector()) {
- CopyFullRecord(first, out_file.get(), weights_size);
- first.NextHeader();
- continue;
- }
- if (first.HeaderVector() > second.HeaderVector()) {
- CopyFullRecord(second, out_file.get(), weights_size);
- second.NextHeader();
- continue;
- }
- // Merge at the entry level.
- WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes());
- WordIndex first_count = first.ReadCount(), second_count = second.ReadCount();
- WordIndex total_count = first_count + second_count;
- WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex));
-
- WordIndex first_word = first.ReadWord(), second_word = second.ReadWord();
- WordIndex first_index = 0, second_index = 0;
- while (true) {
- if (first_word < second_word) {
- WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex));
- CopyOrThrow(first.File(), out_file.get(), weights_size);
- if (++first_index == first_count) break;
- first_word = first.ReadWord();
- } else {
- WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex));
- CopyOrThrow(second.File(), out_file.get(), weights_size);
- if (++second_index == second_count) break;
- second_word = second.ReadWord();
- }
- }
- if (first_index == first_count) {
- WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex));
- CopyOrThrow(second.File(), out_file.get(), (second_count - second_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex));
- } else {
- WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex));
- CopyOrThrow(first.File(), out_file.get(), (first_count - first_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex));
- }
- first.NextHeader();
- second.NextHeader();
- }
-
- for (SortedFileReader &remaining = first.Ended() ? second : first; !remaining.Ended(); remaining.NextHeader()) {
- CopyFullRecord(remaining, out_file.get(), weights_size);
- }
-}
-
-const char *kContextSuffix = "_contexts";
-
-void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) {
- const size_t context_size = sizeof(WordIndex) * (order - 1);
- // Sort just the contexts using the same memory.
- PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
- PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size));
-
- std::sort(context_begin, context_end, CompareRecords<PartialViewProxy>(order - 1));
-
- std::string name(ngram_file_name + kContextSuffix);
- util::scoped_FILE out(OpenOrThrow(name.c_str(), "w"));
-
- // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
- if (context_begin == context_end) return;
- PartialIter i(context_begin);
- WriteOrThrow(out.get(), i->Indices(), context_size);
- const WordIndex *previous = i->Indices();
- ++i;
- for (; i != context_end; ++i) {
- if (memcmp(previous, i->Indices(), context_size)) {
- WriteOrThrow(out.get(), i->Indices(), context_size);
- previous = i->Indices();
- }
- }
-}
+const float kBadProb = std::numeric_limits<float>::infinity();
-class ContextReader {
+class SRISucks {
public:
- ContextReader() : valid_(false) {}
-
- ContextReader(const char *name, unsigned char order) {
- Reset(name, order);
- }
-
- void Reset(const char *name, unsigned char order) {
- file_.reset(OpenOrThrow(name, "r"));
- length_ = sizeof(WordIndex) * static_cast<size_t>(order);
- words_.resize(order);
- valid_ = true;
- ++*this;
- }
-
- ContextReader &operator++() {
- if (1 != fread(&*words_.begin(), length_, 1, file_.get())) {
- if (!feof(file_.get()))
- UTIL_THROW(util::ErrnoException, "Short read");
- valid_ = false;
+ SRISucks() {
+ for (BackoffMessages *i = messages_; i != messages_ + kMaxOrder - 1; ++i)
+ i->Init(sizeof(ProbPointer) + sizeof(WordIndex) * (i - messages_ + 1));
+ }
+
+ void Send(unsigned char begin, unsigned char order, const WordIndex *to, float prob_basis) {
+ assert(prob_basis != kBadProb);
+ ProbPointer pointer;
+ pointer.array = order - 1;
+ pointer.index = values_[order - 1].size();
+ for (unsigned char i = begin; i < order; ++i) {
+ messages_[i - 1].Add(to, pointer);
}
- return *this;
+ values_[order - 1].push_back(prob_basis);
}
- const WordIndex *operator*() const { return &*words_.begin(); }
-
- operator bool() const { return valid_; }
-
- FILE *GetFile() { return file_.get(); }
-
- private:
- util::scoped_FILE file_;
-
- size_t length_;
-
- std::vector<WordIndex> words_;
-
- bool valid_;
-};
-
-void MergeContextFiles(const std::string &first_base, const std::string &second_base, const std::string &out_base, unsigned char order) {
- const size_t context_size = sizeof(WordIndex) * (order - 1);
- std::string first_name(first_base + kContextSuffix);
- std::string second_name(second_base + kContextSuffix);
- ContextReader first(first_name.c_str(), order - 1), second(second_name.c_str(), order - 1);
- RemoveOrThrow(first_name.c_str());
- RemoveOrThrow(second_name.c_str());
- std::string out_name(out_base + kContextSuffix);
- util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w"));
- while (first && second) {
- for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) {
- if (f == *first + order - 1) {
- // Equal.
- WriteOrThrow(out.get(), *first, context_size);
- ++first;
- ++second;
- break;
+ void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) {
+ for (unsigned char i = 0; i < kMaxOrder - 1; ++i) {
+ it_[i] = &*values_[i].begin();
}
- if (*f < *s) {
- // First lower
- WriteOrThrow(out.get(), *first, context_size);
- ++first;
- break;
- } else if (*f > *s) {
- WriteOrThrow(out.get(), *second, context_size);
- ++second;
- break;
+ messages_[0].Apply(it_, unigram_file);
+ BackoffMessages *messages = messages_ + 1;
+ const RecordReader *end = reader + total_order - 2 /* exclude unigrams and longest order */;
+ for (; reader != end; ++messages, ++reader) {
+ messages->Apply(it_, *reader);
}
}
- }
- ContextReader &remaining = first ? first : second;
- if (!remaining) return;
- WriteOrThrow(out.get(), *remaining, context_size);
- CopyRestOrThrow(remaining.GetFile(), out.get());
-}
-void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) {
- ReadNGramHeader(f, order);
- const size_t count = counts[order - 1];
- // Size of weights. Does it include backoff?
- const size_t words_size = sizeof(WordIndex) * order;
- const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float));
- const size_t entry_size = words_size + weights_size;
- const size_t batch_size = std::min(count, mem.size() / entry_size);
- uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get());
- std::deque<std::string> files;
- for (std::size_t batch = 0, done = 0; done < count; ++batch) {
- uint8_t *out = begin;
- uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
- if (order == counts.size()) {
- for (; out != out_end; out += entry_size) {
- ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn);
- }
- } else {
- for (; out != out_end; out += entry_size) {
- ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
- }
+ ProbBackoff GetBlank(unsigned char total_order, unsigned char order, const WordIndex *indices) {
+ assert(order > 1);
+ ProbBackoff ret;
+ ret.prob = *(it_[order - 1]++);
+ ret.backoff = ((order != total_order - 1) && messages_[order - 1].Extends(order, indices)) ? kExtensionBackoff : kNoExtensionBackoff;
+ return ret;
}
- // Sort full records by full n-gram.
- EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
- // parallel_sort uses too much RAM
- std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords<EntryProxy>(order));
- files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size));
- WriteContextFile(begin, out_end, files.back(), entry_size, order);
-
- done += (out_end - begin) / entry_size;
- }
- // All individual files created. Merge them.
-
- std::size_t merge_count = 0;
- while (files.size() > 1) {
- std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++);
- files.push_back(assembled.str());
- MergeSortedFiles(files[0], files[1], files.back(), weights_size, order);
- MergeContextFiles(files[0], files[1], files.back(), order);
- files.pop_front();
- files.pop_front();
- }
- if (!files.empty()) {
- std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(order) << "_merged";
- std::string merged_name(assembled.str());
- if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str());
- std::string context_name = files[0] + kContextSuffix;
- merged_name += kContextSuffix;
- if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str());
- }
-}
-
-void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
- PositiveProbWarn warn(config.positive_log_probability);
- {
- std::string unigram_name = file_prefix + "unigrams";
- util::scoped_fd unigram_file;
- // In case <unk> appears.
- size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff);
- util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out);
- Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);
- CheckSpecials(config, vocab);
- if (!vocab.SawUnk()) ++counts[0];
- }
+ const std::vector<float> &Values(unsigned char order) const {
+ return values_[order - 1];
+ }
- // Only use as much buffer as we need.
- size_t buffer_use = 0;
- for (unsigned int order = 2; order < counts.size(); ++order) {
- buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1]));
- }
- buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
- buffer = std::min<size_t>(buffer, buffer_use);
+ private:
+ // This used to be one array. Then I needed to separate it by order for quantization to work.
+ std::vector<float> values_[kMaxOrder - 1];
+ BackoffMessages messages_[kMaxOrder - 1];
- util::scoped_memory mem;
- mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED);
- if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
+ float *it_[kMaxOrder - 1];
+};
- for (unsigned char order = 2; order <= counts.size(); ++order) {
- ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn);
- }
- ReadEnd(f);
-}
+class FindBlanks {
+ public:
+ FindBlanks(uint64_t *counts, unsigned char order, const ProbBackoff *unigrams, SRISucks &messages)
+ : counts_(counts), longest_counts_(counts + order - 1), unigrams_(unigrams), sri_(messages) {}
-bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) {
- for (; words != words_end; ++words, ++header) {
- if (*words != *header) {
- //assert(*words <= *header);
- return false;
+ float UnigramProb(WordIndex index) const {
+ return unigrams_[index].prob;
}
- }
- return true;
-}
-// Phase to count n-grams, including blanks inserted because they were pruned but have extensions
-class JustCount {
- public:
- template <class Middle, class Longest> JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order)
- : counts_(counts), longest_counts_(counts + order - 1) {}
-
- void Unigrams(WordIndex begin, WordIndex end) {
- counts_[0] += end - begin;
+ void Unigram(WordIndex /*index*/) {
+ ++counts_[0];
}
- void MiddleBlank(const unsigned char mid_idx, WordIndex /* idx */) {
- ++counts_[mid_idx + 1];
+ void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) {
+ sri_.Send(lower, order, indices + 1, prob_basis);
+ ++counts_[order - 1];
}
- void Middle(const unsigned char mid_idx, const WordIndex * /*before*/, WordIndex /*key*/, const ProbBackoff &/*weights*/) {
- ++counts_[mid_idx + 1];
+ void Middle(const unsigned char order, const void * /*data*/) {
+ ++counts_[order - 1];
}
- void Longest(WordIndex /*key*/, Prob /*prob*/) {
+ void Longest(const void * /*data*/) {
++*longest_counts_;
}
@@ -608,167 +253,156 @@ class JustCount {
private:
uint64_t *const counts_, *const longest_counts_;
+
+ const ProbBackoff *unigrams_;
+
+ SRISucks &sri_;
};
// Phase to actually write n-grams to the trie.
template <class Quant, class Bhiksha> class WriteEntries {
public:
- WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :
+ WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, unsigned char order, SRISucks &sri) :
contexts_(contexts),
unigrams_(unigrams),
middle_(middle),
longest_(longest),
- bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)) {}
+ bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)),
+ order_(order),
+ sri_(sri) {}
- void Unigrams(WordIndex begin, WordIndex end) {
- uint64_t next = bigram_pack_.InsertIndex();
- for (UnigramValue *i = unigrams_ + begin; i < unigrams_ + end; ++i) {
- i->next = next;
- }
+ float UnigramProb(WordIndex index) const { return unigrams_[index].weights.prob; }
+
+ void Unigram(WordIndex word) {
+ unigrams_[word].next = bigram_pack_.InsertIndex();
}
- void MiddleBlank(const unsigned char mid_idx, WordIndex key) {
- middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff);
+ void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) {
+ ProbBackoff weights = sri_.GetBlank(order_, order, indices);
+ middle_[order - 2].Insert(indices[order - 1], weights.prob, weights.backoff);
}
- void Middle(const unsigned char mid_idx, const WordIndex *before, WordIndex key, ProbBackoff weights) {
- // Order (mid_idx+2).
- ContextReader &context = contexts_[mid_idx + 1];
- if (context && !memcmp(before, *context, sizeof(WordIndex) * (mid_idx + 1)) && (*context)[mid_idx + 1] == key) {
+ void Middle(const unsigned char order, const void *data) {
+ RecordReader &context = contexts_[order - 1];
+ const WordIndex *words = reinterpret_cast<const WordIndex*>(data);
+ ProbBackoff weights = *reinterpret_cast<const ProbBackoff*>(words + order);
+ if (context && !memcmp(data, context.Data(), sizeof(WordIndex) * order)) {
SetExtension(weights.backoff);
++context;
}
- middle_[mid_idx].Insert(key, weights.prob, weights.backoff);
+ middle_[order - 2].Insert(words[order - 1], weights.prob, weights.backoff);
}
- void Longest(WordIndex key, Prob prob) {
- longest_.Insert(key, prob.prob);
+ void Longest(const void *data) {
+ const WordIndex *words = reinterpret_cast<const WordIndex*>(data);
+ longest_.Insert(words[order_ - 1], reinterpret_cast<const Prob*>(words + order_)->prob);
}
void Cleanup() {}
private:
- ContextReader *contexts_;
+ RecordReader *contexts_;
UnigramValue *const unigrams_;
BitPackedMiddle<typename Quant::Middle, Bhiksha> *const middle_;
BitPackedLongest<typename Quant::Longest> &longest_;
BitPacked &bigram_pack_;
+ const unsigned char order_;
+ SRISucks &sri_;
};
-template <class Doing> class RecursiveInsert {
- public:
- template <class MiddleT, class LongestT> RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) :
- doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) {
- }
+struct Gram {
+ Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {}
- // Outer unigram loop.
- void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) {
- util::ErsatzProgress progress(progress_out, message, unigram_count + 1);
- for (words_[0] = 0; ; ++words_[0]) {
- progress.Set(words_[0]);
- WordIndex min_continue = unigram_count;
- for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) {
- if (other->Ended()) continue;
- min_continue = std::min(min_continue, other->Header()[0]);
- }
- // This will write at unigram_count. This is by design so that the next pointers will make sense.
- doing_.Unigrams(words_[0], min_continue + 1);
- if (min_continue == unigram_count) break;
- words_[0] = min_continue;
- Middle(0);
- }
- doing_.Cleanup();
- }
+ const WordIndex *begin, *end;
- private:
- void Middle(const unsigned char mid_idx) {
- // (mid_idx + 2)-gram.
- if (mid_idx == order_minus_2_) {
- Longest();
- return;
- }
- // Orders [2, order)
+ // For queue, this is the direction we want.
+ bool operator<(const Gram &other) const {
+ return std::lexicographical_compare(other.begin, other.end, begin, end);
+ }
+};
- SortedFileReader &reader = inputs_[mid_idx];
+template <class Doing> class BlankManager {
+ public:
+ BlankManager(unsigned char total_order, Doing &doing) : total_order_(total_order), been_length_(0), doing_(doing) {
+ for (float *i = basis_; i != basis_ + kMaxOrder - 1; ++i) *i = kBadProb;
+ }
- if (reader.Ended() || !HeadMatch(words_, words_ + mid_idx + 1, reader.Header())) {
- // This order doesn't have a header match, but longer ones might.
- MiddleAllBlank(mid_idx);
- return;
+ void Visit(const WordIndex *to, unsigned char length, float prob) {
+ basis_[length - 1] = prob;
+ unsigned char overlap = std::min<unsigned char>(length - 1, been_length_);
+ const WordIndex *cur;
+ WordIndex *pre;
+ for (cur = to, pre = been_; cur != to + overlap; ++cur, ++pre) {
+ if (*pre != *cur) break;
}
-
- // There is a header match.
- WordIndex count = reader.ReadCount();
- WordIndex current = reader.ReadWord();
- while (count) {
- WordIndex min_continue = std::numeric_limits<WordIndex>::max();
- for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) {
- if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header()))
- min_continue = std::min(min_continue, other->Header()[mid_idx + 1]);
- }
- while (true) {
- if (current > min_continue) {
- doing_.MiddleBlank(mid_idx, min_continue);
- words_[mid_idx + 1] = min_continue;
- Middle(mid_idx + 1);
- break;
- }
- ProbBackoff weights;
- reader.ReadWeights(weights);
- doing_.Middle(mid_idx, words_, current, weights);
- --count;
- if (current == min_continue) {
- words_[mid_idx + 1] = min_continue;
- Middle(mid_idx + 1);
- if (count) current = reader.ReadWord();
- break;
- }
- if (!count) break;
- current = reader.ReadWord();
- }
+ if (cur == to + length - 1) {
+ *pre = *cur;
+ been_length_ = length;
+ return;
}
- // Count is now zero. Finish off remaining blanks.
- MiddleAllBlank(mid_idx);
- reader.NextHeader();
- }
-
- void MiddleAllBlank(const unsigned char mid_idx) {
- while (true) {
- WordIndex min_continue = std::numeric_limits<WordIndex>::max();
- for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) {
- if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header()))
- min_continue = std::min(min_continue, other->Header()[mid_idx + 1]);
- }
- if (min_continue == std::numeric_limits<WordIndex>::max()) return;
- doing_.MiddleBlank(mid_idx, min_continue);
- words_[mid_idx + 1] = min_continue;
- Middle(mid_idx + 1);
+ // There are blanks to insert starting with order blank.
+ unsigned char blank = cur - to + 1;
+ UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context.");
+ const float *lower_basis;
+ for (lower_basis = basis_ + blank - 2; *lower_basis == kBadProb; --lower_basis) {}
+ unsigned char based_on = lower_basis - basis_ + 1;
+ for (; cur != to + length - 1; ++blank, ++cur, ++pre) {
+ assert(*lower_basis != kBadProb);
+ doing_.MiddleBlank(blank, to, based_on, *lower_basis);
+ *pre = *cur;
+ // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
+ basis_[blank - 1] = kBadProb;
}
+ been_length_ = length;
}
- void Longest() {
- SortedFileReader &reader = *(inputs_end_ - 1);
- if (reader.Ended() || !HeadMatch(words_, words_ + order_minus_2_ + 1, reader.Header())) return;
- WordIndex count = reader.ReadCount();
- for (WordIndex i = 0; i < count; ++i) {
- WordIndex word = reader.ReadWord();
- Prob prob;
- reader.ReadWeights(prob);
- doing_.Longest(word, prob);
- }
- reader.NextHeader();
- return;
- }
+ private:
+ const unsigned char total_order_;
- Doing doing_;
+ WordIndex been_[kMaxOrder];
+ unsigned char been_length_;
- SortedFileReader *inputs_;
- SortedFileReader *inputs_end_;
+ float basis_[kMaxOrder];
+
+ Doing &doing_;
+};
- WordIndex words_[kMaxOrder];
+template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) {
+ util::ErsatzProgress progress(progress_out, message, unigram_count + 1);
+ unsigned int unigram = 0;
+ std::priority_queue<Gram> grams;
+ grams.push(Gram(&unigram, 1));
+ for (unsigned char i = 2; i <= total_order; ++i) {
+ if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i));
+ }
- const unsigned char order_minus_2_;
-};
+ BlankManager<Doing> blank(total_order, doing);
+
+ while (true) {
+ Gram top = grams.top();
+ grams.pop();
+ unsigned char order = top.end - top.begin;
+ if (order == 1) {
+ blank.Visit(&unigram, 1, doing.UnigramProb(unigram));
+ doing.Unigram(unigram);
+ progress.Set(unigram);
+ if (++unigram == unigram_count + 1) break;
+ grams.push(top);
+ } else {
+ if (order == total_order) {
+ blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob);
+ doing.Longest(top.begin);
+ } else {
+ blank.Visit(top.begin, order, reinterpret_cast<const ProbBackoff*>(top.end)->prob);
+ doing.Middle(order, top.begin);
+ }
+ RecordReader &reader = input[order - 2];
+ if (++reader) grams.push(top);
+ }
+ }
+ assert(grams.empty());
+ doing.Cleanup();
+}
void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {
if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]);
@@ -778,120 +412,122 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u
}
}
-bool IsDirectory(const char *path) {
- struct stat info;
- if (0 != stat(path, &info)) return false;
- return S_ISDIR(info.st_mode);
-}
-
-template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) {
- ProbBackoff weights;
- std::vector<float> probs, backoffs;
- probs.reserve(count);
+template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const std::vector<float> &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
+ std::vector<float> probs(additional), backoffs;
+ probs.reserve(count + additional.size());
backoffs.reserve(count);
- for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) {
- uint64_t entries = reader.ReadCount();
- for (uint64_t c = 0; c < entries; ++c) {
- reader.ReadWord();
- reader.ReadWeights(weights);
- // kBlankProb isn't added yet.
- probs.push_back(weights.prob);
- if (weights.backoff != 0.0) backoffs.push_back(weights.backoff);
- ++progress;
- }
+ for (reader.Rewind(); reader; ++reader) {
+ const ProbBackoff &weights = *reinterpret_cast<const ProbBackoff*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order);
+ probs.push_back(weights.prob);
+ if (weights.backoff != 0.0) backoffs.push_back(weights.backoff);
+ ++progress;
}
quant.Train(order, probs, backoffs);
}
-template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) {
- Prob weights;
+template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
std::vector<float> probs, backoffs;
probs.reserve(count);
- for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) {
- uint64_t entries = reader.ReadCount();
- for (uint64_t c = 0; c < entries; ++c) {
- reader.ReadWord();
- reader.ReadWeights(weights);
- // kBlankProb isn't added yet.
- probs.push_back(weights.prob);
- ++progress;
- }
+ for (reader.Rewind(); reader; ++reader) {
+ const Prob &weights = *reinterpret_cast<const Prob*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order);
+ probs.push_back(weights.prob);
+ ++progress;
}
quant.TrainProb(order, probs);
}
+void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) {
+ // Fill unigram probabilities.
+ try {
+ rewind(file);
+ for (WordIndex i = 0; i < unigram_count; ++i) {
+ ReadOrThrow(file, &unigrams[i].weights, sizeof(ProbBackoff));
+ if (contexts && *reinterpret_cast<const WordIndex*>(contexts.Data()) == i) {
+ SetExtension(unigrams[i].weights.backoff);
+ ++contexts;
+ }
+ }
+ } catch (util::Exception &e) {
+ e << " while re-reading unigram probabilities";
+ throw;
+ }
+}
+
} // namespace
template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
- std::vector<SortedFileReader> inputs(counts.size() - 1);
- std::vector<ContextReader> contexts(counts.size() - 1);
+ RecordReader inputs[kMaxOrder - 1];
+ RecordReader contexts[kMaxOrder - 1];
for (unsigned char i = 2; i <= counts.size(); ++i) {
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(i) << "_merged";
- inputs[i-2].Init(assembled.str(), i);
- RemoveOrThrow(assembled.str().c_str());
+ inputs[i-2].Init(assembled.str(), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff)));
+ util::RemoveOrThrow(assembled.str().c_str());
assembled << kContextSuffix;
- contexts[i-2].Reset(assembled.str().c_str(), i-1);
- RemoveOrThrow(assembled.str().c_str());
+ contexts[i-2].Init(assembled.str(), (i-1) * sizeof(WordIndex));
+ util::RemoveOrThrow(assembled.str().c_str());
}
+ SRISucks sri;
std::vector<uint64_t> fixed_counts(counts.size());
{
- RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
- counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
+ std::string temp(file_prefix); temp += "unigrams";
+ util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str()));
+ util::scoped_memory unigrams;
+ MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
+ FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
+ RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder);
}
- for (std::vector<SortedFileReader>::const_iterator i = inputs.begin(); i != inputs.end(); ++i) {
- if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs.begin() + 2) << "-gram table did not complete reading");
+ for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) {
+ if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading");
}
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
+ util::scoped_FILE unigram_file;
+ {
+ std::string name(file_prefix + "unigrams");
+ unigram_file.reset(OpenOrThrow(name.c_str(), "r"));
+ util::RemoveOrThrow(name.c_str());
+ }
+ sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
+
out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
+ for (unsigned char i = 2; i <= counts.size(); ++i) {
+ inputs[i-2].Rewind();
+ }
if (Quant::kTrain) {
util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0));
for (unsigned char i = 2; i < counts.size(); ++i) {
- TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant);
+ TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
}
TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant);
quant.FinishedLoading(config);
}
+ UnigramValue *unigrams = out.unigram.Raw();
+ PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams);
+ unigram_file.reset();
+
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
}
- UnigramValue *unigrams = out.unigram.Raw();
// Fill entries except unigram probabilities.
{
- RecursiveInsert<WriteEntries<Quant, Bhiksha> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
- inserter.Apply(config.messages, "Building trie", fixed_counts[0]);
- }
-
- // Fill unigram probabilities.
- try {
- std::string name(file_prefix + "unigrams");
- util::scoped_FILE file(OpenOrThrow(name.c_str(), "r"));
- for (WordIndex i = 0; i < counts[0]; ++i) {
- ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff));
- if (contexts[0] && **contexts[0] == i) {
- SetExtension(unigrams[i].weights.backoff);
- ++contexts[0];
- }
- }
- RemoveOrThrow(name.c_str());
- } catch (util::Exception &e) {
- e << " while re-reading unigram probabilities";
- throw;
+ WriteEntries<Quant, Bhiksha> writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri);
+ RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer);
}
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
for (unsigned char order = 2; order <= counts.size(); ++order) {
- const ContextReader &context = contexts[order - 2];
+ const RecordReader &context = contexts[order - 2];
if (context) {
FormatLoadException e;
- e << "An " << static_cast<unsigned int>(order) << "-gram has the context (i.e. all but the last word):";
- for (const WordIndex *i = *context; i != *context + order - 1; ++i) {
+ e << "An " << static_cast<unsigned int>(order) << "-gram has context";
+ const WordIndex *ctx = reinterpret_cast<const WordIndex*>(context.Data());
+ for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) {
e << ' ' << *i;
}
e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not";
@@ -945,6 +581,14 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBin
longest.LoadedBinary();
}
+namespace {
+bool IsDirectory(const char *path) {
+ struct stat info;
+ if (0 != stat(path, &info)) return false;
+ return S_ISDIR(info.st_mode);
+}
+} // namespace
+
template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
std::string temporary_directory;
if (config.temporary_directory_prefix) {