summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/bhiksha.cc15
-rw-r--r--klm/lm/bhiksha.hh9
-rw-r--r--klm/lm/binary_format.cc259
-rw-r--r--klm/lm/binary_format.hh112
-rw-r--r--klm/lm/build_binary_main.cc8
-rw-r--r--klm/lm/builder/corpus_count.cc9
-rw-r--r--klm/lm/builder/lmplz_main.cc33
-rw-r--r--klm/lm/builder/pipeline.cc1
-rw-r--r--klm/lm/facade.hh23
-rw-r--r--klm/lm/filter/arpa_io.hh1
-rw-r--r--klm/lm/filter/count_io.hh12
-rw-r--r--klm/lm/filter/filter_main.cc155
-rw-r--r--klm/lm/filter/format.hh2
-rw-r--r--klm/lm/filter/phrase.cc59
-rw-r--r--klm/lm/filter/phrase.hh42
-rw-r--r--klm/lm/filter/phrase_table_vocab_main.cc165
-rw-r--r--klm/lm/filter/vocab.cc1
-rw-r--r--klm/lm/filter/wrapper.hh10
-rw-r--r--klm/lm/model.cc84
-rw-r--r--klm/lm/model.hh14
-rw-r--r--klm/lm/model_test.cc8
-rw-r--r--klm/lm/ngram_query.hh17
-rw-r--r--klm/lm/quantize.cc12
-rw-r--r--klm/lm/quantize.hh5
-rw-r--r--klm/lm/query_main.cc51
-rw-r--r--klm/lm/read_arpa.cc14
-rw-r--r--klm/lm/search_hashed.cc33
-rw-r--r--klm/lm/search_hashed.hh12
-rw-r--r--klm/lm/search_trie.cc35
-rw-r--r--klm/lm/search_trie.hh23
-rw-r--r--klm/lm/state.hh2
-rw-r--r--klm/lm/trie.hh9
-rw-r--r--klm/lm/trie_sort.cc4
-rw-r--r--klm/lm/value_build.cc1
-rw-r--r--klm/lm/virtual_interface.hh7
-rw-r--r--klm/lm/vocab.cc28
-rw-r--r--klm/lm/vocab.hh12
37 files changed, 800 insertions, 487 deletions
diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc
index 088ea98d..c8a18dfd 100644
--- a/klm/lm/bhiksha.cc
+++ b/klm/lm/bhiksha.cc
@@ -1,4 +1,6 @@
#include "lm/bhiksha.hh"
+
+#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "util/file.hh"
#include "util/exception.hh"
@@ -15,11 +17,11 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_
const uint8_t kArrayBhikshaVersion = 0;
// TODO: put this in binary file header instead when I change the binary file format again.
-void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) {
- uint8_t version;
- uint8_t configured_bits;
- util::ReadOrThrow(fd, &version, 1);
- util::ReadOrThrow(fd, &configured_bits, 1);
+void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) {
+ uint8_t buffer[2];
+ file.ReadForConfig(buffer, 2, offset);
+ uint8_t version = buffer[0];
+ uint8_t configured_bits = buffer[1];
if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion);
config.pointer_bhiksha_bits = configured_bits;
}
@@ -87,9 +89,6 @@ void ArrayBhiksha::FinishedLoading(const Config &config) {
*(head_write++) = config.pointer_bhiksha_bits;
}
-void ArrayBhiksha::LoadedBinary() {
-}
-
} // namespace trie
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh
index 8ff88654..350571a6 100644
--- a/klm/lm/bhiksha.hh
+++ b/klm/lm/bhiksha.hh
@@ -24,6 +24,7 @@
namespace lm {
namespace ngram {
struct Config;
+class BinaryFormat;
namespace trie {
@@ -31,7 +32,7 @@ class DontBhiksha {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
- static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &/*config*/) {}
static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
@@ -53,8 +54,6 @@ class DontBhiksha {
void FinishedLoading(const Config &/*config*/) {}
- void LoadedBinary() {}
-
uint8_t InlineBits() const { return next_.bits; }
private:
@@ -65,7 +64,7 @@ class ArrayBhiksha {
public:
static const ModelType kModelTypeAdd = kArrayAdd;
- static void UpdateConfigFromBinary(int fd, Config &config);
+ static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
@@ -93,8 +92,6 @@ class ArrayBhiksha {
void FinishedLoading(const Config &config);
- void LoadedBinary();
-
uint8_t InlineBits() const { return next_inline_.bits; }
private:
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 39c4a9b6..9c744b13 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -8,11 +8,15 @@
#include <cstring>
#include <limits>
#include <string>
+#include <cstdlib>
#include <stdint.h>
namespace lm {
namespace ngram {
+
+const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
+
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
@@ -57,8 +61,6 @@ struct Sanity {
}
};
-const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
-
std::size_t TotalHeaderSize(unsigned char order) {
return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
}
@@ -80,83 +82,6 @@ void WriteHeader(void *to, const Parameters &params) {
} // namespace
-uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
- if (config.write_mmap) {
- std::size_t total = TotalHeaderSize(order) + memory_size;
- backing.file.reset(util::CreateOrThrow(config.write_mmap));
- if (config.write_method == Config::WRITE_MMAP) {
- backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
- } else {
- util::ResizeOrThrow(backing.file.get(), 0);
- util::MapAnonymous(total, backing.vocab);
- }
- strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
- return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
- } else {
- util::MapAnonymous(memory_size, backing.vocab);
- return reinterpret_cast<uint8_t*>(backing.vocab.get());
- }
-}
-
-uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
- std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
- if (config.write_mmap) {
- // Grow the file to accomodate the search, using zeros.
- try {
- util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
- } catch (util::ErrnoException &e) {
- e << " for file " << config.write_mmap;
- throw e;
- }
-
- if (config.write_method == Config::WRITE_AFTER) {
- util::MapAnonymous(memory_size, backing.search);
- return reinterpret_cast<uint8_t*>(backing.search.get());
- }
- // mmap it now.
- // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
- std::size_t page_size = util::SizePage();
- std::size_t alignment_cruft = adjusted_vocab % page_size;
- backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
- return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
- } else {
- util::MapAnonymous(memory_size, backing.search);
- return reinterpret_cast<uint8_t*>(backing.search.get());
- }
-}
-
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
- if (!config.write_mmap) return;
- switch (config.write_method) {
- case Config::WRITE_MMAP:
- util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
- util::SyncOrThrow(backing.search.get(), backing.search.size());
- break;
- case Config::WRITE_AFTER:
- util::SeekOrThrow(backing.file.get(), 0);
- util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size());
- util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
- util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
- util::FSyncOrThrow(backing.file.get());
- break;
- }
- // header and vocab share the same mmap. The header is written here because we know the counts.
- Parameters params = Parameters();
- params.counts = counts;
- params.fixed.order = counts.size();
- params.fixed.probing_multiplier = config.probing_multiplier;
- params.fixed.model_type = model_type;
- params.fixed.has_vocabulary = config.include_vocab;
- params.fixed.search_version = search_version;
- WriteHeader(backing.vocab.get(), params);
- if (config.write_method == Config::WRITE_AFTER) {
- util::SeekOrThrow(backing.file.get(), 0);
- util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size()));
- }
-}
-
-namespace detail {
-
bool IsBinaryFormat(int fd) {
const uint64_t size = util::SizeFile(fd);
if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
@@ -169,21 +94,21 @@ bool IsBinaryFormat(int fd) {
}
Sanity reference_header = Sanity();
reference_header.SetToReference();
- if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
- if (!memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
+ if (!std::memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
+ if (!std::memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
UTIL_THROW(FormatLoadException, "This binary file did not finish building");
}
- if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
+ if (!std::memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
char *end_ptr;
const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
- long int version = strtol(begin_version, &end_ptr, 10);
+ long int version = std::strtol(begin_version, &end_ptr, 10);
if ((end_ptr != begin_version) && version != kMagicVersion) {
UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
}
OldSanity old_sanity = OldSanity();
old_sanity.SetToReference();
- UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
+ UTIL_THROW_IF(!std::memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
}
return false;
@@ -208,44 +133,164 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet
UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version);
}
-void SeekPastHeader(int fd, const Parameters &params) {
- util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
+const std::size_t kInvalidSize = static_cast<std::size_t>(-1);
+
+BinaryFormat::BinaryFormat(const Config &config)
+ : write_method_(config.write_method), write_mmap_(config.write_mmap), load_method_(config.load_method),
+ header_size_(kInvalidSize), vocab_size_(kInvalidSize), vocab_string_offset_(kInvalidOffset) {}
+
+void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters &params) {
+ file_.reset(fd);
+ write_mmap_ = NULL; // Ignore write requests; this is already in binary format.
+ ReadHeader(fd, params);
+ MatchCheck(model_type, search_version, params);
+ header_size_ = TotalHeaderSize(params.counts.size());
+}
+
+void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const {
+ assert(header_size_ != kInvalidSize);
+ util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_);
}
-uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing) {
- const uint64_t file_size = util::SizeFile(backing.file.get());
+void *BinaryFormat::LoadBinary(std::size_t size) {
+ assert(header_size_ != kInvalidSize);
+ const uint64_t file_size = util::SizeFile(file_.get());
// The header is smaller than a page, so we have to map the whole header as well.
- std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
- if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
- UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
+ uint64_t total_map = static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(size);
+ UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
- util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search);
+ util::MapRead(load_method_, file_.get(), 0, util::CheckOverflow(total_map), mapping_);
- if (config.enumerate_vocab && !params.fixed.has_vocabulary)
- UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+ vocab_string_offset_ = total_map;
+ return reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
+}
+
+void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) {
+ vocab_size_ = memory_size;
+ if (!write_mmap_) {
+ header_size_ = 0;
+ util::MapAnonymous(memory_size, memory_vocab_);
+ return reinterpret_cast<uint8_t*>(memory_vocab_.get());
+ }
+ header_size_ = TotalHeaderSize(order);
+ std::size_t total = util::CheckOverflow(static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(memory_size));
+ file_.reset(util::CreateOrThrow(write_mmap_));
+ // some gccs complain about uninitialized variables even though all enum values are covered.
+ void *vocab_base = NULL;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ mapping_.reset(util::MapZeroedWrite(file_.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
+ vocab_base = mapping_.get();
+ break;
+ case Config::WRITE_AFTER:
+ util::ResizeOrThrow(file_.get(), 0);
+ util::MapAnonymous(total, memory_vocab_);
+ vocab_base = memory_vocab_.get();
+ break;
+ }
+ strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_);
+ return reinterpret_cast<uint8_t*>(vocab_base) + header_size_;
+}
- // Seek to vocabulary words
- util::SeekOrThrow(backing.file.get(), total_map);
- return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
+void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base) {
+ assert(vocab_size_ != kInvalidSize);
+ vocab_pad_ = vocab_pad;
+ std::size_t new_size = header_size_ + vocab_size_ + vocab_pad_ + memory_size;
+ vocab_string_offset_ = new_size;
+ if (!write_mmap_ || write_method_ == Config::WRITE_AFTER) {
+ util::MapAnonymous(memory_size, memory_search_);
+ assert(header_size_ == 0 || write_mmap_);
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
+ return reinterpret_cast<uint8_t*>(memory_search_.get());
+ }
+
+ assert(write_method_ == Config::WRITE_MMAP);
+ // Also known as total size without vocab words.
+ // Grow the file to accomodate the search, using zeros.
+ // According to man mmap, behavior is undefined when the file is resized
+ // underneath a mmap that is not a multiple of the page size. So to be
+ // safe, we'll unmap it and map it again.
+ mapping_.reset();
+ util::ResizeOrThrow(file_.get(), new_size);
+ void *ret;
+ MapFile(vocab_base, ret);
+ return ret;
}
-void ComplainAboutARPA(const Config &config, ModelType model_type) {
- if (config.write_mmap || !config.messages) return;
- if (config.arpa_complain == Config::ALL) {
- *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
- } else if (config.arpa_complain == Config::EXPENSIVE &&
- (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
- *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+void BinaryFormat::WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base) {
+ // Checking Config's include_vocab is the responsibility of the caller.
+ assert(header_size_ != kInvalidSize && vocab_size_ != kInvalidSize);
+ if (!write_mmap_) {
+ // Unchanged base.
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get());
+ search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
+ return;
+ }
+ if (write_method_ == Config::WRITE_MMAP) {
+ mapping_.reset();
+ }
+ util::SeekOrThrow(file_.get(), VocabStringReadingOffset());
+ util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
+ if (write_method_ == Config::WRITE_MMAP) {
+ MapFile(vocab_base, search_base);
+ } else {
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
+ search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
+ }
+}
+
+void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts) {
+ if (!write_mmap_) return;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ util::SyncOrThrow(mapping_.get(), mapping_.size());
+ break;
+ case Config::WRITE_AFTER:
+ util::SeekOrThrow(file_.get(), 0);
+ util::WriteOrThrow(file_.get(), memory_vocab_.get(), memory_vocab_.size());
+ util::SeekOrThrow(file_.get(), header_size_ + vocab_size_ + vocab_pad_);
+ util::WriteOrThrow(file_.get(), memory_search_.get(), memory_search_.size());
+ util::FSyncOrThrow(file_.get());
+ break;
+ }
+ // header and vocab share the same mmap.
+ Parameters params = Parameters();
+ memset(&params, 0, sizeof(Parameters));
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+ params.fixed.search_version = search_version;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ WriteHeader(mapping_.get(), params);
+ util::SyncOrThrow(mapping_.get(), mapping_.size());
+ break;
+ case Config::WRITE_AFTER:
+ {
+ std::vector<uint8_t> buffer(TotalHeaderSize(counts.size()));
+ WriteHeader(&buffer[0], params);
+ util::SeekOrThrow(file_.get(), 0);
+ util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
+ }
+ break;
}
}
-} // namespace detail
+void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) {
+ mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED);
+ vocab_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
+ search_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_;
+}
bool RecognizeBinary(const char *file, ModelType &recognized) {
util::scoped_fd fd(util::OpenReadOrThrow(file));
- if (!detail::IsBinaryFormat(fd.get())) return false;
+ if (!IsBinaryFormat(fd.get())) {
+ return false;
+ }
Parameters params;
- detail::ReadHeader(fd.get(), params);
+ ReadHeader(fd.get(), params);
recognized = params.fixed.model_type;
return true;
}
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh
index bf699d5f..f33f88d7 100644
--- a/klm/lm/binary_format.hh
+++ b/klm/lm/binary_format.hh
@@ -17,6 +17,8 @@
namespace lm {
namespace ngram {
+extern const char *kModelNames[6];
+
/*Inspect a file to determine if it is a binary lm. If not, return false.
* If so, return true and set recognized to the type. This is the only API in
* this header designed for use by decoder authors.
@@ -42,67 +44,63 @@ struct Parameters {
std::vector<uint64_t> counts;
};
-struct Backing {
- // File behind memory, if any.
- util::scoped_fd file;
- // Vocabulary lookup table. Not to be confused with the vocab words themselves.
- util::scoped_memory vocab;
- // Raw block of memory backing the language model data structures
- util::scoped_memory search;
-};
-
-// Create just enough of a binary file to write vocabulary to it.
-uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
-// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
-uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing);
-
-// Write header to binary file. This is done last to prevent incomplete files
-// from loading.
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing);
+class BinaryFormat {
+ public:
+ explicit BinaryFormat(const Config &config);
+
+ // Reading a binary file:
+ // Takes ownership of fd
+ void InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters &params);
+ // Used to read parts of the file to update the config object before figuring out full size.
+ void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const;
+ // Actually load the binary file and return a pointer to the beginning of the search area.
+ void *LoadBinary(std::size_t size);
+
+ uint64_t VocabStringReadingOffset() const {
+ assert(vocab_string_offset_ != kInvalidOffset);
+ return vocab_string_offset_;
+ }
-namespace detail {
+ // Writing a binary file or initializing in RAM from ARPA:
+ // Size for vocabulary.
+ void *SetupJustVocab(std::size_t memory_size, uint8_t order);
+ // Warning: can change the vocaulary base pointer.
+ void *GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base);
+ // Warning: can change vocabulary and search base addresses.
+ void WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base);
+ // Write the header at the beginning of the file.
+ void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts);
+
+ private:
+ void MapFile(void *&vocab_base, void *&search_base);
+
+ // Copied from configuration.
+ const Config::WriteMethod write_method_;
+ const char *write_mmap_;
+ util::LoadMethod load_method_;
+
+ // File behind memory, if any.
+ util::scoped_fd file_;
+
+ // If there is a file involved, a single mapping.
+ util::scoped_memory mapping_;
+
+ // If the data is only in memory, separately allocate each because the trie
+ // knows vocab's size before it knows search's size (because SRILM might
+ // have pruned).
+ util::scoped_memory memory_vocab_, memory_search_;
+
+ // Memory ranges. Note that these may not be contiguous and may not all
+ // exist.
+ std::size_t header_size_, vocab_size_, vocab_pad_;
+ // aka end of search.
+ uint64_t vocab_string_offset_;
+
+ static const uint64_t kInvalidOffset = (uint64_t)-1;
+};
bool IsBinaryFormat(int fd);
-void ReadHeader(int fd, Parameters &params);
-
-void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters &params);
-
-void SeekPastHeader(int fd, const Parameters &params);
-
-uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing);
-
-void ComplainAboutARPA(const Config &config, ModelType model_type);
-
-} // namespace detail
-
-template <class To> void LoadLM(const char *file, const Config &config, To &to) {
- Backing &backing = to.MutableBacking();
- backing.file.reset(util::OpenReadOrThrow(file));
-
- try {
- if (detail::IsBinaryFormat(backing.file.get())) {
- Parameters params;
- detail::ReadHeader(backing.file.get(), params);
- detail::MatchCheck(To::kModelType, To::kVersion, params);
- // Replace the run-time configured probing_multiplier with the one in the file.
- Config new_config(config);
- new_config.probing_multiplier = params.fixed.probing_multiplier;
- detail::SeekPastHeader(backing.file.get(), params);
- To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config);
- uint64_t memory_size = To::Size(params.counts, new_config);
- uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing);
- to.InitializeFromBinary(start, params, new_config, backing.file.get());
- } else {
- detail::ComplainAboutARPA(config, To::kModelType);
- to.InitializeFromARPA(file, config);
- }
- } catch (util::Exception &e) {
- e << " File: " << file;
- throw;
- }
-}
-
} // namespace ngram
} // namespace lm
#endif // LM_BINARY_FORMAT__
diff --git a/klm/lm/build_binary_main.cc b/klm/lm/build_binary_main.cc
index ab2c0c32..15b421e9 100644
--- a/klm/lm/build_binary_main.cc
+++ b/klm/lm/build_binary_main.cc
@@ -52,6 +52,7 @@ void Usage(const char *name, const char *default_mem) {
"-a compresses pointers using an array of offsets. The parameter is the\n"
" maximum number of bits encoded by the array. Memory is minimized subject\n"
" to the maximum, so pick 255 to minimize memory.\n\n"
+"-h print this help message.\n\n"
"Get a memory estimate by passing an ARPA file without an output file name.\n";
exit(1);
}
@@ -104,12 +105,15 @@ int main(int argc, char *argv[]) {
const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G";
+ if (argc == 2 && !strcmp(argv[1], "--help"))
+ Usage(argv[0], default_mem);
+
try {
bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false;
lm::ngram::Config config;
config.building_memory = util::ParseSize(default_mem);
int opt;
- while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) {
+ while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:h")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
@@ -161,6 +165,7 @@ int main(int argc, char *argv[]) {
ParseFileList(optarg, config.rest_lower_files);
config.rest_function = Config::REST_LOWER;
break;
+ case 'h': // help
default:
Usage(argv[0], default_mem);
}
@@ -186,6 +191,7 @@ int main(int argc, char *argv[]) {
config.write_mmap = argv[optind + 2];
} else {
Usage(argv[0], default_mem);
+ return 1;
}
if (!strcmp(model_type, "probing")) {
if (!set_write_method) config.write_method = Config::WRITE_AFTER;
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc
index aea93ad1..ccc06efc 100644
--- a/klm/lm/builder/corpus_count.cc
+++ b/klm/lm/builder/corpus_count.cc
@@ -238,12 +238,17 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {
const WordIndex end_sentence = vocab.Lookup("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;
- StringPiece delimiters("\0\t\r ", 4);
+ bool delimiters[256];
+ memset(delimiters, 0, sizeof(delimiters));
+ const char kDelimiterSet[] = "\0\t\n\r ";
+ for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) {
+ delimiters[static_cast<unsigned char>(*i)] = true;
+ }
try {
while(true) {
StringPiece line(from_.ReadLine());
writer.StartSentence();
- for (util::TokenIter<util::AnyCharacter, true> w(line, delimiters); w; ++w) {
+ for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {
WordIndex word = vocab.Lookup(*w);
UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future.");
writer.Append(word);
diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc
index c87abdb8..2563deed 100644
--- a/klm/lm/builder/lmplz_main.cc
+++ b/klm/lm/builder/lmplz_main.cc
@@ -33,7 +33,10 @@ int main(int argc, char *argv[]) {
po::options_description options("Language model building options");
lm::builder::PipelineConfig pipeline;
+ std::string text, arpa;
+
options.add_options()
+ ("help", po::bool_switch(), "Show this help message")
("order,o", po::value<std::size_t>(&pipeline.order)
#if BOOST_VERSION >= 104200
->required()
@@ -47,8 +50,13 @@ int main(int argc, char *argv[]) {
("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
- ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.");
- if (argc == 1) {
+ ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
+ ("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
+ ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout");
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, options), vm);
+
+ if (argc == 1 || vm["help"].as<bool>()) {
std::cerr <<
"Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"
"Please cite:\n"
@@ -66,12 +74,17 @@ int main(int argc, char *argv[]) {
"setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n"
"Memory sizes are specified like GNU sort: a number followed by a unit character.\n"
"Valid units are \% for percentage of memory (supported platforms only) and (in\n"
- "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n";
+ "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n";
+ uint64_t mem = util::GuessPhysicalMemory();
+ if (mem) {
+ std::cerr << "This machine has " << mem << " bytes of memory.\n\n";
+ } else {
+ std::cerr << "Unable to determine the amount of memory on this machine.\n\n";
+ }
std::cerr << options << std::endl;
return 1;
}
- po::variables_map vm;
- po::store(po::parse_command_line(argc, argv, options), vm);
+
po::notify(vm);
// required() appeared in Boost 1.42.0.
@@ -92,9 +105,17 @@ int main(int argc, char *argv[]) {
initial.adder_out.block_count = 2;
pipeline.read_backoffs = initial.adder_out;
+ util::scoped_fd in(0), out(1);
+ if (vm.count("text")) {
+ in.reset(util::OpenReadOrThrow(text.c_str()));
+ }
+ if (vm.count("arpa")) {
+ out.reset(util::CreateOrThrow(arpa.c_str()));
+ }
+
// Read from stdin
try {
- lm::builder::Pipeline(pipeline, 0, 1);
+ lm::builder::Pipeline(pipeline, in.release(), out.release());
} catch (const util::MallocException &e) {
std::cerr << e.what() << std::endl;
std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl;
diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc
index b89ea6ba..44a2313c 100644
--- a/klm/lm/builder/pipeline.cc
+++ b/klm/lm/builder/pipeline.cc
@@ -226,6 +226,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
chain.Wait(true);
+ std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl;
std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl;
master.InitForAdjust(sorter, type_count);
}
diff --git a/klm/lm/facade.hh b/klm/lm/facade.hh
index 8b186017..de1551f1 100644
--- a/klm/lm/facade.hh
+++ b/klm/lm/facade.hh
@@ -16,19 +16,28 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ
typedef StateT State;
typedef VocabularyT Vocabulary;
- // Default Score function calls FullScore. Model can override this.
- float Score(const State &in_state, const WordIndex new_word, State &out_state) const {
- return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob;
- }
-
/* Translate from void* to State */
- FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const {
+ FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const {
return static_cast<const Child*>(this)->FullScore(
*reinterpret_cast<const State*>(in_state),
new_word,
*reinterpret_cast<State*>(out_state));
}
- float Score(const void *in_state, const WordIndex new_word, void *out_state) const {
+
+ FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const {
+ return static_cast<const Child*>(this)->FullScoreForgotState(
+ context_rbegin,
+ context_rend,
+ new_word,
+ *reinterpret_cast<State*>(out_state));
+ }
+
+ // Default Score function calls FullScore. Model can override this.
+ float Score(const State &in_state, const WordIndex new_word, State &out_state) const {
+ return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob;
+ }
+
+ float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const {
return static_cast<const Child*>(this)->Score(
*reinterpret_cast<const State*>(in_state),
new_word,
diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh
index 5b31620b..602b5b31 100644
--- a/klm/lm/filter/arpa_io.hh
+++ b/klm/lm/filter/arpa_io.hh
@@ -14,7 +14,6 @@
#include <string>
#include <vector>
-#include <err.h>
#include <string.h>
#include <stdint.h>
diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh
index 97c0fa25..d992026f 100644
--- a/klm/lm/filter/count_io.hh
+++ b/klm/lm/filter/count_io.hh
@@ -5,20 +5,18 @@
#include <iostream>
#include <string>
-#include <err.h>
-
+#include "util/fake_ofstream.hh"
+#include "util/file.hh"
#include "util/file_piece.hh"
namespace lm {
class CountOutput : boost::noncopyable {
public:
- explicit CountOutput(const char *name) : file_(name, std::ios::out) {}
+ explicit CountOutput(const char *name) : file_(util::CreateOrThrow(name)) {}
void AddNGram(const StringPiece &line) {
- if (!(file_ << line << '\n')) {
- err(3, "Writing counts file failed");
- }
+ file_ << line << '\n';
}
template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
@@ -30,7 +28,7 @@ class CountOutput : boost::noncopyable {
}
private:
- std::fstream file_;
+ util::FakeOFStream file_;
};
class CountBatch {
diff --git a/klm/lm/filter/filter_main.cc b/klm/lm/filter/filter_main.cc
index 1736bc40..82fdc1ef 100644
--- a/klm/lm/filter/filter_main.cc
+++ b/klm/lm/filter/filter_main.cc
@@ -6,6 +6,7 @@
#endif
#include "lm/filter/vocab.hh"
#include "lm/filter/wrapper.hh"
+#include "util/exception.hh"
#include "util/file_piece.hh"
#include <boost/ptr_container/ptr_vector.hpp>
@@ -157,92 +158,96 @@ template <class Format> void DispatchFilterModes(const Config &config, std::istr
} // namespace lm
int main(int argc, char *argv[]) {
- if (argc < 4) {
- lm::DisplayHelp(argv[0]);
- return 1;
- }
+ try {
+ if (argc < 4) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
- // I used to have boost::program_options, but some users didn't want to compile boost.
- lm::Config config;
- config.mode = lm::MODE_UNSET;
- for (int i = 1; i < argc - 2; ++i) {
- const char *str = argv[i];
- if (!std::strcmp(str, "copy")) {
- config.mode = lm::MODE_COPY;
- } else if (!std::strcmp(str, "single")) {
- config.mode = lm::MODE_SINGLE;
- } else if (!std::strcmp(str, "multiple")) {
- config.mode = lm::MODE_MULTIPLE;
- } else if (!std::strcmp(str, "union")) {
- config.mode = lm::MODE_UNION;
- } else if (!std::strcmp(str, "phrase")) {
- config.phrase = true;
- } else if (!std::strcmp(str, "context")) {
- config.context = true;
- } else if (!std::strcmp(str, "arpa")) {
- config.format = lm::FORMAT_ARPA;
- } else if (!std::strcmp(str, "raw")) {
- config.format = lm::FORMAT_COUNT;
+ // I used to have boost::program_options, but some users didn't want to compile boost.
+ lm::Config config;
+ config.mode = lm::MODE_UNSET;
+ for (int i = 1; i < argc - 2; ++i) {
+ const char *str = argv[i];
+ if (!std::strcmp(str, "copy")) {
+ config.mode = lm::MODE_COPY;
+ } else if (!std::strcmp(str, "single")) {
+ config.mode = lm::MODE_SINGLE;
+ } else if (!std::strcmp(str, "multiple")) {
+ config.mode = lm::MODE_MULTIPLE;
+ } else if (!std::strcmp(str, "union")) {
+ config.mode = lm::MODE_UNION;
+ } else if (!std::strcmp(str, "phrase")) {
+ config.phrase = true;
+ } else if (!std::strcmp(str, "context")) {
+ config.context = true;
+ } else if (!std::strcmp(str, "arpa")) {
+ config.format = lm::FORMAT_ARPA;
+ } else if (!std::strcmp(str, "raw")) {
+ config.format = lm::FORMAT_COUNT;
#ifndef NTHREAD
- } else if (!std::strncmp(str, "threads:", 8)) {
- config.threads = boost::lexical_cast<size_t>(str + 8);
- if (!config.threads) {
- std::cerr << "Specify at least one thread." << std::endl;
+ } else if (!std::strncmp(str, "threads:", 8)) {
+ config.threads = boost::lexical_cast<size_t>(str + 8);
+ if (!config.threads) {
+ std::cerr << "Specify at least one thread." << std::endl;
+ return 1;
+ }
+ } else if (!std::strncmp(str, "batch_size:", 11)) {
+ config.batch_size = boost::lexical_cast<size_t>(str + 11);
+ if (config.batch_size < 5000) {
+ std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
+ if (!config.batch_size) return 1;
+ }
+#endif
+ } else {
+ lm::DisplayHelp(argv[0]);
return 1;
}
- } else if (!std::strncmp(str, "batch_size:", 11)) {
- config.batch_size = boost::lexical_cast<size_t>(str + 11);
- if (config.batch_size < 5000) {
- std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
- if (!config.batch_size) return 1;
- }
-#endif
- } else {
+ }
+
+ if (config.mode == lm::MODE_UNSET) {
lm::DisplayHelp(argv[0]);
return 1;
}
- }
-
- if (config.mode == lm::MODE_UNSET) {
- lm::DisplayHelp(argv[0]);
- return 1;
- }
- if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) {
- std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
- return 1;
- }
+ if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) {
+ std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
+ return 1;
+ }
- bool cmd_is_model = true;
- const char *cmd_input = argv[argc - 2];
- if (!strncmp(cmd_input, "vocab:", 6)) {
- cmd_is_model = false;
- cmd_input += 6;
- } else if (!strncmp(cmd_input, "model:", 6)) {
- cmd_input += 6;
- } else if (strchr(cmd_input, ':')) {
- errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input);
- } else {
- std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
- }
- std::ifstream cmd_file;
- std::istream *vocab;
- if (cmd_is_model) {
- vocab = &std::cin;
- } else {
- cmd_file.open(cmd_input, std::ios::in);
- if (!cmd_file) {
- err(2, "Could not open input file %s", cmd_input);
+ bool cmd_is_model = true;
+ const char *cmd_input = argv[argc - 2];
+ if (!strncmp(cmd_input, "vocab:", 6)) {
+ cmd_is_model = false;
+ cmd_input += 6;
+ } else if (!strncmp(cmd_input, "model:", 6)) {
+ cmd_input += 6;
+ } else if (strchr(cmd_input, ':')) {
+ std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl;
+ return 1;
+ } else {
+ std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
+ }
+ std::ifstream cmd_file;
+ std::istream *vocab;
+ if (cmd_is_model) {
+ vocab = &std::cin;
+ } else {
+ cmd_file.open(cmd_input, std::ios::in);
+ UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input);
+ vocab = &cmd_file;
}
- vocab = &cmd_file;
- }
- util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
+ util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
- if (config.format == lm::FORMAT_ARPA) {
- lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
- } else if (config.format == lm::FORMAT_COUNT) {
- lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ if (config.format == lm::FORMAT_ARPA) {
+ lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
+ } else if (config.format == lm::FORMAT_COUNT) {
+ lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ }
+ return 0;
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
}
- return 0;
}
diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh
index 7f945b0d..7d8c28db 100644
--- a/klm/lm/filter/format.hh
+++ b/klm/lm/filter/format.hh
@@ -1,5 +1,5 @@
#ifndef LM_FILTER_FORMAT_H__
-#define LM_FITLER_FORMAT_H__
+#define LM_FILTER_FORMAT_H__
#include "lm/filter/arpa_io.hh"
#include "lm/filter/count_io.hh"
diff --git a/klm/lm/filter/phrase.cc b/klm/lm/filter/phrase.cc
index 1bef2a3f..e2946b14 100644
--- a/klm/lm/filter/phrase.cc
+++ b/klm/lm/filter/phrase.cc
@@ -48,21 +48,21 @@ unsigned int ReadMultiple(std::istream &in, Substrings &out) {
return sentence_id + sentence_content;
}
-namespace detail { const StringPiece kEndSentence("</s>"); }
-
namespace {
-
typedef unsigned int Sentence;
typedef std::vector<Sentence> Sentences;
+} // namespace
-class Vertex;
+namespace detail {
+
+const StringPiece kEndSentence("</s>");
class Arc {
public:
Arc() {}
// For arcs from one vertex to another.
- void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) {
+ void SetPhrase(detail::Vertex &from, detail::Vertex &to, const Sentences &intersect) {
Set(to, intersect);
from_ = &from;
}
@@ -71,7 +71,7 @@ class Arc {
* aligned). These have no from_ vertex; it implictly matches every
* sentence. This also handles when the n-gram is a substring of a phrase.
*/
- void SetRight(Vertex &to, const Sentences &complete) {
+ void SetRight(detail::Vertex &to, const Sentences &complete) {
Set(to, complete);
from_ = NULL;
}
@@ -97,11 +97,11 @@ class Arc {
void LowerBound(const Sentence to);
private:
- void Set(Vertex &to, const Sentences &sentences);
+ void Set(detail::Vertex &to, const Sentences &sentences);
const Sentence *current_;
const Sentence *last_;
- Vertex *from_;
+ detail::Vertex *from_;
};
struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> {
@@ -183,7 +183,13 @@ void Vertex::LowerBound(const Sentence to) {
}
}
-void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Vertex *const vertices, Arc *free_arc) {
+} // namespace detail
+
+namespace {
+
+void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, detail::Vertex *const vertices, detail::Arc *free_arc) {
+ using detail::Vertex;
+ using detail::Arc;
assert(!hashes.empty());
const Hash *const first_word = &*hashes.begin();
@@ -231,17 +237,29 @@ void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Verte
namespace detail {
-} // namespace detail
+// Here instead of header due to forward declaration.
+ConditionCommon::ConditionCommon(const Substrings &substrings) : substrings_(substrings) {}
-bool Union::Evaluate() {
+// Rest of the variables are temporaries anyway
+ConditionCommon::ConditionCommon(const ConditionCommon &from) : substrings_(from.substrings_) {}
+
+ConditionCommon::~ConditionCommon() {}
+
+detail::Vertex &ConditionCommon::MakeGraph() {
assert(!hashes_.empty());
- // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
- Vertex vertices[hashes_.size()];
+ vertices_.clear();
+ vertices_.resize(hashes_.size());
+ arcs_.clear();
// One for every substring.
- Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
- BuildGraph(substrings_, hashes_, vertices, arcs);
- Vertex &last_vertex = vertices[hashes_.size() - 1];
+ arcs_.resize(((hashes_.size() + 1) * hashes_.size()) / 2);
+ BuildGraph(substrings_, hashes_, &*vertices_.begin(), &*arcs_.begin());
+ return vertices_[hashes_.size() - 1];
+}
+
+} // namespace detail
+bool Union::Evaluate() {
+ detail::Vertex &last_vertex = MakeGraph();
unsigned int lower = 0;
while (true) {
last_vertex.LowerBound(lower);
@@ -252,14 +270,7 @@ bool Union::Evaluate() {
}
template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) {
- assert(!hashes_.empty());
- // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
- Vertex vertices[hashes_.size()];
- // One for every substring.
- Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
- BuildGraph(substrings_, hashes_, vertices, arcs);
- Vertex &last_vertex = vertices[hashes_.size() - 1];
-
+ detail::Vertex &last_vertex = MakeGraph();
unsigned int lower = 0;
while (true) {
last_vertex.LowerBound(lower);
diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh
index b4edff41..e8e85835 100644
--- a/klm/lm/filter/phrase.hh
+++ b/klm/lm/filter/phrase.hh
@@ -103,11 +103,33 @@ template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::
}
}
+class Vertex;
+class Arc;
+
+class ConditionCommon {
+ protected:
+ ConditionCommon(const Substrings &substrings);
+ ConditionCommon(const ConditionCommon &from);
+
+ ~ConditionCommon();
+
+ detail::Vertex &MakeGraph();
+
+ // Temporaries in PassNGram and Evaluate to avoid reallocation.
+ std::vector<Hash> hashes_;
+
+ private:
+ std::vector<detail::Vertex> vertices_;
+ std::vector<detail::Arc> arcs_;
+
+ const Substrings &substrings_;
+};
+
} // namespace detail
-class Union {
+class Union : public detail::ConditionCommon {
public:
- explicit Union(const Substrings &substrings) : substrings_(substrings) {}
+ explicit Union(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
detail::MakeHashes(begin, end, hashes_);
@@ -116,23 +138,19 @@ class Union {
private:
bool Evaluate();
-
- std::vector<Hash> hashes_;
-
- const Substrings &substrings_;
};
-class Multiple {
+class Multiple : public detail::ConditionCommon {
public:
- explicit Multiple(const Substrings &substrings) : substrings_(substrings) {}
+ explicit Multiple(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
detail::MakeHashes(begin, end, hashes_);
if (hashes_.empty()) {
output.AddNGram(line);
- return;
+ } else {
+ Evaluate(line, output);
}
- Evaluate(line, output);
}
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
@@ -143,10 +161,6 @@ class Multiple {
private:
template <class Output> void Evaluate(const StringPiece &line, Output &output);
-
- std::vector<Hash> hashes_;
-
- const Substrings &substrings_;
};
} // namespace phrase
diff --git a/klm/lm/filter/phrase_table_vocab_main.cc b/klm/lm/filter/phrase_table_vocab_main.cc
new file mode 100644
index 00000000..e0f47d89
--- /dev/null
+++ b/klm/lm/filter/phrase_table_vocab_main.cc
@@ -0,0 +1,165 @@
+#include "util/fake_ofstream.hh"
+#include "util/file_piece.hh"
+#include "util/murmur_hash.hh"
+#include "util/pool.hh"
+#include "util/string_piece.hh"
+#include "util/string_piece_hash.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/unordered_map.hpp>
+#include <boost/unordered_set.hpp>
+
+#include <cstddef>
+#include <vector>
+
+namespace {
+
+struct MutablePiece {
+ mutable StringPiece behind;
+ bool operator==(const MutablePiece &other) const {
+ return behind == other.behind;
+ }
+};
+
+std::size_t hash_value(const MutablePiece &m) {
+ return hash_value(m.behind);
+}
+
+class InternString {
+ public:
+ const char *Add(StringPiece str) {
+ MutablePiece mut;
+ mut.behind = str;
+ std::pair<boost::unordered_set<MutablePiece>::iterator, bool> res(strs_.insert(mut));
+ if (res.second) {
+ void *mem = backing_.Allocate(str.size() + 1);
+ memcpy(mem, str.data(), str.size());
+ static_cast<char*>(mem)[str.size()] = 0;
+ res.first->behind = StringPiece(static_cast<char*>(mem), str.size());
+ }
+ return res.first->behind.data();
+ }
+
+ private:
+ util::Pool backing_;
+ boost::unordered_set<MutablePiece> strs_;
+};
+
+class TargetWords {
+ public:
+ void Introduce(StringPiece source) {
+ vocab_.resize(vocab_.size() + 1);
+ std::vector<unsigned int> temp(1, vocab_.size() - 1);
+ Add(temp, source);
+ }
+
+ void Add(const std::vector<unsigned int> &sentences, StringPiece target) {
+ if (sentences.empty()) return;
+ interns_.clear();
+ for (util::TokenIter<util::SingleCharacter, true> i(target, ' '); i; ++i) {
+ interns_.push_back(intern_.Add(*i));
+ }
+ for (std::vector<unsigned int>::const_iterator i(sentences.begin()); i != sentences.end(); ++i) {
+ boost::unordered_set<const char *> &vocab = vocab_[*i];
+ for (std::vector<const char *>::const_iterator j = interns_.begin(); j != interns_.end(); ++j) {
+ vocab.insert(*j);
+ }
+ }
+ }
+
+ void Print() const {
+ util::FakeOFStream out(1);
+ for (std::vector<boost::unordered_set<const char *> >::const_iterator i = vocab_.begin(); i != vocab_.end(); ++i) {
+ for (boost::unordered_set<const char *>::const_iterator j = i->begin(); j != i->end(); ++j) {
+ out << *j << ' ';
+ }
+ out << '\n';
+ }
+ }
+
+ private:
+ InternString intern_;
+
+ std::vector<boost::unordered_set<const char *> > vocab_;
+
+ // Temporary in Add.
+ std::vector<const char *> interns_;
+};
+
+class Input {
+ public:
+ explicit Input(std::size_t max_length)
+ : max_length_(max_length), sentence_id_(0), empty_() {}
+
+ void AddSentence(StringPiece sentence, TargetWords &targets) {
+ canonical_.clear();
+ starts_.clear();
+ starts_.push_back(0);
+ for (util::TokenIter<util::AnyCharacter, true> i(sentence, StringPiece("\0 \t", 3)); i; ++i) {
+ canonical_.append(i->data(), i->size());
+ canonical_ += ' ';
+ starts_.push_back(canonical_.size());
+ }
+ targets.Introduce(canonical_);
+ for (std::size_t i = 0; i < starts_.size() - 1; ++i) {
+ std::size_t subtract = starts_[i];
+ const char *start = &canonical_[subtract];
+ for (std::size_t j = i + 1; j < std::min(starts_.size(), i + max_length_ + 1); ++j) {
+ map_[util::MurmurHash64A(start, &canonical_[starts_[j]] - start - 1)].push_back(sentence_id_);
+ }
+ }
+ ++sentence_id_;
+ }
+
+ // Assumes single space-delimited phrase with no space at the beginning or end.
+ const std::vector<unsigned int> &Matches(StringPiece phrase) const {
+ Map::const_iterator i = map_.find(util::MurmurHash64A(phrase.data(), phrase.size()));
+ return i == map_.end() ? empty_ : i->second;
+ }
+
+ private:
+ const std::size_t max_length_;
+
+ // hash of phrase is the key, array of sentences is the value.
+ typedef boost::unordered_map<uint64_t, std::vector<unsigned int> > Map;
+ Map map_;
+
+ std::size_t sentence_id_;
+
+ // Temporaries in AddSentence.
+ std::string canonical_;
+ std::vector<std::size_t> starts_;
+
+ const std::vector<unsigned int> empty_;
+};
+
+} // namespace
+
+int main(int argc, char *argv[]) {
+ if (argc != 2) {
+ std::cerr << "Expected source text on the command line" << std::endl;
+ return 1;
+ }
+ Input input(7);
+ TargetWords targets;
+ try {
+ util::FilePiece inputs(argv[1], &std::cerr);
+ while (true)
+ input.AddSentence(inputs.ReadLine(), targets);
+ } catch (const util::EndOfFileException &e) {}
+
+ util::FilePiece table(0, NULL, &std::cerr);
+ StringPiece line;
+ const StringPiece pipes("|||");
+ while (true) {
+ try {
+ line = table.ReadLine();
+ } catch (const util::EndOfFileException &e) { break; }
+ util::TokenIter<util::MultiCharacter> it(line, pipes);
+ StringPiece source(*it);
+ if (!source.empty() && source[source.size() - 1] == ' ')
+ source.remove_suffix(1);
+ targets.Add(input.Matches(source), *++it);
+ }
+ targets.Print();
+}
diff --git a/klm/lm/filter/vocab.cc b/klm/lm/filter/vocab.cc
index 7ee4e84b..011ab599 100644
--- a/klm/lm/filter/vocab.cc
+++ b/klm/lm/filter/vocab.cc
@@ -4,7 +4,6 @@
#include <iostream>
#include <ctype.h>
-#include <err.h>
namespace lm {
namespace vocab {
diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh
index 90b07a08..eb657501 100644
--- a/klm/lm/filter/wrapper.hh
+++ b/klm/lm/filter/wrapper.hh
@@ -39,17 +39,15 @@ template <class FilterT> class ContextFilter {
explicit ContextFilter(Filter &backend) : backend_(backend) {}
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
- pieces_.clear();
- // TODO: this copy could be avoided by a lookahead iterator.
- std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_));
- backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output);
+ // Find beginning of string or last space.
+ const char *last_space;
+ for (last_space = ngram.data() + ngram.size() - 1; last_space > ngram.data() && *last_space != ' '; --last_space) {}
+ backend_.AddNGram(StringPiece(ngram.data(), last_space - ngram.data()), line, output);
}
void Flush() const {}
private:
- std::vector<StringPiece> pieces_;
-
Filter backend_;
};
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index a26654a6..a5a16bf8 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -34,23 +34,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);
}
-template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
- LoadLM(file, config, *this);
-
- // g++ prints warnings unless these are fully initialized.
- State begin_sentence = State();
- begin_sentence.length = 1;
- begin_sentence.words[0] = vocab_.BeginSentence();
- typename Search::Node ignored_node;
- bool ignored_independent_left;
- uint64_t ignored_extend_left;
- begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
- State null_context = State();
- null_context.length = 0;
- P::Init(begin_sentence, null_context, vocab_, search_.Order());
+namespace {
+void ComplainAboutARPA(const Config &config, ModelType model_type) {
+ if (config.write_mmap || !config.messages) return;
+ if (config.arpa_complain == Config::ALL) {
+ *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
+ } else if (config.arpa_complain == Config::EXPENSIVE &&
+ (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
+ *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+ }
}
-namespace {
void CheckCounts(const std::vector<uint64_t> &counts) {
UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
if (sizeof(uint64_t) > sizeof(std::size_t)) {
@@ -59,18 +53,45 @@ void CheckCounts(const std::vector<uint64_t> &counts) {
}
}
}
+
} // namespace
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
- CheckCounts(params.counts);
- SetupMemory(start, params.counts, config);
- vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
- search_.LoadedBinary();
+template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &init_config) : backing_(init_config) {
+ util::scoped_fd fd(util::OpenReadOrThrow(file));
+ if (IsBinaryFormat(fd.get())) {
+ Parameters parameters;
+ int fd_shallow = fd.release();
+ backing_.InitializeBinary(fd_shallow, kModelType, kVersion, parameters);
+ CheckCounts(parameters.counts);
+
+ Config new_config(init_config);
+ new_config.probing_multiplier = parameters.fixed.probing_multiplier;
+ Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config);
+ UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+
+ SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config);
+ vocab_.LoadedBinary(parameters.fixed.has_vocabulary, fd_shallow, new_config.enumerate_vocab, backing_.VocabStringReadingOffset());
+ } else {
+ ComplainAboutARPA(init_config, kModelType);
+ InitializeFromARPA(fd.release(), file, init_config);
+ }
+
+ // g++ prints warnings unless these are fully initialized.
+ State begin_sentence = State();
+ begin_sentence.length = 1;
+ begin_sentence.words[0] = vocab_.BeginSentence();
+ typename Search::Node ignored_node;
+ bool ignored_independent_left;
+ uint64_t ignored_extend_left;
+ begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
+ State null_context = State();
+ null_context.length = 0;
+ P::Init(begin_sentence, null_context, vocab_, search_.Order());
}
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
- // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
- util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(int fd, const char *file, const Config &config) {
+ // Backing file is the ARPA.
+ util::FilePiece f(fd, file, config.ProgressMessages());
try {
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
@@ -81,13 +102,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
- vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
+ vocab_.SetupMemory(backing_.SetupJustVocab(vocab_size, counts.size()), vocab_size, counts[0], config);
- if (config.write_mmap) {
+ if (config.write_mmap && config.include_vocab) {
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));
+ void *vocab_rebase, *search_rebase;
+ backing_.WriteVocabWords(wrap.Buffer(), vocab_rebase, search_rebase);
+ // Due to writing at the end of file, mmap may have relocated data. So remap.
+ vocab_.Relocate(vocab_rebase);
+ search_.SetupMemory(reinterpret_cast<uint8_t*>(search_rebase), counts, config);
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
@@ -99,18 +124,13 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.UnknownUnigram().backoff = 0.0;
search_.UnknownUnigram().prob = config.unknown_missing_logprob;
}
- FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
+ backing_.FinishFile(config, kModelType, kVersion, counts);
} catch (util::Exception &e) {
e << " Byte: " << f.Offset();
throw;
}
}
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
- util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config));
- Search::UpdateConfigFromBinary(fd, counts, config);
-}
-
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) {
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index 60f55110..e75da93b 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -67,7 +67,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
/* Get the state for a context. Don't use this if you can avoid it. Use
- * BeginSentenceState or EmptyContextState and extend from those. If
+ * BeginSentenceState or NullContextState and extend from those. If
* you're only going to use this state to call FullScore once, use
* FullScoreForgotState.
* To use this function, make an array of WordIndex containing the context
@@ -104,10 +104,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
}
private:
- friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
-
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
-
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
// Score bigrams and above. Do not include backoff.
@@ -116,15 +112,11 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
// Appears after Size in the cc file.
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
- void InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd);
-
- void InitializeFromARPA(const char *file, const Config &config);
+ void InitializeFromARPA(int fd, const char *file, const Config &config);
float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const;
- Backing &MutableBacking() { return backing_; }
-
- Backing backing_;
+ BinaryFormat backing_;
VocabularyT vocab_;
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index eb159094..7005b05e 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -360,10 +360,11 @@ BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {
LoadingTest<QuantArrayTrieModel>();
}
-template <class ModelT> void BinaryTest() {
+template <class ModelT> void BinaryTest(Config::WriteMethod write_method) {
Config config;
config.write_mmap = "test.binary";
config.messages = NULL;
+ config.write_method = write_method;
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
@@ -406,6 +407,11 @@ template <class ModelT> void BinaryTest() {
unlink("test_nounk.binary");
}
+template <class ModelT> void BinaryTest() {
+ BinaryTest<ModelT>(Config::WRITE_MMAP);
+ BinaryTest<ModelT>(Config::WRITE_AFTER);
+}
+
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
BinaryTest<ProbingModel>();
}
diff --git a/klm/lm/ngram_query.hh b/klm/lm/ngram_query.hh
index dfcda170..ec2590f4 100644
--- a/klm/lm/ngram_query.hh
+++ b/klm/lm/ngram_query.hh
@@ -11,21 +11,25 @@
#include <istream>
#include <string>
+#include <math.h>
+
namespace lm {
namespace ngram {
template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
- std::cerr << "Loading statistics:\n";
- util::PrintUsage(std::cerr);
typename Model::State state, out;
lm::FullScoreReturn ret;
std::string word;
+ double corpus_total = 0.0;
+ uint64_t corpus_oov = 0;
+ uint64_t corpus_tokens = 0;
+
while (in_stream) {
state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
float total = 0.0;
bool got = false;
- unsigned int oov = 0;
+ uint64_t oov = 0;
while (in_stream >> word) {
got = true;
lm::WordIndex vocab = model.GetVocabulary().Index(word);
@@ -33,6 +37,7 @@ template <class Model> void Query(const Model &model, bool sentence_context, std
ret = model.FullScore(state, vocab, out);
total += ret.prob;
out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
+ ++corpus_tokens;
state = out;
char c;
while (true) {
@@ -50,12 +55,14 @@ template <class Model> void Query(const Model &model, bool sentence_context, std
if (sentence_context) {
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
+ ++corpus_tokens;
out_stream << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
}
out_stream << "Total: " << total << " OOV: " << oov << '\n';
+ corpus_total += total;
+ corpus_oov += oov;
}
- std::cerr << "After queries:\n";
- util::PrintUsage(std::cerr);
+ out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << std::endl;
}
template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc
index b58c3f3f..273ea398 100644
--- a/klm/lm/quantize.cc
+++ b/klm/lm/quantize.cc
@@ -38,13 +38,13 @@ const char kSeparatelyQuantizeVersion = 2;
} // namespace
-void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) {
- char version;
- util::ReadOrThrow(fd, &version, 1);
- util::ReadOrThrow(fd, &config.prob_bits, 1);
- util::ReadOrThrow(fd, &config.backoff_bits, 1);
+void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) {
+ unsigned char buffer[3];
+ file.ReadForConfig(buffer, 3, offset);
+ char version = buffer[0];
+ config.prob_bits = buffer[1];
+ config.backoff_bits = buffer[2];
if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion);
- util::AdvanceOrThrow(fd, -3);
}
void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) {
diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh
index 8ce2378a..9d3a2f43 100644
--- a/klm/lm/quantize.hh
+++ b/klm/lm/quantize.hh
@@ -18,12 +18,13 @@ namespace lm {
namespace ngram {
struct Config;
+class BinaryFormat;
/* Store values directly and don't quantize. */
class DontQuantize {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
- static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {}
static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
static uint8_t LongestBits(const Config &/*config*/) { return 31; }
@@ -136,7 +137,7 @@ class SeparatelyQuantize {
public:
static const ModelType kModelTypeAdd = kQuantAdd;
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
+ static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
static uint64_t Size(uint8_t order, const Config &config) {
uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float);
diff --git a/klm/lm/query_main.cc b/klm/lm/query_main.cc
index 27d3a1a5..bd4fde62 100644
--- a/klm/lm/query_main.cc
+++ b/klm/lm/query_main.cc
@@ -1,42 +1,65 @@
#include "lm/ngram_query.hh"
+#ifdef WITH_NPLM
+#include "lm/wrappers/nplm.hh"
+#endif
+
+#include <stdlib.h>
+
+void Usage(const char *name) {
+ std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl;
+ std::cerr << "Usage: " << name << " [-n] lm_file" << std::endl;
+ std::cerr << "Input is wrapped in <s> and </s> unless -n is passed." << std::endl;
+ exit(1);
+}
+
int main(int argc, char *argv[]) {
- if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) {
- std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl;
- std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl;
- std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl;
- return 1;
+ bool sentence_context = true;
+ const char *file = NULL;
+ for (char **arg = argv + 1; arg != argv + argc; ++arg) {
+ if (!strcmp(*arg, "-n")) {
+ sentence_context = false;
+ } else if (!strcmp(*arg, "-h") || !strcmp(*arg, "--help") || file) {
+ Usage(argv[0]);
+ } else {
+ file = *arg;
+ }
}
+ if (!file) Usage(argv[0]);
try {
- bool sentence_context = (argc == 2);
using namespace lm::ngram;
ModelType model_type;
- if (RecognizeBinary(argv[1], model_type)) {
+ if (RecognizeBinary(file, model_type)) {
switch(model_type) {
case PROBING:
- Query<lm::ngram::ProbingModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<lm::ngram::ProbingModel>(file, sentence_context, std::cin, std::cout);
break;
case REST_PROBING:
- Query<lm::ngram::RestProbingModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<lm::ngram::RestProbingModel>(file, sentence_context, std::cin, std::cout);
break;
case TRIE:
- Query<TrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<TrieModel>(file, sentence_context, std::cin, std::cout);
break;
case QUANT_TRIE:
- Query<QuantTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<QuantTrieModel>(file, sentence_context, std::cin, std::cout);
break;
case ARRAY_TRIE:
- Query<ArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<ArrayTrieModel>(file, sentence_context, std::cin, std::cout);
break;
case QUANT_ARRAY_TRIE:
- Query<QuantArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<QuantArrayTrieModel>(file, sentence_context, std::cin, std::cout);
break;
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
abort();
}
+#ifdef WITH_NPLM
+ } else if (lm::np::Model::Recognize(file)) {
+ lm::np::Model model(file);
+ Query(model, sentence_context, std::cin, std::cout);
+#endif
} else {
- Query<ProbingModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<ProbingModel>(file, sentence_context, std::cin, std::cout);
}
std::cerr << "Total time including destruction:\n";
util::PrintUsage(std::cerr);
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index 9ea08798..fb8bbfa2 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -19,7 +19,7 @@
namespace lm {
-// 1 for '\t', '\n', and ' '. This is stricter than isspace.
+// 1 for '\t', '\n', and ' '. This is stricter than isspace.
const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
namespace {
@@ -50,7 +50,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
// In general, ARPA files can have arbitrary text before "\data\"
// But in KenLM, we require such lines to start with "#", so that
// we can do stricter error checking
- while (IsEntirelyWhiteSpace(line) || line.starts_with("#")) {
+ while (IsEntirelyWhiteSpace(line) || starts_with(line, "#")) {
line = in.ReadLine();
}
@@ -58,7 +58,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
}
- if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)
+ if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)
UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?");
UTIL_THROW_IF(line.size() >= 4 && StringPiece(line.data(), 4) == "blmt", FormatLoadException, "This looks like an IRSTLM binary file. Did you forget to pass --text yes to compile-lm?");
UTIL_THROW_IF(line == "iARPA", FormatLoadException, "This looks like an IRSTLM iARPA file. You need an ARPA file. Run\n compile-lm --text yes " << in.FileName() << " " << in.FileName() << ".arpa\nfirst.");
@@ -66,7 +66,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
}
while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");
- // So strtol doesn't go off the end of line.
+ // So strtol doesn't go off the end of line.
std::string remaining(line.data() + 6, line.size() - 6);
char *end_ptr;
unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10);
@@ -102,8 +102,8 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
}
void ReadBackoff(util::FilePiece &in, float &backoff) {
- // Always make zero negative.
- // Negative zero means that no (n+1)-gram has this n-gram as context.
+ // Always make zero negative.
+ // Negative zero means that no (n+1)-gram has this n-gram as context.
// Therefore the hypothesis state can be shorter. Of course, many n-grams
// are context for (n+1)-grams. An algorithm in the data structure will go
// back and set the backoff to positive zero in these cases.
@@ -150,7 +150,7 @@ void PositiveProbWarn::Warn(float prob) {
case THROW_UP:
UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability. Error");
case COMPLAIN:
- std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapepd to 0 log probability." << std::endl;
+ std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapped to 0 log probability." << std::endl;
action_ = SILENT;
break;
case SILENT:
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 62275d27..354a56b4 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -204,9 +204,10 @@ template <class Build, class Activate, class Store> void ReadNGrams(
namespace detail {
template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
- std::size_t allocated = Unigram::Size(counts[0]);
- unigram_ = Unigram(start, counts[0], allocated);
- start += allocated;
+ unigram_ = Unigram(start, counts[0]);
+ start += Unigram::Size(counts[0]);
+ std::size_t allocated;
+ middle_.clear();
for (unsigned int n = 2; n < counts.size(); ++n) {
allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
middle_.push_back(Middle(start, allocated));
@@ -218,9 +219,21 @@ template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start,
return start;
}
-template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing) {
- // TODO: fix sorted.
- SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config);
+/*template <class Value> void HashedSearch<Value>::Relocate(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
+ unigram_ = Unigram(start, counts[0]);
+ start += Unigram::Size(counts[0]);
+ for (unsigned int n = 2; n < counts.size(); ++n) {
+ middle[n-2].Relocate(start);
+ start += Middle::Size(counts[n - 1], config.probing_multiplier)
+ }
+ longest_.Relocate(start);
+}*/
+
+template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing) {
+ void *vocab_rebase;
+ void *search_base = backing.GrowForSearch(Size(counts, config), vocab.UnkCountChangePadding(), vocab_rebase);
+ vocab.Relocate(vocab_rebase);
+ SetupMemory(reinterpret_cast<uint8_t*>(search_base), counts, config);
PositiveProbWarn warn(config.positive_log_probability);
Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn);
@@ -277,14 +290,6 @@ template <class Value> template <class Build> void HashedSearch<Value>::ApplyBui
ReadEnd(f);
}
-template <class Value> void HashedSearch<Value>::LoadedBinary() {
- unigram_.LoadedBinary();
- for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) {
- i->LoadedBinary();
- }
- longest_.LoadedBinary();
-}
-
template class HashedSearch<BackoffValue>;
template class HashedSearch<RestValue>;
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index 9d067bc2..8193262b 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -18,7 +18,7 @@ namespace util { class FilePiece; }
namespace lm {
namespace ngram {
-struct Backing;
+class BinaryFormat;
class ProbingVocabulary;
namespace detail {
@@ -72,7 +72,7 @@ template <class Value> class HashedSearch {
static const unsigned int kVersion = 0;
// TODO: move probing_multiplier here with next binary file format update.
- static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
uint64_t ret = Unigram::Size(counts[0]);
@@ -84,9 +84,7 @@ template <class Value> class HashedSearch {
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
- void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing);
-
- void LoadedBinary();
+ void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing);
unsigned char Order() const {
return middle_.size() + 2;
@@ -148,7 +146,7 @@ template <class Value> class HashedSearch {
public:
Unigram() {}
- Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
+ Unigram(void *start, uint64_t count) :
unigram_(static_cast<typename Value::Weights*>(start))
#ifdef DEBUG
, count_(count)
@@ -168,8 +166,6 @@ template <class Value> class HashedSearch {
typename Value::Weights &Unknown() { return unigram_[0]; }
- void LoadedBinary() {}
-
// For building.
typename Value::Weights *Raw() { return unigram_; }
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 1b0d9b26..4a88194e 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -253,11 +253,6 @@ class FindBlanks {
++counts_.back();
}
- // Unigrams wrote one past.
- void Cleanup() {
- --counts_[0];
- }
-
const std::vector<uint64_t> &Counts() const {
return counts_;
}
@@ -310,8 +305,6 @@ template <class Quant, class Bhiksha> class WriteEntries {
typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob);
}
- void Cleanup() {}
-
private:
RecordReader *contexts_;
const Quant &quant_;
@@ -385,14 +378,14 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con
util::ErsatzProgress progress(unigram_count + 1, progress_out, message);
WordIndex unigram = 0;
std::priority_queue<Gram> grams;
- grams.push(Gram(&unigram, 1));
+ if (unigram_count) grams.push(Gram(&unigram, 1));
for (unsigned char i = 2; i <= total_order; ++i) {
if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i));
}
BlankManager<Doing> blank(total_order, doing);
- while (true) {
+ while (!grams.empty()) {
Gram top = grams.top();
grams.pop();
unsigned char order = top.end - top.begin;
@@ -400,8 +393,7 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con
blank.Visit(&unigram, 1, doing.UnigramProb(unigram));
doing.Unigram(unigram);
progress.Set(unigram);
- if (++unigram == unigram_count + 1) break;
- grams.push(top);
+ if (++unigram < unigram_count) grams.push(top);
} else {
if (order == total_order) {
blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob);
@@ -414,8 +406,6 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con
if (++reader) grams.push(top);
}
}
- assert(grams.empty());
- doing.Cleanup();
}
void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {
@@ -469,7 +459,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c
} // namespace
-template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) {
RecordReader inputs[KENLM_MAX_ORDER - 1];
RecordReader contexts[KENLM_MAX_ORDER - 1];
@@ -498,7 +488,10 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
- out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
+ void *vocab_relocate;
+ void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate);
+ vocab.Relocate(vocab_relocate);
+ out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config);
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
@@ -524,6 +517,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
{
WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);
+ // Write the last unigram entry, which is the end pointer for the bigrams.
+ writer.Unigram(counts[0]);
}
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
@@ -579,15 +574,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
-template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
- unigram_.LoadedBinary();
- for (Middle *i = middle_begin_; i != middle_end_; ++i) {
- i->LoadedBinary();
- }
- longest_.LoadedBinary();
-}
-
-template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) {
std::string temporary_prefix;
if (config.temporary_directory_prefix) {
temporary_prefix = config.temporary_directory_prefix;
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 60be416b..299262a5 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -11,18 +11,19 @@
#include "util/file_piece.hh"
#include <vector>
+#include <cstdlib>
#include <assert.h>
namespace lm {
namespace ngram {
-struct Backing;
+class BinaryFormat;
class SortedVocabulary;
namespace trie {
template <class Quant, class Bhiksha> class TrieSearch;
class SortedFiles;
-template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
+template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
template <class Quant, class Bhiksha> class TrieSearch {
public:
@@ -38,11 +39,11 @@ template <class Quant, class Bhiksha> class TrieSearch {
static const unsigned int kVersion = 1;
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
- Quant::UpdateConfigFromBinary(fd, counts, config);
- util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
+ static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) {
+ Quant::UpdateConfigFromBinary(file, offset, config);
// Currently the unigram pointers are not compresssed, so there will only be a header for order > 2.
- if (counts.size() > 2) Bhiksha::UpdateConfigFromBinary(fd, config);
+ if (counts.size() > 2)
+ Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config);
}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
@@ -59,9 +60,7 @@ template <class Quant, class Bhiksha> class TrieSearch {
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
- void LoadedBinary();
-
- void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
+ void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing);
unsigned char Order() const {
return middle_end_ - middle_begin_ + 2;
@@ -102,14 +101,14 @@ template <class Quant, class Bhiksha> class TrieSearch {
}
private:
- friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
+ friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
- // Middles are managed manually so we can delay construction and they don't have to be copyable.
+ // Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
for (const Middle *i = middle_begin_; i != middle_end_; ++i) {
i->~Middle();
}
- free(middle_begin_);
+ std::free(middle_begin_);
}
typedef trie::BitPackedMiddle<Bhiksha> Middle;
diff --git a/klm/lm/state.hh b/klm/lm/state.hh
index a6b9accb..543df37c 100644
--- a/klm/lm/state.hh
+++ b/klm/lm/state.hh
@@ -102,7 +102,7 @@ struct ChartState {
}
bool operator<(const ChartState &other) const {
- return Compare(other) == -1;
+ return Compare(other) < 0;
}
void ZeroRemaining() {
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index 9ea3c546..d858ab5e 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -62,8 +62,6 @@ class Unigram {
return unigram_;
}
- void LoadedBinary() {}
-
UnigramPointer Find(WordIndex word, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
next.begin = val->next;
@@ -108,8 +106,6 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked {
void FinishedLoading(uint64_t next_end, const Config &config);
- void LoadedBinary() { bhiksha_.LoadedBinary(); }
-
util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;
util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
@@ -138,14 +134,9 @@ class BitPackedLongest : public BitPacked {
BaseInit(base, max_vocab, quant_bits);
}
- void LoadedBinary() {}
-
util::BitAddress Insert(WordIndex word);
util::BitAddress Find(WordIndex word, const NodeRange &node) const;
-
- private:
- uint8_t quant_bits_;
};
} // namespace trie
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index dc542bb3..126d43ab 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -50,6 +50,10 @@ class PartialViewProxy {
const void *Data() const { return inner_.Data(); }
void *Data() { return inner_.Data(); }
+ friend void swap(PartialViewProxy first, PartialViewProxy second) {
+ std::swap_ranges(reinterpret_cast<char*>(first.Data()), reinterpret_cast<char*>(first.Data()) + first.attention_size_, reinterpret_cast<char*>(second.Data()));
+ }
+
private:
friend class util::ProxyIterator<PartialViewProxy>;
diff --git a/klm/lm/value_build.cc b/klm/lm/value_build.cc
index 6124f8da..3ec3dce2 100644
--- a/klm/lm/value_build.cc
+++ b/klm/lm/value_build.cc
@@ -9,6 +9,7 @@ namespace ngram {
template <class Model> LowerRestBuild<Model>::LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab) {
UTIL_THROW_IF(config.rest_lower_files.size() != order - 1, ConfigException, "This model has order " << order << " so there should be " << (order - 1) << " lower-order models for rest cost purposes.");
Config for_lower = config;
+ for_lower.write_mmap = NULL;
for_lower.rest_lower_files.clear();
// Unigram models aren't supported, so this is a custom loader.
diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh
index 17f064b2..7a3e2379 100644
--- a/klm/lm/virtual_interface.hh
+++ b/klm/lm/virtual_interface.hh
@@ -125,10 +125,13 @@ class Model {
void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); }
// Requires in_state != out_state
- virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
+ virtual float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
// Requires in_state != out_state
- virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
+ virtual FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
+
+ // Prefer to use FullScore. The context words should be provided in reverse order.
+ virtual FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0;
unsigned char Order() const { return order_; }
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index fd7f96dc..7f0878f4 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -32,7 +32,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
-void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) {
+void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) {
+ util::SeekOrThrow(fd, offset);
// Check that we're at the right place by reading <unk> which is always first.
char check_unk[6];
util::ReadOrThrow(fd, check_unk, 6);
@@ -80,11 +81,6 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
buffer_.push_back(0);
}
-void WriteWordsWrapper::Write(int fd, uint64_t start) {
- util::SeekOrThrow(fd, start);
- util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
-}
-
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
@@ -100,6 +96,12 @@ void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size
saw_unk_ = false;
}
+void SortedVocabulary::Relocate(void *new_start) {
+ std::size_t delta = end_ - begin_;
+ begin_ = reinterpret_cast<uint64_t*>(new_start) + 1;
+ end_ = begin_ + delta;
+}
+
void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {
enumerate_ = to;
if (enumerate_) {
@@ -147,11 +149,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
bound_ = end_ - begin_ + 1;
}
-void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
+void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
- if (have_words) ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_, offset);
}
namespace {
@@ -179,6 +181,11 @@ void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::siz
saw_unk_ = false;
}
+void ProbingVocabulary::Relocate(void *new_start) {
+ header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start);
+ lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)));
+}
+
void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) {
enumerate_ = to;
if (enumerate_) {
@@ -206,12 +213,11 @@ void ProbingVocabulary::InternalFinishedLoading() {
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
-void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
+void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
- lookup_.LoadedBinary();
bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
- if (have_words) ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_, offset);
}
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 226ae438..074b74d8 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -36,7 +36,7 @@ class WriteWordsWrapper : public EnumerateVocab {
void Add(WordIndex index, const StringPiece &str);
- void Write(int fd, uint64_t start);
+ const std::string &Buffer() const { return buffer_; }
private:
EnumerateVocab *inner_;
@@ -71,6 +71,8 @@ class SortedVocabulary : public base::Vocabulary {
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
+ void Relocate(void *new_start);
+
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
WordIndex Insert(const StringPiece &str);
@@ -83,15 +85,13 @@ class SortedVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
uint64_t *begin_, *end_;
WordIndex bound_;
- WordIndex highest_value_;
-
bool saw_unk_;
EnumerateVocab *enumerate_;
@@ -140,6 +140,8 @@ class ProbingVocabulary : public base::Vocabulary {
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
+ void Relocate(void *new_start);
+
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
WordIndex Insert(const StringPiece &str);
@@ -152,7 +154,7 @@ class ProbingVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
void InternalFinishedLoading();