summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2011-08-18 12:14:01 +0100
committerKenneth Heafield <github@kheafield.com>2011-08-18 12:14:01 +0100
commit2c14cf2218031c29a9884bccf17e9273c71a33b2 (patch)
treec6afcdffb542dea214fe0bd3fad865527e65eb5c /klm/lm
parentd73b5d25bd0af14a4a83490d67ba2553b6af9884 (diff)
KenLM update: Bhiksha's trick, simple test for lms without unk, auto-detect binary files instead of requiring them to be specified at runtime.
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/Makefile.am1
-rw-r--r--klm/lm/bhiksha.cc93
-rw-r--r--klm/lm/bhiksha.hh108
-rw-r--r--klm/lm/binary_format.cc13
-rw-r--r--klm/lm/binary_format.hh9
-rw-r--r--klm/lm/build_binary.cc54
-rw-r--r--klm/lm/config.cc1
-rw-r--r--klm/lm/config.hh5
-rw-r--r--klm/lm/model.cc67
-rw-r--r--klm/lm/model.hh12
-rw-r--r--klm/lm/model_test.cc73
-rw-r--r--klm/lm/ngram_query.cc9
-rw-r--r--klm/lm/quantize.cc1
-rw-r--r--klm/lm/quantize.hh4
-rw-r--r--klm/lm/read_arpa.cc6
-rw-r--r--klm/lm/search_hashed.cc2
-rw-r--r--klm/lm/search_hashed.hh3
-rw-r--r--klm/lm/search_trie.cc45
-rw-r--r--klm/lm/search_trie.hh20
-rw-r--r--klm/lm/test_nounk.arpa120
-rw-r--r--klm/lm/trie.cc57
-rw-r--r--klm/lm/trie.hh24
-rw-r--r--klm/lm/vocab.cc6
-rw-r--r--klm/lm/vocab.hh4
24 files changed, 586 insertions, 151 deletions
diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am
index 395494bc..fae6b41a 100644
--- a/klm/lm/Makefile.am
+++ b/klm/lm/Makefile.am
@@ -12,6 +12,7 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz
noinst_LIBRARIES = libklm.a
libklm_a_SOURCES = \
+ bhiksha.cc \
binary_format.cc \
config.cc \
lm_exception.cc \
diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc
new file mode 100644
index 00000000..bf86fd4b
--- /dev/null
+++ b/klm/lm/bhiksha.cc
@@ -0,0 +1,93 @@
+#include "lm/bhiksha.hh"
+#include "lm/config.hh"
+
+#include <limits>
+
+namespace lm {
+namespace ngram {
+namespace trie {
+
+DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) :
+ next_(util::BitsMask::ByMax(max_next)) {}
+
+const uint8_t kArrayBhikshaVersion = 0;
+
+void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) {
+ uint8_t version;
+ uint8_t configured_bits;
+ if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) {
+ UTIL_THROW(util::ErrnoException, "Could not read from binary file");
+ }
+ if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion);
+ config.pointer_bhiksha_bits = configured_bits;
+}
+
+namespace {
+
+// Find argmin_{chopped \in [0, RequiredBits(max_next)]} ChoppedDelta(max_offset)
+uint8_t ChopBits(uint64_t max_offset, uint64_t max_next, const Config &config) {
+ uint8_t required = util::RequiredBits(max_next);
+ uint8_t best_chop = 0;
+ int64_t lowest_change = std::numeric_limits<int64_t>::max();
+ // There are probably faster ways but I don't care because this is only done once per order at construction time.
+ for (uint8_t chop = 0; chop <= std::min(required, config.pointer_bhiksha_bits); ++chop) {
+ int64_t change = (max_next >> (required - chop)) * 64 /* table cost in bits */
+ - max_offset * static_cast<int64_t>(chop); /* savings in bits*/
+ if (change < lowest_change) {
+ lowest_change = change;
+ best_chop = chop;
+ }
+ }
+ return best_chop;
+}
+
+std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &config) {
+ uint8_t required = util::RequiredBits(max_next);
+ uint8_t chopping = ChopBits(max_offset, max_next, config);
+ return (max_next >> (required - chopping)) + 1 /* we store 0 too */;
+}
+} // namespace
+
+std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) {
+ return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */;
+}
+
+uint8_t ArrayBhiksha::InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config) {
+ return util::RequiredBits(max_next) - ChopBits(max_offset, max_next, config);
+}
+
+namespace {
+
+void *AlignTo8(void *from) {
+ uint8_t *val = reinterpret_cast<uint8_t*>(from);
+ std::size_t remainder = reinterpret_cast<std::size_t>(val) & 7;
+ if (!remainder) return val;
+ return val + 8 - remainder;
+}
+
+} // namespace
+
+ArrayBhiksha::ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_next, const Config &config)
+ : next_inline_(util::BitsMask::ByBits(InlineBits(max_offset, max_next, config))),
+ offset_begin_(reinterpret_cast<const uint64_t*>(AlignTo8(base)) + 1 /* 8-byte header */),
+ offset_end_(offset_begin_ + ArrayCount(max_offset, max_next, config)),
+ write_to_(reinterpret_cast<uint64_t*>(AlignTo8(base)) + 1 /* 8-byte header */ + 1 /* first entry is 0 */),
+ original_base_(base) {}
+
+void ArrayBhiksha::FinishedLoading(const Config &config) {
+ // *offset_begin_ = 0 but without a const_cast.
+ *(write_to_ - (write_to_ - offset_begin_)) = 0;
+
+ if (write_to_ != offset_end_) UTIL_THROW(util::Exception, "Did not get all the array entries that were expected.");
+
+ uint8_t *head_write = reinterpret_cast<uint8_t*>(original_base_);
+ *(head_write++) = kArrayBhikshaVersion;
+ *(head_write++) = config.pointer_bhiksha_bits;
+}
+
+void ArrayBhiksha::LoadedBinary() {
+}
+
+} // namespace trie
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh
new file mode 100644
index 00000000..cfb2b053
--- /dev/null
+++ b/klm/lm/bhiksha.hh
@@ -0,0 +1,108 @@
+/* Simple implementation of
+ * @inproceedings{bhikshacompression,
+ * author={Bhiksha Raj and Ed Whittaker},
+ * year={2003},
+ * title={Lossless Compression of Language Model Structure and Word Identifiers},
+ * booktitle={Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing},
+ * pages={388--391},
+ * }
+ *
+ * Currently only used for next pointers.
+ */
+
+#include <inttypes.h>
+
+#include "lm/binary_format.hh"
+#include "lm/trie.hh"
+#include "util/bit_packing.hh"
+#include "util/sorted_uniform.hh"
+
+namespace lm {
+namespace ngram {
+class Config;
+
+namespace trie {
+
+class DontBhiksha {
+ public:
+ static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
+
+ static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {}
+
+ static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
+
+ static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) {
+ return util::RequiredBits(max_next);
+ }
+
+ DontBhiksha(const void *base, uint64_t max_offset, uint64_t max_next, const Config &config);
+
+ void ReadNext(const void *base, uint64_t bit_offset, uint64_t /*index*/, uint8_t total_bits, NodeRange &out) const {
+ out.begin = util::ReadInt57(base, bit_offset, next_.bits, next_.mask);
+ out.end = util::ReadInt57(base, bit_offset + total_bits, next_.bits, next_.mask);
+ //assert(out.end >= out.begin);
+ }
+
+ void WriteNext(void *base, uint64_t bit_offset, uint64_t /*index*/, uint64_t value) {
+ util::WriteInt57(base, bit_offset, next_.bits, value);
+ }
+
+ void FinishedLoading(const Config &/*config*/) {}
+
+ void LoadedBinary() {}
+
+ uint8_t InlineBits() const { return next_.bits; }
+
+ private:
+ util::BitsMask next_;
+};
+
+class ArrayBhiksha {
+ public:
+ static const ModelType kModelTypeAdd = kArrayAdd;
+
+ static void UpdateConfigFromBinary(int fd, Config &config);
+
+ static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
+
+ static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config);
+
+ ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config);
+
+ void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const {
+ const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor<uint64_t>(), offset_begin_, offset_end_, index);
+ const uint64_t *end_it;
+ for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {}
+ --end_it;
+ out.begin = ((begin_it - offset_begin_) << next_inline_.bits) |
+ util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask);
+ out.end = ((end_it - offset_begin_) << next_inline_.bits) |
+ util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask);
+ }
+
+ void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) {
+ uint64_t encode = value >> next_inline_.bits;
+ for (; write_to_ <= offset_begin_ + encode; ++write_to_) *write_to_ = index;
+ util::WriteInt57(base, bit_offset, next_inline_.bits, value & next_inline_.mask);
+ }
+
+ void FinishedLoading(const Config &config);
+
+ void LoadedBinary();
+
+ uint8_t InlineBits() const { return next_inline_.bits; }
+
+ private:
+ const util::BitsMask next_inline_;
+
+ const uint64_t *const offset_begin_;
+ const uint64_t *const offset_end_;
+
+ uint64_t *write_to_;
+
+ void *original_base_;
+};
+
+} // namespace trie
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 92b1008b..e02e621a 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -40,7 +40,7 @@ struct Sanity {
}
};
-const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"};
+const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
std::size_t Align8(std::size_t in) {
std::size_t off = in % 8;
@@ -100,16 +100,17 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
}
}
-uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) {
+uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
+ std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
if (config.write_mmap) {
// Grow the file to accomodate the search, using zeros.
- if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size))
- UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed");
+ if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size))
+ UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed");
// We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
off_t page_size = sysconf(_SC_PAGE_SIZE);
- off_t alignment_cruft = backing.vocab.size() % page_size;
- backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ off_t alignment_cruft = adjusted_vocab % page_size;
+ backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
} else {
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh
index 2b32b450..d28cb6c5 100644
--- a/klm/lm/binary_format.hh
+++ b/klm/lm/binary_format.hh
@@ -16,7 +16,12 @@
namespace lm {
namespace ngram {
-typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType;
+/* Not the best numbering system, but it grew this way for historical reasons
+ * and I want to preserve existing binary files. */
+typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType;
+
+const static ModelType kQuantAdd = static_cast<ModelType>(QUANT_TRIE_SORTED - TRIE_SORTED);
+const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE_SORTED - TRIE_SORTED);
/*Inspect a file to determine if it is a binary lm. If not, return false.
* If so, return true and set recognized to the type. This is the only API in
@@ -55,7 +60,7 @@ void AdvanceOrThrow(int fd, off_t off);
// Create just enough of a binary file to write vocabulary to it.
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
-uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing);
+uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing);
// Write header to binary file. This is done last to prevent incomplete files
// from loading.
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index 4552c419..b7aee4de 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -15,12 +15,12 @@ namespace ngram {
namespace {
void Usage(const char *name) {
- std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n"
-"-u sets the default log10 probability for <unk> if the ARPA file does not have\n"
-"one.\n"
+ std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n"
+"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
+" Default is -100. The ARPA file will always take precedence.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n"
-"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n"
-"type is either probing or trie:\n\n"
+"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n"
+"type is either probing or trie. Default is probing.\n\n"
"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n"
"trie is a straightforward trie with bit-level packing. It uses the least\n"
@@ -29,10 +29,11 @@ void Usage(const char *name) {
"-t is the temporary directory prefix. Default is the output file name.\n"
"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n"
"-q turns quantization on and sets the number of bits (e.g. -q 8).\n"
-"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n"
-"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n"
-"Passing only an input file will print memory usage of each data structure.\n"
-"If the ARPA file does not have <unk>, -u sets <unk>'s probability; default 0.0.\n";
+"-b sets backoff quantization bits. Requires -q and defaults to that value.\n"
+"-a compresses pointers using an array of offsets. The parameter is the\n"
+" maximum number of bits encoded by the array. Memory is minimized subject\n"
+" to the maximum, so pick 255 to minimize memory.\n\n"
+"Get a memory estimate by passing an ARPA file without an output file name.\n";
exit(1);
}
@@ -63,12 +64,14 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
std::vector<uint64_t> counts;
util::FilePiece f(file);
lm::ReadARPACounts(f, counts);
- std::size_t sizes[3];
+ std::size_t sizes[5];
sizes[0] = ProbingModel::Size(counts, config);
sizes[1] = TrieModel::Size(counts, config);
sizes[2] = QuantTrieModel::Size(counts, config);
- std::size_t max_length = *std::max_element(sizes, sizes + 3);
- std::size_t min_length = *std::max_element(sizes, sizes + 3);
+ sizes[3] = ArrayTrieModel::Size(counts, config);
+ sizes[4] = QuantArrayTrieModel::Size(counts, config);
+ std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
+ std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
std::size_t divide;
char prefix;
if (min_length < (1 << 10) * 10) {
@@ -91,7 +94,9 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
std::cout << prefix << "B\n"
"probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"
"trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n"
- "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n";
+ "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
+ "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
+ "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";
}
void ProbingQuantizationUnsupported() {
@@ -106,11 +111,11 @@ void ProbingQuantizationUnsupported() {
int main(int argc, char *argv[]) {
using namespace lm::ngram;
- bool quantize = false, set_backoff_bits = false;
try {
+ bool quantize = false, set_backoff_bits = false, bhiksha = false;
lm::ngram::Config config;
int opt;
- while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) {
+ while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
@@ -121,6 +126,9 @@ int main(int argc, char *argv[]) {
config.backoff_bits = ParseBitCount(optarg);
set_backoff_bits = true;
break;
+ case 'a':
+ config.pointer_bhiksha_bits = ParseBitCount(optarg);
+ bhiksha = true;
case 'u':
config.unknown_missing_logprob = ParseFloat(optarg);
break;
@@ -162,9 +170,17 @@ int main(int argc, char *argv[]) {
ProbingModel(from_file, config);
} else if (!strcmp(model_type, "trie")) {
if (quantize) {
- QuantTrieModel(from_file, config);
+ if (bhiksha) {
+ QuantArrayTrieModel(from_file, config);
+ } else {
+ QuantTrieModel(from_file, config);
+ }
} else {
- TrieModel(from_file, config);
+ if (bhiksha) {
+ ArrayTrieModel(from_file, config);
+ } else {
+ TrieModel(from_file, config);
+ }
}
} else {
Usage(argv[0]);
@@ -173,9 +189,9 @@ int main(int argc, char *argv[]) {
Usage(argv[0]);
}
}
- catch (std::exception &e) {
+ catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
- abort();
+ return 1;
}
return 0;
}
diff --git a/klm/lm/config.cc b/klm/lm/config.cc
index 08e1af5c..297589a4 100644
--- a/klm/lm/config.cc
+++ b/klm/lm/config.cc
@@ -20,6 +20,7 @@ Config::Config() :
include_vocab(true),
prob_bits(8),
backoff_bits(8),
+ pointer_bhiksha_bits(22),
load_method(util::POPULATE_OR_READ) {}
} // namespace ngram
diff --git a/klm/lm/config.hh b/klm/lm/config.hh
index dcc7cf35..227b8512 100644
--- a/klm/lm/config.hh
+++ b/klm/lm/config.hh
@@ -73,9 +73,12 @@ struct Config {
// Quantization options. Only effective for QuantTrieModel. One value is
// reserved for each of prob and backoff, so 2^bits - 1 buckets will be used
- // to quantize.
+ // to quantize (and one of the remaining backoffs will be 0).
uint8_t prob_bits, backoff_bits;
+ // Bhiksha compression (simple form). Only works with trie.
+ uint8_t pointer_bhiksha_bits;
+
// ONLY EFFECTIVE WHEN READING BINARY
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index a1d10b3d..27e24b1c 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -21,6 +21,8 @@ size_t hash_value(const State &state) {
namespace detail {
+template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType;
+
template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
@@ -56,35 +58,40 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util::FilePiece f(backing_.file.release(), file, config.messages);
- std::vector<uint64_t> counts;
- // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
- ReadARPACounts(f, counts);
-
- if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
- if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
- if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
-
- std::size_t vocab_size = VocabularyT::Size(counts[0], config);
- // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
- vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
-
- if (config.write_mmap) {
- WriteWordsWrapper wrap(config.enumerate_vocab);
- vocab_.ConfigureEnumerate(&wrap, counts[0]);
- search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get());
- } else {
- vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
- search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- }
+ try {
+ std::vector<uint64_t> counts;
+ // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
+ ReadARPACounts(f, counts);
+
+ if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
+ if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
+ if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
+
+ std::size_t vocab_size = VocabularyT::Size(counts[0], config);
+ // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
+ vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
+
+ if (config.write_mmap) {
+ WriteWordsWrapper wrap(config.enumerate_vocab);
+ vocab_.ConfigureEnumerate(&wrap, counts[0]);
+ search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
+ wrap.Write(backing_.file.get());
+ } else {
+ vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
+ search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
+ }
- if (!vocab_.SawUnk()) {
- assert(config.unknown_missing != THROW_UP);
- // Default probabilities for unknown.
- search_.unigram.Unknown().backoff = 0.0;
- search_.unigram.Unknown().prob = config.unknown_missing_logprob;
+ if (!vocab_.SawUnk()) {
+ assert(config.unknown_missing != THROW_UP);
+ // Default probabilities for unknown.
+ search_.unigram.Unknown().backoff = 0.0;
+ search_.unigram.Unknown().prob = config.unknown_missing_logprob;
+ }
+ FinishFile(config, kModelType, counts, backing_);
+ } catch (util::Exception &e) {
+ e << " Byte: " << f.Offset();
+ throw;
}
- FinishFile(config, kModelType, counts, backing_);
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
@@ -225,8 +232,10 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
}
template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; // HASH_PROBING
-template class GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary>; // TRIE_SORTED
-template class GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary>; // TRIE_SORTED_QUANT
+template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED
+template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
+template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED_QUANT
+template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
} // namespace detail
} // namespace ngram
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index 1f49a382..21595321 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -1,6 +1,7 @@
#ifndef LM_MODEL__
#define LM_MODEL__
+#include "lm/bhiksha.hh"
#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "lm/facade.hh"
@@ -71,6 +72,9 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
private:
typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
public:
+ // This is the model type returned by RecognizeBinary.
+ static const ModelType kModelType;
+
/* Get the size of memory that will be mapped given ngram counts. This
* does not include small non-mapped control structures, such as this class
* itself.
@@ -131,8 +135,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
Backing &MutableBacking() { return backing_; }
- static const ModelType kModelType = Search::kModelType;
-
Backing backing_;
VocabularyT vocab_;
@@ -152,9 +154,11 @@ typedef ProbingModel Model;
// Smaller implementation.
typedef ::lm::ngram::SortedVocabulary SortedVocabulary;
-typedef detail::GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary> TrieModel; // TRIE_SORTED
+typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary> TrieModel; // TRIE_SORTED
+typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary> ArrayTrieModel;
-typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED
+typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED
+typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary> QuantArrayTrieModel;
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 8bf040ff..57c7291c 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -193,6 +193,14 @@ template <class M> void Stateless(const M &model) {
BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]);
}
+template <class M> void NoUnkCheck(const M &model) {
+ WordIndex unk_index = 0;
+ State state;
+
+ FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state);
+ BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001);
+}
+
template <class M> void Everything(const M &m) {
Starters(m);
Continuation(m);
@@ -231,25 +239,38 @@ template <class ModelT> void LoadingTest() {
Config config;
config.arpa_complain = Config::NONE;
config.messages = NULL;
- ExpectEnumerateVocab enumerate;
- config.enumerate_vocab = &enumerate;
config.probing_multiplier = 2.0;
- ModelT m("test.arpa", config);
- enumerate.Check(m.GetVocabulary());
- Everything(m);
+ {
+ ExpectEnumerateVocab enumerate;
+ config.enumerate_vocab = &enumerate;
+ ModelT m("test.arpa", config);
+ enumerate.Check(m.GetVocabulary());
+ Everything(m);
+ }
+ {
+ ExpectEnumerateVocab enumerate;
+ config.enumerate_vocab = &enumerate;
+ ModelT m("test_nounk.arpa", config);
+ enumerate.Check(m.GetVocabulary());
+ NoUnkCheck(m);
+ }
}
BOOST_AUTO_TEST_CASE(probing) {
LoadingTest<Model>();
}
-
BOOST_AUTO_TEST_CASE(trie) {
LoadingTest<TrieModel>();
}
-
-BOOST_AUTO_TEST_CASE(quant) {
+BOOST_AUTO_TEST_CASE(quant_trie) {
LoadingTest<QuantTrieModel>();
}
+BOOST_AUTO_TEST_CASE(bhiksha_trie) {
+ LoadingTest<ArrayTrieModel>();
+}
+BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {
+ LoadingTest<QuantArrayTrieModel>();
+}
template <class ModelT> void BinaryTest() {
Config config;
@@ -267,10 +288,34 @@ template <class ModelT> void BinaryTest() {
config.write_mmap = NULL;
- ModelT binary("test.binary", config);
- enumerate.Check(binary.GetVocabulary());
- Everything(binary);
+ ModelType type;
+ BOOST_REQUIRE(RecognizeBinary("test.binary", type));
+ BOOST_CHECK_EQUAL(ModelT::kModelType, type);
+
+ {
+ ModelT binary("test.binary", config);
+ enumerate.Check(binary.GetVocabulary());
+ Everything(binary);
+ }
unlink("test.binary");
+
+ // Now test without <unk>.
+ config.write_mmap = "test_nounk.binary";
+ config.messages = NULL;
+ enumerate.Clear();
+ {
+ ModelT copy_model("test_nounk.arpa", config);
+ enumerate.Check(copy_model.GetVocabulary());
+ enumerate.Clear();
+ NoUnkCheck(copy_model);
+ }
+ config.write_mmap = NULL;
+ {
+ ModelT binary("test_nounk.binary", config);
+ enumerate.Check(binary.GetVocabulary());
+ NoUnkCheck(binary);
+ }
+ unlink("test_nounk.binary");
}
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
@@ -282,6 +327,12 @@ BOOST_AUTO_TEST_CASE(write_and_read_trie) {
BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) {
BinaryTest<QuantTrieModel>();
}
+BOOST_AUTO_TEST_CASE(write_and_read_array_trie) {
+ BinaryTest<ArrayTrieModel>();
+}
+BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) {
+ BinaryTest<QuantArrayTrieModel>();
+}
} // namespace
} // namespace ngram
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc
index 9454a6d1..d9db4aa2 100644
--- a/klm/lm/ngram_query.cc
+++ b/klm/lm/ngram_query.cc
@@ -99,6 +99,15 @@ int main(int argc, char *argv[]) {
case lm::ngram::TRIE_SORTED:
Query<lm::ngram::TrieModel>(argv[1], sentence_context);
break;
+ case lm::ngram::QUANT_TRIE_SORTED:
+ Query<lm::ngram::QuantTrieModel>(argv[1], sentence_context);
+ break;
+ case lm::ngram::ARRAY_TRIE_SORTED:
+ Query<lm::ngram::ArrayTrieModel>(argv[1], sentence_context);
+ break;
+ case lm::ngram::QUANT_ARRAY_TRIE_SORTED:
+ Query<lm::ngram::QuantArrayTrieModel>(argv[1], sentence_context);
+ break;
case lm::ngram::HASH_SORTED:
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc
index 4bb6b1b8..fd371cc8 100644
--- a/klm/lm/quantize.cc
+++ b/klm/lm/quantize.cc
@@ -43,6 +43,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64
if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1)
UTIL_THROW(util::ErrnoException, "Failed to read header for quantization.");
if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion);
+ AdvanceOrThrow(fd, -3);
}
void SeparatelyQuantize::SetupMemory(void *start, const Config &config) {
diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh
index aae72b34..0b71d14a 100644
--- a/klm/lm/quantize.hh
+++ b/klm/lm/quantize.hh
@@ -21,7 +21,7 @@ class Config;
/* Store values directly and don't quantize. */
class DontQuantize {
public:
- static const ModelType kModelType = TRIE_SORTED;
+ static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
@@ -108,7 +108,7 @@ class SeparatelyQuantize {
};
public:
- static const ModelType kModelType = QUANT_TRIE_SORTED;
+ static const ModelType kModelTypeAdd = kQuantAdd;
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index 060a97ea..455bc4ba 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -31,15 +31,15 @@ const char kBinaryMagic[] = "mmap lm http://kheafield.com/code";
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
number.clear();
StringPiece line;
- if (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
+ while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
+ if (line != "\\data\\") {
if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
}
if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)
UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?");
- UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank");
+ UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\.");
}
- if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\.");
while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");
// So strtol doesn't go off the end of line.
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index c56ba7b8..82c53ec8 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -98,7 +98,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT,
template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {
// TODO: fix sorted.
- SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config);
+ SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config);
PositiveProbWarn warn(config.positive_log_probability);
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index f3acdefc..c62985e4 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -52,12 +52,11 @@ struct HashedSearch {
Unigram unigram;
- bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {
+ void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {
const ProbBackoff &entry = unigram.Lookup(word);
prob = entry.prob;
backoff = entry.backoff;
next = static_cast<Node>(word);
- return true;
}
};
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 91f87f1c..05059ffb 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -1,6 +1,7 @@
/* This is where the trie is built. It's on-disk. */
#include "lm/search_trie.hh"
+#include "lm/bhiksha.hh"
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/max_order.hh"
@@ -543,8 +544,8 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
// In case <unk> appears.
- size_t extra_count = counts[0] + 1;
- util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff));
+ size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff);
+ util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out);
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);
CheckSpecials(config, vocab);
if (!vocab.SawUnk()) ++counts[0];
@@ -610,9 +611,9 @@ class JustCount {
};
// Phase to actually write n-grams to the trie.
-template <class Quant> class WriteEntries {
+template <class Quant, class Bhiksha> class WriteEntries {
public:
- WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :
+ WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :
contexts_(contexts),
unigrams_(unigrams),
middle_(middle),
@@ -649,7 +650,7 @@ template <class Quant> class WriteEntries {
private:
ContextReader *contexts_;
UnigramValue *const unigrams_;
- BitPackedMiddle<typename Quant::Middle> *const middle_;
+ BitPackedMiddle<typename Quant::Middle, Bhiksha> *const middle_;
BitPackedLongest<typename Quant::Longest> &longest_;
BitPacked &bigram_pack_;
};
@@ -821,7 +822,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, So
} // namespace
-template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing) {
+template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
std::vector<SortedFileReader> inputs(counts.size() - 1);
std::vector<ContextReader> contexts(counts.size() - 1);
@@ -846,7 +847,7 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
- out.SetupMemory(GrowForSearch(config, TrieSearch<Quant>::Size(fixed_counts, config), backing), fixed_counts, config);
+ out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
if (Quant::kTrain) {
util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0));
@@ -863,7 +864,7 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto
UnigramValue *unigrams = out.unigram.Raw();
// Fill entries except unigram probabilities.
{
- RecursiveInsert<WriteEntries<Quant> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
+ RecursiveInsert<WriteEntries<Quant, Bhiksha> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
inserter.Apply(config.messages, "Building trie", fixed_counts[0]);
}
@@ -901,14 +902,14 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto
/* Set ending offsets so the last entry will be sized properly */
// Last entry for unigrams was already set.
if (out.middle_begin_ != out.middle_end_) {
- for (typename TrieSearch<Quant>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
- i->FinishedLoading((i+1)->InsertIndex());
+ for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
+ i->FinishedLoading((i+1)->InsertIndex(), config);
}
- (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex());
+ (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config);
}
}
-template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
+template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
quant_.SetupMemory(start, config);
start += Quant::Size(counts.size(), config);
unigram.Init(start);
@@ -919,22 +920,24 @@ template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, c
std::vector<uint8_t*> middle_starts(counts.size() - 2);
for (unsigned char i = 2; i < counts.size(); ++i) {
middle_starts[i-2] = start;
- start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]);
+ start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config);
}
- // Crazy backwards thing so we initialize in the correct order.
+ // Crazy backwards thing so we initialize using pointers to ones that have already been initialized
for (unsigned char i = counts.size() - 1; i >= 2; --i) {
new (middle_begin_ + i - 2) Middle(
middle_starts[i-2],
quant_.Mid(i),
+ counts[i-1],
counts[0],
counts[i],
- (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]));
+ (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]),
+ config);
}
longest.Init(start, quant_.Long(counts.size()), counts[0]);
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
-template <class Quant> void TrieSearch<Quant>::LoadedBinary() {
+template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
unigram.LoadedBinary();
for (Middle *i = middle_begin_; i != middle_end_; ++i) {
i->LoadedBinary();
@@ -942,7 +945,7 @@ template <class Quant> void TrieSearch<Quant>::LoadedBinary() {
longest.LoadedBinary();
}
-template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
std::string temporary_directory;
if (config.temporary_directory_prefix) {
temporary_directory = config.temporary_directory_prefix;
@@ -966,14 +969,16 @@ template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *fi
// At least 1MB sorting memory.
ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
- BuildTrie(temporary_directory, counts, config, *this, quant_, backing);
+ BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing);
if (rmdir(temporary_directory.c_str()) && config.messages) {
*config.messages << "Failed to delete " << temporary_directory << std::endl;
}
}
-template class TrieSearch<DontQuantize>;
-template class TrieSearch<SeparatelyQuantize>;
+template class TrieSearch<DontQuantize, DontBhiksha>;
+template class TrieSearch<DontQuantize, ArrayBhiksha>;
+template class TrieSearch<SeparatelyQuantize, DontBhiksha>;
+template class TrieSearch<SeparatelyQuantize, ArrayBhiksha>;
} // namespace trie
} // namespace ngram
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 0a52acb5..2f39c09f 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -13,31 +13,33 @@ struct Backing;
class SortedVocabulary;
namespace trie {
-template <class Quant> class TrieSearch;
-template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing);
+template <class Quant, class Bhiksha> class TrieSearch;
+template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
-template <class Quant> class TrieSearch {
+template <class Quant, class Bhiksha> class TrieSearch {
public:
typedef NodeRange Node;
typedef ::lm::ngram::trie::Unigram Unigram;
Unigram unigram;
- typedef trie::BitPackedMiddle<typename Quant::Middle> Middle;
+ typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;
typedef trie::BitPackedLongest<typename Quant::Longest> Longest;
Longest longest;
- static const ModelType kModelType = Quant::kModelType;
+ static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
Quant::UpdateConfigFromBinary(fd, counts, config);
+ AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
+ Bhiksha::UpdateConfigFromBinary(fd, config);
}
static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
- ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]);
+ ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
}
return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
@@ -55,8 +57,8 @@ template <class Quant> class TrieSearch {
void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
- bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
- return unigram.Find(word, prob, backoff, node);
+ void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
+ unigram.Find(word, prob, backoff, node);
}
bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
@@ -83,7 +85,7 @@ template <class Quant> class TrieSearch {
}
private:
- friend void BuildTrie<Quant>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing);
+ friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
diff --git a/klm/lm/test_nounk.arpa b/klm/lm/test_nounk.arpa
new file mode 100644
index 00000000..060733d9
--- /dev/null
+++ b/klm/lm/test_nounk.arpa
@@ -0,0 +1,120 @@
+
+\data\
+ngram 1=36
+ngram 2=45
+ngram 3=10
+ngram 4=6
+ngram 5=4
+
+\1-grams:
+-1.383514 , -0.30103
+-1.139057 . -0.845098
+-1.029493 </s>
+-99 <s> -0.4149733
+-1.285941 a -0.69897
+-1.687872 also -0.30103
+-1.687872 beyond -0.30103
+-1.687872 biarritz -0.30103
+-1.687872 call -0.30103
+-1.687872 concerns -0.30103
+-1.687872 consider -0.30103
+-1.687872 considering -0.30103
+-1.687872 for -0.30103
+-1.509559 higher -0.30103
+-1.687872 however -0.30103
+-1.687872 i -0.30103
+-1.687872 immediate -0.30103
+-1.687872 in -0.30103
+-1.687872 is -0.30103
+-1.285941 little -0.69897
+-1.383514 loin -0.30103
+-1.687872 look -0.30103
+-1.285941 looking -0.4771212
+-1.206319 more -0.544068
+-1.509559 on -0.4771212
+-1.509559 screening -0.4771212
+-1.687872 small -0.30103
+-1.687872 the -0.30103
+-1.687872 to -0.30103
+-1.687872 watch -0.30103
+-1.687872 watching -0.30103
+-1.687872 what -0.30103
+-1.687872 would -0.30103
+-3.141592 foo
+-2.718281 bar 3.0
+-6.535897 baz -0.0
+
+\2-grams:
+-0.6925742 , .
+-0.7522095 , however
+-0.7522095 , is
+-0.0602359 . </s>
+-0.4846522 <s> looking -0.4771214
+-1.051485 <s> screening
+-1.07153 <s> the
+-1.07153 <s> watching
+-1.07153 <s> what
+-0.09132547 a little -0.69897
+-0.2922095 also call
+-0.2922095 beyond immediate
+-0.2705918 biarritz .
+-0.2922095 call for
+-0.2922095 concerns in
+-0.2922095 consider watch
+-0.2922095 considering consider
+-0.2834328 for ,
+-0.5511513 higher more
+-0.5845945 higher small
+-0.2834328 however ,
+-0.2922095 i would
+-0.2922095 immediate concerns
+-0.2922095 in biarritz
+-0.2922095 is to
+-0.09021038 little more -0.1998621
+-0.7273645 loin ,
+-0.6925742 loin .
+-0.6708385 loin </s>
+-0.2922095 look beyond
+-0.4638903 looking higher
+-0.4638903 looking on -0.4771212
+-0.5136299 more . -0.4771212
+-0.3561665 more loin
+-0.1649931 on a -0.4771213
+-0.1649931 screening a -0.4771213
+-0.2705918 small .
+-0.287799 the screening
+-0.2922095 to look
+-0.2622373 watch </s>
+-0.2922095 watching considering
+-0.2922095 what i
+-0.2922095 would also
+-2 also would -6
+-6 foo bar
+
+\3-grams:
+-0.01916512 more . </s>
+-0.0283603 on a little -0.4771212
+-0.0283603 screening a little -0.4771212
+-0.01660496 a little more -0.09409451
+-0.3488368 <s> looking higher
+-0.3488368 <s> looking on -0.4771212
+-0.1892331 little more loin
+-0.04835128 looking on a -0.4771212
+-3 also would consider -7
+-7 to look good
+
+\4-grams:
+-0.009249173 looking on a little -0.4771212
+-0.005464747 on a little more -0.4771212
+-0.005464747 screening a little more
+-0.1453306 a little more loin
+-0.01552657 <s> looking on a -0.4771212
+-4 also would consider higher -8
+
+\5-grams:
+-0.003061223 <s> looking on a little
+-0.001813953 looking on a little more
+-0.0432557 on a little more loin
+-5 also would consider higher looking
+
+\end\
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 63c2a612..8c536e66 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -1,5 +1,6 @@
#include "lm/trie.hh"
+#include "lm/bhiksha.hh"
#include "lm/quantize.hh"
#include "util/bit_packing.hh"
#include "util/exception.hh"
@@ -57,16 +58,21 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)
max_vocab_ = max_vocab;
}
-template <class Quant> std::size_t BitPackedMiddle<Quant>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
- return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr));
+template <class Quant, class Bhiksha> std::size_t BitPackedMiddle<Quant, Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
+ return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
}
-template <class Quant> BitPackedMiddle<Quant>::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) {
- if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
- BaseInit(base, max_vocab, quant.TotalBits() + next_bits_);
+template <class Quant, class Bhiksha> BitPackedMiddle<Quant, Bhiksha>::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :
+ BitPacked(),
+ quant_(quant),
+ // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary.
+ bhiksha_(base, entries + 1, max_next, config),
+ next_source_(&next_source) {
+ if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
+ BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits());
}
-template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float prob, float backoff) {
+template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::Insert(WordIndex word, float prob, float backoff) {
assert(word <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
@@ -75,47 +81,42 @@ template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float
quant_.Write(base_, at_pointer, prob, backoff);
at_pointer += quant_.TotalBits();
uint64_t next = next_source_->InsertIndex();
- assert(next <= next_mask_);
- util::WriteInt57(base_, at_pointer, next_bits_, next);
+ bhiksha_.WriteNext(base_, at_pointer, insert_index_, next);
++insert_index_;
}
-template <class Quant> bool BitPackedMiddle<Quant>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
+template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
return false;
}
+ uint64_t index = at_pointer;
at_pointer *= total_bits_;
at_pointer += word_bits_;
quant_.Read(base_, at_pointer, prob, backoff);
at_pointer += quant_.TotalBits();
- range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
- // Read the next entry's pointer.
- at_pointer += total_bits_;
- range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
+ bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);
+
return true;
}
-template <class Quant> bool BitPackedMiddle<Quant>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
- uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
- at_pointer *= total_bits_;
+template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
+ uint64_t index;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false;
+ uint64_t at_pointer = index * total_bits_;
at_pointer += word_bits_;
quant_.ReadBackoff(base_, at_pointer, backoff);
at_pointer += quant_.TotalBits();
- range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
- // Read the next entry's pointer.
- at_pointer += total_bits_;
- range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
+ bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);
return true;
}
-template <class Quant> void BitPackedMiddle<Quant>::FinishedLoading(uint64_t next_end) {
- assert(next_end <= next_mask_);
- uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_;
- util::WriteInt57(base_, last_next_write, next_bits_, next_end);
+template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
+ uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits();
+ bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end);
+ bhiksha_.FinishedLoading(config);
}
template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) {
@@ -135,8 +136,10 @@ template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float
return true;
}
-template class BitPackedMiddle<DontQuantize::Middle>;
-template class BitPackedMiddle<SeparatelyQuantize::Middle>;
+template class BitPackedMiddle<DontQuantize::Middle, DontBhiksha>;
+template class BitPackedMiddle<DontQuantize::Middle, ArrayBhiksha>;
+template class BitPackedMiddle<SeparatelyQuantize::Middle, DontBhiksha>;
+template class BitPackedMiddle<SeparatelyQuantize::Middle, ArrayBhiksha>;
template class BitPackedLongest<DontQuantize::Longest>;
template class BitPackedLongest<SeparatelyQuantize::Longest>;
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index 8fa21aaf..53612064 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -10,6 +10,7 @@
namespace lm {
namespace ngram {
+class Config;
namespace trie {
struct NodeRange {
@@ -46,13 +47,12 @@ class Unigram {
void LoadedBinary() {}
- bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
+ void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
prob = val->weights.prob;
backoff = val->weights.backoff;
next.begin = val->next;
next.end = (val+1)->next;
- return true;
}
private:
@@ -67,8 +67,6 @@ class BitPacked {
return insert_index_;
}
- void LoadedBinary() {}
-
protected:
static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
@@ -83,30 +81,30 @@ class BitPacked {
uint64_t insert_index_, max_vocab_;
};
-template <class Quant> class BitPackedMiddle : public BitPacked {
+template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked {
public:
- static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next);
+ static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
// next_source need not be initialized.
- BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source);
+ BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
void Insert(WordIndex word, float prob, float backoff);
+ void FinishedLoading(uint64_t next_end, const Config &config);
+
+ void LoadedBinary() { bhiksha_.LoadedBinary(); }
+
bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;
bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const;
- void FinishedLoading(uint64_t next_end);
-
private:
Quant quant_;
- uint8_t next_bits_;
- uint64_t next_mask_;
+ Bhiksha bhiksha_;
const BitPacked *next_source_;
};
-
template <class Quant> class BitPackedLongest : public BitPacked {
public:
static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
@@ -120,6 +118,8 @@ template <class Quant> class BitPackedLongest : public BitPacked {
BaseInit(base, max_vocab, quant_.TotalBits());
}
+ void LoadedBinary() {}
+
void Insert(WordIndex word, float prob);
bool Find(WordIndex word, float &prob, const NodeRange &node) const;
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 7defd5c1..04979d51 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -37,14 +37,14 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {
WordIndex index = 0;
while (true) {
ssize_t got = read(fd, &buf[0], kInitialRead);
- if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words");
+ UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words");
if (got == 0) return index;
buf.resize(got);
while (buf[buf.size() - 1]) {
char next_char;
ssize_t ret = read(fd, &next_char, 1);
- if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words");
- if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word.");
+ UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words");
+ UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word.");
buf.push_back(next_char);
}
// Ok now we have null terminated strings.
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index c92518e4..9d218fff 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -61,6 +61,7 @@ class SortedVocabulary : public base::Vocabulary {
}
}
+ // Size for purposes of file writing
static size_t Size(std::size_t entries, const Config &config);
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
@@ -77,6 +78,9 @@ class SortedVocabulary : public base::Vocabulary {
// Reorders reorder_vocab so that the IDs are sorted.
void FinishedLoading(ProbBackoff *reorder_vocab);
+ // Trie stores the correct counts including <unk> in the header. If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>.
+ std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); }
+
bool SawUnk() const { return saw_unk_; }
void LoadedBinary(int fd, EnumerateVocab *to);