summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/Jamfile14
-rw-r--r--klm/lm/Makefile.am41
-rw-r--r--klm/lm/bhiksha.cc2
-rw-r--r--klm/lm/bhiksha.hh4
-rw-r--r--klm/lm/binary_format.cc41
-rw-r--r--klm/lm/binary_format.hh4
-rw-r--r--klm/lm/build_binary_main.cc (renamed from klm/lm/build_binary.cc)80
-rw-r--r--klm/lm/builder/Makefile.am28
-rw-r--r--klm/lm/builder/README.md47
-rw-r--r--klm/lm/builder/TODO5
-rw-r--r--klm/lm/builder/adjust_counts.cc216
-rw-r--r--klm/lm/builder/adjust_counts.hh44
-rw-r--r--klm/lm/builder/adjust_counts_test.cc106
-rw-r--r--klm/lm/builder/corpus_count.cc224
-rw-r--r--klm/lm/builder/corpus_count.hh42
-rw-r--r--klm/lm/builder/corpus_count_test.cc76
-rw-r--r--klm/lm/builder/discount.hh26
-rw-r--r--klm/lm/builder/header_info.hh20
-rw-r--r--klm/lm/builder/initial_probabilities.cc136
-rw-r--r--klm/lm/builder/initial_probabilities.hh34
-rw-r--r--klm/lm/builder/interpolate.cc65
-rw-r--r--klm/lm/builder/interpolate.hh27
-rw-r--r--klm/lm/builder/joint_order.hh43
-rw-r--r--klm/lm/builder/lmplz_main.cc94
-rw-r--r--klm/lm/builder/multi_stream.hh180
-rw-r--r--klm/lm/builder/ngram.hh84
-rw-r--r--klm/lm/builder/ngram_stream.hh55
-rw-r--r--klm/lm/builder/pipeline.cc320
-rw-r--r--klm/lm/builder/pipeline.hh40
-rw-r--r--klm/lm/builder/print.cc135
-rw-r--r--klm/lm/builder/print.hh102
-rw-r--r--klm/lm/builder/sort.hh103
-rw-r--r--klm/lm/config.cc1
-rw-r--r--klm/lm/config.hh59
-rw-r--r--klm/lm/filter/arpa_io.cc108
-rw-r--r--klm/lm/filter/arpa_io.hh115
-rw-r--r--klm/lm/filter/count_io.hh91
-rw-r--r--klm/lm/filter/filter_main.cc248
-rw-r--r--klm/lm/filter/format.hh250
-rw-r--r--klm/lm/filter/phrase.cc281
-rw-r--r--klm/lm/filter/phrase.hh154
-rw-r--r--klm/lm/filter/thread.hh167
-rw-r--r--klm/lm/filter/vocab.cc54
-rw-r--r--klm/lm/filter/vocab.hh133
-rw-r--r--klm/lm/filter/wrapper.hh58
-rw-r--r--klm/lm/fragment_main.cc37
-rw-r--r--klm/lm/kenlm_max_order_main.cc (renamed from klm/lm/max_order.cc)0
-rw-r--r--klm/lm/left.hh66
-rw-r--r--klm/lm/max_order.hh5
-rw-r--r--klm/lm/model.cc54
-rw-r--r--klm/lm/model.hh2
-rw-r--r--klm/lm/model_test.cc10
-rw-r--r--klm/lm/partial.hh167
-rw-r--r--klm/lm/partial_test.cc199
-rw-r--r--klm/lm/quantize.hh8
-rw-r--r--klm/lm/query_main.cc (renamed from klm/lm/ngram_query.cc)0
-rw-r--r--klm/lm/read_arpa.cc33
-rw-r--r--klm/lm/search_hashed.cc10
-rw-r--r--klm/lm/search_hashed.hh10
-rw-r--r--klm/lm/search_trie.cc49
-rw-r--r--klm/lm/search_trie.hh4
-rw-r--r--klm/lm/sizes.cc63
-rw-r--r--klm/lm/sizes.hh17
-rw-r--r--klm/lm/sri_test.cc65
-rw-r--r--klm/lm/state.hh8
-rw-r--r--klm/lm/trie.cc4
-rw-r--r--klm/lm/trie.hh8
-rw-r--r--klm/lm/trie_sort.cc47
-rw-r--r--klm/lm/trie_sort.hh5
-rw-r--r--klm/lm/vocab.cc15
-rw-r--r--klm/lm/vocab.hh11
-rw-r--r--klm/lm/word_index.hh3
72 files changed, 4711 insertions, 346 deletions
diff --git a/klm/lm/Jamfile b/klm/lm/Jamfile
deleted file mode 100644
index b1971d88..00000000
--- a/klm/lm/Jamfile
+++ /dev/null
@@ -1,14 +0,0 @@
-lib kenlm : bhiksha.cc binary_format.cc config.cc lm_exception.cc model.cc quantize.cc read_arpa.cc search_hashed.cc search_trie.cc trie.cc trie_sort.cc value_build.cc virtual_interface.cc vocab.cc ../util//kenutil : <include>.. : : <include>.. <library>../util//kenutil ;
-
-import testing ;
-
-run left_test.cc ../util//kenutil kenlm ../..//boost_unit_test_framework : : test.arpa ;
-run model_test.cc ../util//kenutil kenlm ../..//boost_unit_test_framework : : test.arpa test_nounk.arpa ;
-
-exe query : ngram_query.cc kenlm ../util//kenutil ;
-exe build_binary : build_binary.cc kenlm ../util//kenutil ;
-
-install legacy : build_binary query
- : <location>$(TOP)/klm/lm <install-type>EXE <install-dependencies>on <link>shared:<dll-path>$(TOP)/klm/lm <link>shared:<install-type>LIB ;
-
-alias programs : build_binary query ;
diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am
index a12c5f03..48b0ba34 100644
--- a/klm/lm/Makefile.am
+++ b/klm/lm/Makefile.am
@@ -1,7 +1,10 @@
-bin_PROGRAMS = build_binary
+bin_PROGRAMS = build_binary ngram_query
-build_binary_SOURCES = build_binary.cc
-build_binary_LDADD = libklm.a ../util/libklm_util.a -lz
+build_binary_SOURCES = build_binary_main.cc
+build_binary_LDADD = libklm.a ../util/libklm_util.a ../util/double-conversion/libklm_util_double.a -lz
+
+ngram_query_SOURCES = query_main.cc
+ngram_query_LDADD = libklm.a ../util/libklm_util.a ../util/double-conversion/libklm_util_double.a -lz
#noinst_PROGRAMS = \
# ngram_test
@@ -12,21 +15,49 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz
noinst_LIBRARIES = libklm.a
libklm_a_SOURCES = \
+ bhiksha.hh \
+ binary_format.hh \
+ blank.hh \
+ config.hh \
+ enumerate_vocab.hh \
+ facade.hh \
+ left.hh \
+ lm_exception.hh \
+ max_order.hh \
+ model.hh \
+ model_type.hh \
+ ngram_query.hh \
+ partial.hh \
+ quantize.hh \
+ read_arpa.hh \
+ return.hh \
+ search_hashed.hh \
+ search_trie.hh \
+ sizes.hh \
+ state.hh \
+ trie.hh \
+ trie_sort.hh \
+ value.hh \
+ value_build.hh \
+ virtual_interface.hh \
+ vocab.hh \
+ weights.hh \
+ word_index.hh \
bhiksha.cc \
binary_format.cc \
config.cc \
lm_exception.cc \
quantize.cc \
model.cc \
- ngram_query.cc \
read_arpa.cc \
search_hashed.cc \
search_trie.cc \
+ sizes.cc \
trie.cc \
trie_sort.cc \
value_build.cc \
virtual_interface.cc \
vocab.cc
-AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm
diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc
index 870a4eee..088ea98d 100644
--- a/klm/lm/bhiksha.cc
+++ b/klm/lm/bhiksha.cc
@@ -50,7 +50,7 @@ std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &con
}
} // namespace
-std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) {
+uint64_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) {
return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */;
}
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh
index 9734f3ab..8ff88654 100644
--- a/klm/lm/bhiksha.hh
+++ b/klm/lm/bhiksha.hh
@@ -33,7 +33,7 @@ class DontBhiksha {
static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {}
- static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
+ static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) {
return util::RequiredBits(max_next);
@@ -67,7 +67,7 @@ class ArrayBhiksha {
static void UpdateConfigFromBinary(int fd, Config &config);
- static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
+ static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config);
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index a56e998e..39c4a9b6 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -16,11 +16,11 @@ namespace ngram {
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
-// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
+// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
const long int kMagicVersion = 5;
-// Old binary files built on 32-bit machines have this header.
+// Old binary files built on 32-bit machines have this header.
// TODO: eliminate with next binary release.
struct OldSanity {
char magic[sizeof(kMagicBytes)];
@@ -39,7 +39,7 @@ struct OldSanity {
};
-// Test values aligned to 8 bytes.
+// Test values aligned to 8 bytes.
struct Sanity {
char magic[ALIGN8(sizeof(kMagicBytes))];
float zero_f, one_f, minus_half_f;
@@ -83,7 +83,13 @@ void WriteHeader(void *to, const Parameters &params) {
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
std::size_t total = TotalHeaderSize(order) + memory_size;
- backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED);
+ backing.file.reset(util::CreateOrThrow(config.write_mmap));
+ if (config.write_method == Config::WRITE_MMAP) {
+ backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
+ } else {
+ util::ResizeOrThrow(backing.file.get(), 0);
+ util::MapAnonymous(total, backing.vocab);
+ }
strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
} else {
@@ -95,7 +101,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
if (config.write_mmap) {
- // Grow the file to accomodate the search, using zeros.
+ // Grow the file to accomodate the search, using zeros.
try {
util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
} catch (util::ErrnoException &e) {
@@ -108,7 +114,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
return reinterpret_cast<uint8_t*>(backing.search.get());
}
// mmap it now.
- // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
+ // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
std::size_t page_size = util::SizePage();
std::size_t alignment_cruft = adjusted_vocab % page_size;
backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
@@ -116,23 +122,25 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
} else {
util::MapAnonymous(memory_size, backing.search);
return reinterpret_cast<uint8_t*>(backing.search.get());
- }
+ }
}
void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
if (!config.write_mmap) return;
- util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
switch (config.write_method) {
case Config::WRITE_MMAP:
+ util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
util::SyncOrThrow(backing.search.get(), backing.search.size());
break;
case Config::WRITE_AFTER:
+ util::SeekOrThrow(backing.file.get(), 0);
+ util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size());
util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
util::FSyncOrThrow(backing.file.get());
break;
}
- // header and vocab share the same mmap. The header is written here because we know the counts.
+ // header and vocab share the same mmap. The header is written here because we know the counts.
Parameters params = Parameters();
params.counts = counts;
params.fixed.order = counts.size();
@@ -141,6 +149,10 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_
params.fixed.has_vocabulary = config.include_vocab;
params.fixed.search_version = search_version;
WriteHeader(backing.vocab.get(), params);
+ if (config.write_method == Config::WRITE_AFTER) {
+ util::SeekOrThrow(backing.file.get(), 0);
+ util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size()));
+ }
}
namespace detail {
@@ -148,7 +160,7 @@ namespace detail {
bool IsBinaryFormat(int fd) {
const uint64_t size = util::SizeFile(fd);
if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
- // Try reading the header.
+ // Try reading the header.
util::scoped_memory memory;
try {
util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory);
@@ -200,10 +212,10 @@ void SeekPastHeader(int fd, const Parameters &params) {
util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
}
-uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing) {
+uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing) {
const uint64_t file_size = util::SizeFile(backing.file.get());
- // The header is smaller than a page, so we have to map the whole header as well.
- std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size;
+ // The header is smaller than a page, so we have to map the whole header as well.
+ std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
@@ -221,7 +233,8 @@ void ComplainAboutARPA(const Config &config, ModelType model_type) {
if (config.write_mmap || !config.messages) return;
if (config.arpa_complain == Config::ALL) {
*config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
- } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) {
+ } else if (config.arpa_complain == Config::EXPENSIVE &&
+ (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
*config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
}
}
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh
index dd795f62..bf699d5f 100644
--- a/klm/lm/binary_format.hh
+++ b/klm/lm/binary_format.hh
@@ -70,7 +70,7 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet
void SeekPastHeader(int fd, const Parameters &params);
-uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing);
+uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing);
void ComplainAboutARPA(const Config &config, ModelType model_type);
@@ -90,7 +90,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)
new_config.probing_multiplier = params.fixed.probing_multiplier;
detail::SeekPastHeader(backing.file.get(), params);
To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config);
- std::size_t memory_size = To::Size(params.counts, new_config);
+ uint64_t memory_size = To::Size(params.counts, new_config);
uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing);
to.InitializeFromBinary(start, params, new_config, backing.file.get());
} else {
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary_main.cc
index c2ca1101..ab2c0c32 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary_main.cc
@@ -1,25 +1,30 @@
#include "lm/model.hh"
+#include "lm/sizes.hh"
#include "util/file_piece.hh"
+#include "util/usage.hh"
+#include <algorithm>
#include <cstdlib>
#include <exception>
#include <iostream>
#include <iomanip>
+#include <limits>
#include <math.h>
#include <stdlib.h>
-#include <unistd.h>
#ifdef WIN32
#include "util/getopt.hh"
+#else
+#include <unistd.h>
#endif
namespace lm {
namespace ngram {
namespace {
-void Usage(const char *name) {
- std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
+void Usage(const char *name, const char *default_mem) {
+ std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
" Default is -100. The ARPA file will always take precedence.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n"
@@ -37,8 +42,11 @@ void Usage(const char *name) {
"trie is a straightforward trie with bit-level packing. It uses the least\n"
"memory and is still faster than SRI or IRST. Building the trie format uses an\n"
"on-disk sort to save memory.\n"
-"-t is the temporary directory prefix. Default is the output file name.\n"
-"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n"
+"-T is the temporary directory prefix. Default is the output file name.\n"
+"-S determines memory use for sorting. Default is " << default_mem << ". This is compatible\n"
+" with GNU sort. The number is followed by a unit: \% for percent of physical\n"
+" memory, b for bytes, K for Kilobytes, M for megabytes, then G,T,P,E,Z,Y. \n"
+" Default unit is K for Kilobytes.\n"
"-q turns quantization on and sets the number of bits (e.g. -q 8).\n"
"-b sets backoff quantization bits. Requires -q and defaults to that value.\n"
"-a compresses pointers using an array of offsets. The parameter is the\n"
@@ -82,47 +90,6 @@ void ParseFileList(const char *from, std::vector<std::string> &to) {
}
}
-void ShowSizes(const char *file, const lm::ngram::Config &config) {
- std::vector<uint64_t> counts;
- util::FilePiece f(file);
- lm::ReadARPACounts(f, counts);
- std::size_t sizes[6];
- sizes[0] = ProbingModel::Size(counts, config);
- sizes[1] = RestProbingModel::Size(counts, config);
- sizes[2] = TrieModel::Size(counts, config);
- sizes[3] = QuantTrieModel::Size(counts, config);
- sizes[4] = ArrayTrieModel::Size(counts, config);
- sizes[5] = QuantArrayTrieModel::Size(counts, config);
- std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
- std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
- std::size_t divide;
- char prefix;
- if (min_length < (1 << 10) * 10) {
- prefix = ' ';
- divide = 1;
- } else if (min_length < (1 << 20) * 10) {
- prefix = 'k';
- divide = 1 << 10;
- } else if (min_length < (1ULL << 30) * 10) {
- prefix = 'M';
- divide = 1 << 20;
- } else {
- prefix = 'G';
- divide = 1 << 30;
- }
- long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide))));
- std::cout << "Memory estimate:\ntype ";
- // right align bytes.
- for (long int i = 0; i < length - 2; ++i) std::cout << ' ';
- std::cout << prefix << "B\n"
- "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"
- "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n"
- "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n"
- "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
- "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
- "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";
-}
-
void ProbingQuantizationUnsupported() {
std::cerr << "Quantization is only implemented in the trie data structure." << std::endl;
exit(1);
@@ -135,11 +102,14 @@ void ProbingQuantizationUnsupported() {
int main(int argc, char *argv[]) {
using namespace lm::ngram;
+ const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G";
+
try {
bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false;
lm::ngram::Config config;
+ config.building_memory = util::ParseSize(default_mem);
int opt;
- while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:sir:")) != -1) {
+ while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
@@ -160,12 +130,16 @@ int main(int argc, char *argv[]) {
case 'p':
config.probing_multiplier = ParseFloat(optarg);
break;
- case 't':
+ case 't': // legacy
+ case 'T':
config.temporary_directory_prefix = optarg;
break;
- case 'm':
+ case 'm': // legacy
config.building_memory = ParseUInt(optarg) * 1048576;
break;
+ case 'S':
+ config.building_memory = std::min(static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), util::ParseSize(optarg));
+ break;
case 'w':
set_write_method = true;
if (!strcmp(optarg, "mmap")) {
@@ -173,7 +147,7 @@ int main(int argc, char *argv[]) {
} else if (!strcmp(optarg, "after")) {
config.write_method = Config::WRITE_AFTER;
} else {
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
break;
case 's':
@@ -188,7 +162,7 @@ int main(int argc, char *argv[]) {
config.rest_function = Config::REST_LOWER;
break;
default:
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
}
if (!quantize && set_backoff_bits) {
@@ -211,7 +185,7 @@ int main(int argc, char *argv[]) {
from_file = argv[optind + 1];
config.write_mmap = argv[optind + 2];
} else {
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
if (!strcmp(model_type, "probing")) {
if (!set_write_method) config.write_method = Config::WRITE_AFTER;
@@ -241,7 +215,7 @@ int main(int argc, char *argv[]) {
}
}
} else {
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
}
catch (const std::exception &e) {
diff --git a/klm/lm/builder/Makefile.am b/klm/lm/builder/Makefile.am
new file mode 100644
index 00000000..317e03ce
--- /dev/null
+++ b/klm/lm/builder/Makefile.am
@@ -0,0 +1,28 @@
+bin_PROGRAMS = builder
+
+builder_SOURCES = \
+ lmplz_main.cc \
+ adjust_counts.cc \
+ adjust_counts.hh \
+ corpus_count.cc \
+ corpus_count.hh \
+ discount.hh \
+ header_info.hh \
+ initial_probabilities.cc \
+ initial_probabilities.hh \
+ interpolate.cc \
+ interpolate.hh \
+ joint_order.hh \
+ multi_stream.hh \
+ ngram.hh \
+ ngram_stream.hh \
+ pipeline.cc \
+ pipeline.hh \
+ print.cc \
+ print.hh \
+ sort.hh
+
+builder_LDADD = ../libklm.a ../../util/double-conversion/libklm_util_double.a ../../util/stream/libklm_util_stream.a ../../util/libklm_util.a $(BOOST_THREAD_LIBS)
+
+AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm
+
diff --git a/klm/lm/builder/README.md b/klm/lm/builder/README.md
new file mode 100644
index 00000000..be0d35e2
--- /dev/null
+++ b/klm/lm/builder/README.md
@@ -0,0 +1,47 @@
+Dependencies
+============
+
+Boost >= 1.42.0 is required.
+
+For Ubuntu,
+```bash
+sudo apt-get install libboost1.48-all-dev
+```
+
+Alternatively, you can download, compile, and install it yourself:
+
+```bash
+wget http://sourceforge.net/projects/boost/files/boost/1.52.0/boost_1_52_0.tar.gz/download -O boost_1_52_0.tar.gz
+tar -xvzf boost_1_52_0.tar.gz
+cd boost_1_52_0
+./bootstrap.sh
+./b2
+sudo ./b2 install
+```
+
+Local install options (in a user-space prefix directory) are also possible. See http://www.boost.org/doc/libs/1_52_0/doc/html/bbv2/installation.html.
+
+
+Building
+========
+
+```bash
+bjam
+```
+Your distribution might package bjam and boost-build separately from Boost. Both are required.
+
+Usage
+=====
+
+Run
+```bash
+$ bin/lmplz
+```
+to see command line arguments
+
+Running
+=======
+
+```bash
+bin/lmplz -o 5 <text >text.arpa
+```
diff --git a/klm/lm/builder/TODO b/klm/lm/builder/TODO
new file mode 100644
index 00000000..cb5aef3a
--- /dev/null
+++ b/klm/lm/builder/TODO
@@ -0,0 +1,5 @@
+More tests!
+Sharding.
+Some way to manage all the crazy config options.
+Option to build the binary file directly.
+Interpolation of different orders.
diff --git a/klm/lm/builder/adjust_counts.cc b/klm/lm/builder/adjust_counts.cc
new file mode 100644
index 00000000..a6f48011
--- /dev/null
+++ b/klm/lm/builder/adjust_counts.cc
@@ -0,0 +1,216 @@
+#include "lm/builder/adjust_counts.hh"
+#include "lm/builder/multi_stream.hh"
+#include "util/stream/timer.hh"
+
+#include <algorithm>
+
+namespace lm { namespace builder {
+
+BadDiscountException::BadDiscountException() throw() {}
+BadDiscountException::~BadDiscountException() throw() {}
+
+namespace {
+// Return last word in full that is different.
+const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) {
+ const WordIndex *cur_word = full.end() - 1;
+ const WordIndex *pre_word = lower_last.end() - 1;
+ // Find last difference.
+ for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {}
+ return cur_word;
+}
+
+class StatCollector {
+ public:
+ StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts)
+ : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) {
+ memset(&orders_[0], 0, sizeof(OrderStat) * order);
+ }
+
+ ~StatCollector() {}
+
+ void CalculateDiscounts() {
+ counts_.resize(orders_.size());
+ discounts_.resize(orders_.size());
+ for (std::size_t i = 0; i < orders_.size(); ++i) {
+ const OrderStat &s = orders_[i];
+ counts_[i] = s.count;
+
+ for (unsigned j = 1; j < 4; ++j) {
+ // TODO: Specialize error message for j == 3, meaning 3+
+ UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
+ << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
+ << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?");
+ }
+
+ // See equation (26) in Chen and Goodman.
+ discounts_[i].amount[0] = 0.0;
+ float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
+ for (unsigned j = 1; j < 4; ++j) {
+ discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
+ UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
+ }
+ }
+ }
+
+ void Add(std::size_t order_minus_1, uint64_t count) {
+ OrderStat &stat = orders_[order_minus_1];
+ ++stat.count;
+ if (count < 5) ++stat.n[count];
+ }
+
+ void AddFull(uint64_t count) {
+ ++full_.count;
+ if (count < 5) ++full_.n[count];
+ }
+
+ private:
+ struct OrderStat {
+ // n_1 in equation 26 of Chen and Goodman etc
+ uint64_t n[5];
+ uint64_t count;
+ };
+
+ std::vector<OrderStat> orders_;
+ OrderStat &full_;
+
+ std::vector<uint64_t> &counts_;
+ std::vector<Discount> &discounts_;
+};
+
+// Reads all entries in order like NGramStream does.
+// But deletes any entries that have <s> in the 1st (not 0th) position on the
+// way out by putting other entries in their place. This disrupts the sort
+// order but we don't care because the data is going to be sorted again.
+class CollapseStream {
+ public:
+ CollapseStream(const util::stream::ChainPosition &position) :
+ current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
+ block_(position) {
+ StartBlock();
+ }
+
+ const NGram &operator*() const { return current_; }
+ const NGram *operator->() const { return &current_; }
+
+ operator bool() const { return block_; }
+
+ CollapseStream &operator++() {
+ assert(block_);
+ if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) {
+ memcpy(current_.Base(), copy_from_, current_.TotalSize());
+ UpdateCopyFrom();
+ }
+ current_.NextInMemory();
+ uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
+ if (current_.Base() == block_base + block_->ValidSize()) {
+ block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base);
+ ++block_;
+ StartBlock();
+ }
+ return *this;
+ }
+
+ private:
+ void StartBlock() {
+ for (; ; ++block_) {
+ if (!block_) return;
+ if (block_->ValidSize()) break;
+ }
+ current_.ReBase(block_->Get());
+ copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize();
+ UpdateCopyFrom();
+ }
+
+ // Find last without bos.
+ void UpdateCopyFrom() {
+ for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
+ if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break;
+ }
+ }
+
+ NGram current_;
+
+ // Goes backwards in the block
+ uint8_t *copy_from_;
+
+ util::stream::Link block_;
+};
+
+} // namespace
+
+void AdjustCounts::Run(const ChainPositions &positions) {
+ UTIL_TIMER("(%w s) Adjusted counts\n");
+
+ const std::size_t order = positions.size();
+ StatCollector stats(order, counts_, discounts_);
+ if (order == 1) {
+ // Only unigrams. Just collect stats.
+ for (NGramStream full(positions[0]); full; ++full)
+ stats.AddFull(full->Count());
+ stats.CalculateDiscounts();
+ return;
+ }
+
+ NGramStreams streams;
+ streams.Init(positions, positions.size() - 1);
+ CollapseStream full(positions[positions.size() - 1]);
+
+ // Initialization: <unk> has count 0 and so does <s>.
+ NGramStream *lower_valid = streams.begin();
+ streams[0]->Count() = 0;
+ *streams[0]->begin() = kUNK;
+ stats.Add(0, 0);
+ (++streams[0])->Count() = 0;
+ *streams[0]->begin() = kBOS;
+ // not in stats because it will get put in later.
+
+ // iterate over full (the stream of the highest order ngrams)
+ for (; full; ++full) {
+ const WordIndex *different = FindDifference(*full, **lower_valid);
+ std::size_t same = full->end() - 1 - different;
+ // Increment the adjusted count.
+ if (same) ++streams[same - 1]->Count();
+
+ // Output all the valid ones that changed.
+ for (; lower_valid >= &streams[same]; --lower_valid) {
+ stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count());
+ ++*lower_valid;
+ }
+
+ // This is here because bos is also const WordIndex *, so copy gets
+ // consistent argument types.
+ const WordIndex *full_end = full->end();
+ // Initialize and mark as valid up to bos.
+ const WordIndex *bos;
+ for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
+ ++lower_valid;
+ std::copy(bos, full_end, (*lower_valid)->begin());
+ (*lower_valid)->Count() = 1;
+ }
+ // Now bos indicates where <s> is or is the 0th word of full.
+ if (bos != full->begin()) {
+ // There is an <s> beyond the 0th word.
+ NGramStream &to = *++lower_valid;
+ std::copy(bos, full_end, to->begin());
+ to->Count() = full->Count();
+ } else {
+ stats.AddFull(full->Count());
+ }
+ assert(lower_valid >= &streams[0]);
+ }
+
+ // Output everything valid.
+ for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
+ stats.Add(s - streams.begin(), (*s)->Count());
+ ++*s;
+ }
+ // Poison everyone! Except the N-grams which were already poisoned by the input.
+ for (NGramStream *s = streams.begin(); s != streams.end(); ++s)
+ s->Poison();
+
+ stats.CalculateDiscounts();
+
+ // NOTE: See special early-return case for unigrams near the top of this function
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh
new file mode 100644
index 00000000..f38ff79d
--- /dev/null
+++ b/klm/lm/builder/adjust_counts.hh
@@ -0,0 +1,44 @@
+#ifndef LM_BUILDER_ADJUST_COUNTS__
+#define LM_BUILDER_ADJUST_COUNTS__
+
+#include "lm/builder/discount.hh"
+#include "util/exception.hh"
+
+#include <vector>
+
+#include <stdint.h>
+
+namespace lm {
+namespace builder {
+
+class ChainPositions;
+
+class BadDiscountException : public util::Exception {
+ public:
+ BadDiscountException() throw();
+ ~BadDiscountException() throw();
+};
+
+/* Compute adjusted counts.
+ * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts.
+ * Output: [1,N]-grams with adjusted counts.
+ * [1,N)-grams are in suffix order
+ * N-grams are in undefined order (they're going to be sorted anyway).
+ */
+class AdjustCounts {
+ public:
+ AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts)
+ : counts_(counts), discounts_(discounts) {}
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ std::vector<uint64_t> &counts_;
+ std::vector<Discount> &discounts_;
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_ADJUST_COUNTS__
+
diff --git a/klm/lm/builder/adjust_counts_test.cc b/klm/lm/builder/adjust_counts_test.cc
new file mode 100644
index 00000000..68b5f33e
--- /dev/null
+++ b/klm/lm/builder/adjust_counts_test.cc
@@ -0,0 +1,106 @@
+#include "lm/builder/adjust_counts.hh"
+
+#include "lm/builder/multi_stream.hh"
+#include "util/scoped.hh"
+
+#include <boost/thread/thread.hpp>
+#define BOOST_TEST_MODULE AdjustCounts
+#include <boost/test/unit_test.hpp>
+
+namespace lm { namespace builder { namespace {
+
+class KeepCopy {
+ public:
+ KeepCopy() : size_(0) {}
+
+ void Run(const util::stream::ChainPosition &position) {
+ for (util::stream::Link link(position); link; ++link) {
+ mem_.call_realloc(size_ + link->ValidSize());
+ memcpy(static_cast<uint8_t*>(mem_.get()) + size_, link->Get(), link->ValidSize());
+ size_ += link->ValidSize();
+ }
+ }
+
+ uint8_t *Get() { return static_cast<uint8_t*>(mem_.get()); }
+ std::size_t Size() const { return size_; }
+
+ private:
+ util::scoped_malloc mem_;
+ std::size_t size_;
+};
+
+struct Gram4 {
+ WordIndex ids[4];
+ uint64_t count;
+};
+
+class WriteInput {
+ public:
+ void Run(const util::stream::ChainPosition &position) {
+ NGramStream input(position);
+ Gram4 grams[] = {
+ {{0,0,0,0},10},
+ {{0,0,3,0},3},
+ // bos
+ {{1,1,1,2},5},
+ {{0,0,3,2},5},
+ };
+ for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) {
+ memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4);
+ input->Count() = grams[i].count;
+ }
+ input.Poison();
+ }
+};
+
+BOOST_AUTO_TEST_CASE(Simple) {
+ KeepCopy outputs[4];
+ std::vector<uint64_t> counts;
+ std::vector<Discount> discount;
+ {
+ util::stream::ChainConfig config;
+ config.total_memory = 100;
+ config.block_count = 1;
+ Chains chains(4);
+ for (unsigned i = 0; i < 4; ++i) {
+ config.entry_size = NGram::TotalSize(i + 1);
+ chains.push_back(config);
+ }
+
+ chains[3] >> WriteInput();
+ ChainPositions for_adjust(chains);
+ for (unsigned i = 0; i < 4; ++i) {
+ chains[i] >> boost::ref(outputs[i]);
+ }
+ chains >> util::stream::kRecycle;
+ BOOST_CHECK_THROW(AdjustCounts(counts, discount).Run(for_adjust), BadDiscountException);
+ }
+ BOOST_REQUIRE_EQUAL(4UL, counts.size());
+ BOOST_CHECK_EQUAL(4UL, counts[0]);
+ // These are no longer set because the discounts are bad.
+/* BOOST_CHECK_EQUAL(4UL, counts[1]);
+ BOOST_CHECK_EQUAL(3UL, counts[2]);
+ BOOST_CHECK_EQUAL(3UL, counts[3]);*/
+ BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size());
+ NGram uni(outputs[0].Get(), 1);
+ BOOST_CHECK_EQUAL(kUNK, *uni.begin());
+ BOOST_CHECK_EQUAL(0ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(kBOS, *uni.begin());
+ BOOST_CHECK_EQUAL(0ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(0UL, *uni.begin());
+ BOOST_CHECK_EQUAL(2ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(2ULL, uni.Count());
+ BOOST_CHECK_EQUAL(2UL, *uni.begin());
+
+ BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size());
+ NGram bi(outputs[1].Get(), 2);
+ BOOST_CHECK_EQUAL(0UL, *bi.begin());
+ BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1));
+ BOOST_CHECK_EQUAL(1ULL, bi.Count());
+ bi.NextInMemory();
+}
+
+}}} // namespaces
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc
new file mode 100644
index 00000000..abea4ed0
--- /dev/null
+++ b/klm/lm/builder/corpus_count.cc
@@ -0,0 +1,224 @@
+#include "lm/builder/corpus_count.hh"
+
+#include "lm/builder/ngram.hh"
+#include "lm/lm_exception.hh"
+#include "lm/word_index.hh"
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/murmur_hash.hh"
+#include "util/probing_hash_table.hh"
+#include "util/scoped.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/timer.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/unordered_set.hpp>
+#include <boost/unordered_map.hpp>
+
+#include <functional>
+
+#include <stdint.h>
+
+namespace lm {
+namespace builder {
+namespace {
+
+class VocabHandout {
+ public:
+ explicit VocabHandout(int fd) {
+ util::scoped_fd duped(util::DupOrThrow(fd));
+ word_list_.reset(util::FDOpenOrThrow(duped));
+
+ Lookup("<unk>"); // Force 0
+ Lookup("<s>"); // Force 1
+ Lookup("</s>"); // Force 2
+ }
+
+ WordIndex Lookup(const StringPiece &word) {
+ uint64_t hashed = util::MurmurHashNative(word.data(), word.size());
+ std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size())));
+ if (ret.second) {
+ char null_delimit = 0;
+ util::WriteOrThrow(word_list_.get(), word.data(), word.size());
+ util::WriteOrThrow(word_list_.get(), &null_delimit, 1);
+ UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ }
+ return ret.first->second;
+ }
+
+ WordIndex Size() const {
+ return seen_.size();
+ }
+
+ private:
+ typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen;
+
+ Seen seen_;
+
+ util::scoped_FILE word_list_;
+};
+
+class DedupeHash : public std::unary_function<const WordIndex *, bool> {
+ public:
+ explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
+
+ std::size_t operator()(const WordIndex *start) const {
+ return util::MurmurHashNative(start, size_);
+ }
+
+ private:
+ const std::size_t size_;
+};
+
+class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> {
+ public:
+ explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {}
+
+ bool operator()(const WordIndex *first, const WordIndex *second) const {
+ return !memcmp(first, second, size_);
+ }
+
+ private:
+ const std::size_t size_;
+};
+
+struct DedupeEntry {
+ typedef WordIndex *Key;
+ Key GetKey() const { return key; }
+ Key key;
+ static DedupeEntry Construct(WordIndex *at) {
+ DedupeEntry ret;
+ ret.key = at;
+ return ret;
+ }
+};
+
+typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
+
+const float kProbingMultiplier = 1.5;
+
+class Writer {
+ public:
+ Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
+ : block_(position), gram_(block_->Get(), order),
+ dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()),
+ dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
+ buffer_(new WordIndex[order - 1]),
+ block_size_(position.GetChain().BlockSize()) {
+ dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
+ if (order == 1) {
+ // Add special words. AdjustCounts is responsible if order != 1.
+ AddUnigramWord(kUNK);
+ AddUnigramWord(kBOS);
+ }
+ }
+
+ ~Writer() {
+ block_->SetValidSize(reinterpret_cast<const uint8_t*>(gram_.begin()) - static_cast<const uint8_t*>(block_->Get()));
+ (++block_).Poison();
+ }
+
+ // Write context with a bunch of <s>
+ void StartSentence() {
+ for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) {
+ *i = kBOS;
+ }
+ }
+
+ void Append(WordIndex word) {
+ *(gram_.end() - 1) = word;
+ Dedupe::MutableIterator at;
+ bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at);
+ if (found) {
+ // Already present.
+ NGram already(at->key, gram_.Order());
+ ++(already.Count());
+ // Shift left by one.
+ memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1));
+ return;
+ }
+ // Complete the write.
+ gram_.Count() = 1;
+ // Prepare the next n-gram.
+ if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) {
+ NGram last(gram_);
+ gram_.NextInMemory();
+ std::copy(last.begin() + 1, last.end(), gram_.begin());
+ return;
+ }
+ // Block end. Need to store the context in a temporary buffer.
+ std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
+ dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ block_->SetValidSize(block_size_);
+ gram_.ReBase((++block_)->Get());
+ std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
+ }
+
+ private:
+ void AddUnigramWord(WordIndex index) {
+ *gram_.begin() = index;
+ gram_.Count() = 0;
+ gram_.NextInMemory();
+ if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) {
+ block_->SetValidSize(block_size_);
+ gram_.ReBase((++block_)->Get());
+ }
+ }
+
+ util::stream::Link block_;
+
+ NGram gram_;
+
+ // This is the memory behind the invalid value in dedupe_.
+ std::vector<WordIndex> dedupe_invalid_;
+ // Hash table combiner implementation.
+ Dedupe dedupe_;
+
+ // Small buffer to hold existing ngrams when shifting across a block boundary.
+ boost::scoped_array<WordIndex> buffer_;
+
+ const std::size_t block_size_;
+};
+
+} // namespace
+
+float CorpusCount::DedupeMultiplier(std::size_t order) {
+ return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
+}
+
+CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
+ : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
+ dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
+ dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
+ token_count_ = 0;
+ type_count_ = 0;
+}
+
+void CorpusCount::Run(const util::stream::ChainPosition &position) {
+ UTIL_TIMER("(%w s) Counted n-grams\n");
+
+ VocabHandout vocab(vocab_write_);
+ const WordIndex end_sentence = vocab.Lookup("</s>");
+ Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
+ uint64_t count = 0;
+ StringPiece delimiters("\0\t\r ", 4);
+ try {
+ while(true) {
+ StringPiece line(from_.ReadLine());
+ writer.StartSentence();
+ for (util::TokenIter<util::AnyCharacter, true> w(line, delimiters); w; ++w) {
+ WordIndex word = vocab.Lookup(*w);
+ UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future.");
+ writer.Append(word);
+ ++count;
+ }
+ writer.Append(end_sentence);
+ }
+ } catch (const util::EndOfFileException &e) {}
+ token_count_ = count;
+ type_count_ = vocab.Size();
+}
+
+} // namespace builder
+} // namespace lm
diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh
new file mode 100644
index 00000000..e255bad1
--- /dev/null
+++ b/klm/lm/builder/corpus_count.hh
@@ -0,0 +1,42 @@
+#ifndef LM_BUILDER_CORPUS_COUNT__
+#define LM_BUILDER_CORPUS_COUNT__
+
+#include "lm/word_index.hh"
+#include "util/scoped.hh"
+
+#include <cstddef>
+#include <string>
+#include <stdint.h>
+
+namespace util {
+class FilePiece;
+namespace stream {
+class ChainPosition;
+} // namespace stream
+} // namespace util
+
+namespace lm {
+namespace builder {
+
+class CorpusCount {
+ public:
+ // Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size
+ static float DedupeMultiplier(std::size_t order);
+
+ CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);
+
+ void Run(const util::stream::ChainPosition &position);
+
+ private:
+ util::FilePiece &from_;
+ int vocab_write_;
+ uint64_t &token_count_;
+ WordIndex &type_count_;
+
+ std::size_t dedupe_mem_size_;
+ util::scoped_malloc dedupe_mem_;
+};
+
+} // namespace builder
+} // namespace lm
+#endif // LM_BUILDER_CORPUS_COUNT__
diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc
new file mode 100644
index 00000000..8d53ca9d
--- /dev/null
+++ b/klm/lm/builder/corpus_count_test.cc
@@ -0,0 +1,76 @@
+#include "lm/builder/corpus_count.hh"
+
+#include "lm/builder/ngram.hh"
+#include "lm/builder/ngram_stream.hh"
+
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/tokenize_piece.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/stream.hh"
+
+#define BOOST_TEST_MODULE CorpusCountTest
+#include <boost/test/unit_test.hpp>
+
+namespace lm { namespace builder { namespace {
+
+#define Check(str, count) { \
+ BOOST_REQUIRE(stream); \
+ w = stream->begin(); \
+ for (util::TokenIter<util::AnyCharacter, true> t(str, " "); t; ++t, ++w) { \
+ BOOST_CHECK_EQUAL(*t, v[*w]); \
+ } \
+ BOOST_CHECK_EQUAL((uint64_t)count, stream->Count()); \
+ ++stream; \
+}
+
+BOOST_AUTO_TEST_CASE(Short) {
+ util::scoped_fd input_file(util::MakeTemp("corpus_count_test_temp"));
+ const char input[] = "looking on a little more loin\non a little more loin\non foo little more loin\nbar\n\n";
+ // Blocks of 10 are
+ // looking on a little more loin </s> on a little[duplicate] more[duplicate] loin[duplicate] </s>[duplicate] on[duplicate] foo
+ // little more loin </s> bar </s> </s>
+
+ util::WriteOrThrow(input_file.get(), input, sizeof(input) - 1);
+ util::FilePiece input_piece(input_file.release(), "temp file");
+
+ util::stream::ChainConfig config;
+ config.entry_size = NGram::TotalSize(3);
+ config.total_memory = config.entry_size * 20;
+ config.block_count = 2;
+
+ util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab"));
+
+ util::stream::Chain chain(config);
+ NGramStream stream;
+ uint64_t token_count;
+ WordIndex type_count;
+ CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;
+
+ const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"};
+
+ WordIndex *w;
+
+ Check("<s> <s> looking", 1);
+ Check("<s> looking on", 1);
+ Check("looking on a", 1);
+ Check("on a little", 2);
+ Check("a little more", 2);
+ Check("little more loin", 2);
+ Check("more loin </s>", 2);
+ Check("<s> <s> on", 2);
+ Check("<s> on a", 1);
+ Check("<s> on foo", 1);
+ Check("on foo little", 1);
+ Check("foo little more", 1);
+ Check("little more loin", 1);
+ Check("more loin </s>", 1);
+ Check("<s> <s> bar", 1);
+ Check("<s> bar </s>", 1);
+ Check("<s> <s> </s>", 1);
+ BOOST_CHECK(!stream);
+ BOOST_CHECK_EQUAL(sizeof(v) / sizeof(const char*), type_count);
+}
+
+}}} // namespaces
diff --git a/klm/lm/builder/discount.hh b/klm/lm/builder/discount.hh
new file mode 100644
index 00000000..4d0aa4fd
--- /dev/null
+++ b/klm/lm/builder/discount.hh
@@ -0,0 +1,26 @@
+#ifndef BUILDER_DISCOUNT__
+#define BUILDER_DISCOUNT__
+
+#include <algorithm>
+
+#include <stdint.h>
+
+namespace lm {
+namespace builder {
+
+struct Discount {
+ float amount[4];
+
+ float Get(uint64_t count) const {
+ return amount[std::min<uint64_t>(count, 3)];
+ }
+
+ float Apply(uint64_t count) const {
+ return static_cast<float>(count) - Get(count);
+ }
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // BUILDER_DISCOUNT__
diff --git a/klm/lm/builder/header_info.hh b/klm/lm/builder/header_info.hh
new file mode 100644
index 00000000..ccca1456
--- /dev/null
+++ b/klm/lm/builder/header_info.hh
@@ -0,0 +1,20 @@
+#ifndef LM_BUILDER_HEADER_INFO__
+#define LM_BUILDER_HEADER_INFO__
+
+#include <string>
+#include <stdint.h>
+
+// Some configuration info that is used to add
+// comments to the beginning of an ARPA file
+struct HeaderInfo {
+ const std::string input_file;
+ const uint64_t token_count;
+
+ HeaderInfo(const std::string& input_file_in, uint64_t token_count_in)
+ : input_file(input_file_in), token_count(token_count_in) {}
+
+ // TODO: Add smoothing type
+ // TODO: More info if multiple models were interpolated
+};
+
+#endif
diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc
new file mode 100644
index 00000000..58b42a20
--- /dev/null
+++ b/klm/lm/builder/initial_probabilities.cc
@@ -0,0 +1,136 @@
+#include "lm/builder/initial_probabilities.hh"
+
+#include "lm/builder/discount.hh"
+#include "lm/builder/ngram_stream.hh"
+#include "lm/builder/sort.hh"
+#include "util/file.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/io.hh"
+#include "util/stream/stream.hh"
+
+#include <vector>
+
+namespace lm { namespace builder {
+
+namespace {
+struct BufferEntry {
+ // Gamma from page 20 of Chen and Goodman.
+ float gamma;
+ // \sum_w a(c w) for all w.
+ float denominator;
+};
+
+// Extract an array of gamma from an array of BufferEntry.
+class OnlyGamma {
+ public:
+ void Run(const util::stream::ChainPosition &position) {
+ for (util::stream::Link block_it(position); block_it; ++block_it) {
+ float *out = static_cast<float*>(block_it->Get());
+ const float *in = out;
+ const float *end = static_cast<const float*>(block_it->ValidEnd());
+ for (out += 1, in += 2; in < end; out += 1, in += 2) {
+ *out = *in;
+ }
+ block_it->SetValidSize(block_it->ValidSize() / 2);
+ }
+ }
+};
+
+class AddRight {
+ public:
+ AddRight(const Discount &discount, const util::stream::ChainPosition &input)
+ : discount_(discount), input_(input) {}
+
+ void Run(const util::stream::ChainPosition &output) {
+ NGramStream in(input_);
+ util::stream::Stream out(output);
+
+ std::vector<WordIndex> previous(in->Order() - 1);
+ const std::size_t size = sizeof(WordIndex) * previous.size();
+ for(; in; ++out) {
+ memcpy(&previous[0], in->begin(), size);
+ uint64_t denominator = 0;
+ uint64_t counts[4];
+ memset(counts, 0, sizeof(counts));
+ do {
+ denominator += in->Count();
+ ++counts[std::min(in->Count(), static_cast<uint64_t>(3))];
+ } while (++in && !memcmp(&previous[0], in->begin(), size));
+ BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get());
+ entry.denominator = static_cast<float>(denominator);
+ entry.gamma = 0.0;
+ for (unsigned i = 1; i <= 3; ++i) {
+ entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]);
+ }
+ entry.gamma /= entry.denominator;
+ }
+ out.Poison();
+ }
+
+ private:
+ const Discount &discount_;
+ const util::stream::ChainPosition input_;
+};
+
+class MergeRight {
+ public:
+ MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount)
+ : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {}
+
+ // calculate the initial probability of each n-gram (before order-interpolation)
+ // Run() gets invoked once for each order
+ void Run(const util::stream::ChainPosition &primary) {
+ util::stream::Stream summed(from_adder_);
+
+ NGramStream grams(primary);
+
+ // Without interpolation, the interpolation weight goes to <unk>.
+ if (grams->Order() == 1 && !interpolate_unigrams_) {
+ BufferEntry sums(*static_cast<const BufferEntry*>(summed.Get()));
+ assert(*grams->begin() == kUNK);
+ grams->Value().uninterp.prob = sums.gamma;
+ grams->Value().uninterp.gamma = 0.0;
+ while (++grams) {
+ grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator;
+ grams->Value().uninterp.gamma = 0.0;
+ }
+ ++summed;
+ return;
+ }
+
+ std::vector<WordIndex> previous(grams->Order() - 1);
+ const std::size_t size = sizeof(WordIndex) * previous.size();
+ for (; grams; ++summed) {
+ memcpy(&previous[0], grams->begin(), size);
+ const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get());
+ do {
+ Payload &pay = grams->Value();
+ pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator;
+ pay.uninterp.gamma = sums.gamma;
+ } while (++grams && !memcmp(&previous[0], grams->begin(), size));
+ }
+ }
+
+ private:
+ bool interpolate_unigrams_;
+ util::stream::ChainPosition from_adder_;
+ Discount discount_;
+};
+
+} // namespace
+
+void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) {
+ util::stream::ChainConfig gamma_config = config.adder_out;
+ gamma_config.entry_size = sizeof(BufferEntry);
+ for (size_t i = 0; i < primary.size(); ++i) {
+ util::stream::ChainPosition second(second_in[i].Add());
+ second_in[i] >> util::stream::kRecycle;
+ gamma_out.push_back(gamma_config);
+ gamma_out[i] >> AddRight(discounts[i], second);
+ primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]);
+ // Don't bother with the OnlyGamma thread for something to discard.
+ if (i) gamma_out[i] >> OnlyGamma();
+ }
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/initial_probabilities.hh b/klm/lm/builder/initial_probabilities.hh
new file mode 100644
index 00000000..626388eb
--- /dev/null
+++ b/klm/lm/builder/initial_probabilities.hh
@@ -0,0 +1,34 @@
+#ifndef LM_BUILDER_INITIAL_PROBABILITIES__
+#define LM_BUILDER_INITIAL_PROBABILITIES__
+
+#include "lm/builder/discount.hh"
+#include "util/stream/config.hh"
+
+#include <vector>
+
+namespace lm {
+namespace builder {
+class Chains;
+
+struct InitialProbabilitiesConfig {
+ // These should be small buffers to keep the adder from getting too far ahead
+ util::stream::ChainConfig adder_in;
+ util::stream::ChainConfig adder_out;
+ // SRILM doesn't normally interpolate unigrams.
+ bool interpolate_unigrams;
+};
+
+/* Compute initial (uninterpolated) probabilities
+ * primary: the normal chain of n-grams. Incoming is context sorted adjusted
+ * counts. Outgoing has uninterpolated probabilities for use by Interpolate.
+ * second_in: a second copy of the primary input. Discard the output.
+ * gamma_out: Computed gamma values are output on these chains in suffix order.
+ * The values are bare floats and should be buffered for interpolation to
+ * use.
+ */
+void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out);
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_INITIAL_PROBABILITIES__
diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc
new file mode 100644
index 00000000..50026806
--- /dev/null
+++ b/klm/lm/builder/interpolate.cc
@@ -0,0 +1,65 @@
+#include "lm/builder/interpolate.hh"
+
+#include "lm/builder/joint_order.hh"
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/sort.hh"
+#include "lm/lm_exception.hh"
+
+#include <assert.h>
+
+namespace lm { namespace builder {
+namespace {
+
+class Callback {
+ public:
+ Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) {
+ probs_[0] = uniform_prob;
+ for (std::size_t i = 0; i < backoffs.size(); ++i) {
+ backoffs_.push_back(backoffs[i]);
+ }
+ }
+
+ ~Callback() {
+ for (std::size_t i = 0; i < backoffs_.size(); ++i) {
+ if (backoffs_[i]) {
+ std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl;
+ abort();
+ }
+ }
+ }
+
+ void Enter(unsigned order_minus_1, NGram &gram) {
+ Payload &pay = gram.Value();
+ pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
+ probs_[order_minus_1 + 1] = pay.complete.prob;
+ pay.complete.prob = log10(pay.complete.prob);
+ // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling.
+ if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
+ pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
+ ++backoffs_[order_minus_1];
+ } else {
+ // Not a context.
+ pay.complete.backoff = 0.0;
+ }
+ }
+
+ void Exit(unsigned, const NGram &) const {}
+
+ private:
+ FixedArray<util::stream::Stream> backoffs_;
+
+ std::vector<float> probs_;
+};
+} // namespace
+
+Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)
+ : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {}
+
+// perform order-wise interpolation
+void Interpolate::Run(const ChainPositions &positions) {
+ assert(positions.size() == backoffs_.size() + 1);
+ Callback callback(uniform_prob_, backoffs_);
+ JointOrder<Callback, SuffixOrder>(positions, callback);
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/interpolate.hh b/klm/lm/builder/interpolate.hh
new file mode 100644
index 00000000..9268d404
--- /dev/null
+++ b/klm/lm/builder/interpolate.hh
@@ -0,0 +1,27 @@
+#ifndef LM_BUILDER_INTERPOLATE__
+#define LM_BUILDER_INTERPOLATE__
+
+#include <stdint.h>
+
+#include "lm/builder/multi_stream.hh"
+
+namespace lm { namespace builder {
+
+/* Interpolate step.
+ * Input: suffix sorted n-grams with (p_uninterpolated, gamma) from
+ * InitialProbabilities.
+ * Output: suffix sorted n-grams with complete probability
+ */
+class Interpolate {
+ public:
+ explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs);
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ float uniform_prob_;
+ ChainPositions backoffs_;
+};
+
+}} // namespaces
+#endif // LM_BUILDER_INTERPOLATE__
diff --git a/klm/lm/builder/joint_order.hh b/klm/lm/builder/joint_order.hh
new file mode 100644
index 00000000..b5620144
--- /dev/null
+++ b/klm/lm/builder/joint_order.hh
@@ -0,0 +1,43 @@
+#ifndef LM_BUILDER_JOINT_ORDER__
+#define LM_BUILDER_JOINT_ORDER__
+
+#include "lm/builder/multi_stream.hh"
+#include "lm/lm_exception.hh"
+
+#include <string.h>
+
+namespace lm { namespace builder {
+
+template <class Callback, class Compare> void JointOrder(const ChainPositions &positions, Callback &callback) {
+ // Allow matching to reference streams[-1].
+ NGramStreams streams_with_dummy;
+ streams_with_dummy.InitWithDummy(positions);
+ NGramStream *streams = streams_with_dummy.begin() + 1;
+
+ unsigned int order;
+ for (order = 0; order < positions.size() && streams[order]; ++order) {}
+ assert(order); // should always have <unk>.
+ unsigned int current = 0;
+ while (true) {
+ // Does the context match the lower one?
+ if (!memcmp(streams[static_cast<int>(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset, sizeof(WordIndex) * current)) {
+ callback.Enter(current, *streams[current]);
+ // Transition to looking for extensions.
+ if (++current < order) continue;
+ }
+ // No extension left.
+ while(true) {
+ assert(current > 0);
+ --current;
+ callback.Exit(current, *streams[current]);
+ if (++streams[current]) break;
+ UTIL_THROW_IF(order != current + 1, FormatLoadException, "Detected n-gram without matching suffix");
+ order = current;
+ if (!order) return;
+ }
+ }
+}
+
+}} // namespaces
+
+#endif // LM_BUILDER_JOINT_ORDER__
diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc
new file mode 100644
index 00000000..90b9dca2
--- /dev/null
+++ b/klm/lm/builder/lmplz_main.cc
@@ -0,0 +1,94 @@
+#include "lm/builder/pipeline.hh"
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/usage.hh"
+
+#include <iostream>
+
+#include <boost/program_options.hpp>
+
+namespace {
+class SizeNotify {
+ public:
+ SizeNotify(std::size_t &out) : behind_(out) {}
+
+ void operator()(const std::string &from) {
+ behind_ = util::ParseSize(from);
+ }
+
+ private:
+ std::size_t &behind_;
+};
+
+boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value) {
+ return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value);
+}
+
+} // namespace
+
+int main(int argc, char *argv[]) {
+ try {
+ namespace po = boost::program_options;
+ po::options_description options("Language model building options");
+ lm::builder::PipelineConfig pipeline;
+
+ options.add_options()
+ ("order,o", po::value<std::size_t>(&pipeline.order)->required(), "Order of the model")
+ ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
+ ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
+ ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
+ ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step")
+ ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
+ ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
+ ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
+ ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
+ ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.");
+ if (argc == 1) {
+ std::cerr <<
+ "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"
+ "Please cite:\n"
+ "@inproceedings{kenlm,\n"
+ "author = {Kenneth Heafield},\n"
+ "title = {{KenLM}: Faster and Smaller Language Model Queries},\n"
+ "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n"
+ "month = {July}, year={2011},\n"
+ "address = {Edinburgh, UK},\n"
+ "publisher = {Association for Computational Linguistics},\n"
+ "}\n\n"
+ "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n"
+ "the model (-o) is the only mandatory option. As this is an on-disk program,\n"
+ "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n"
+ "Memory sizes are specified like GNU sort: a number followed by a unit character.\n"
+ "Valid units are \% for percentage of memory (supported platforms only) and (in\n"
+ "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n";
+ std::cerr << options << std::endl;
+ return 1;
+ }
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, options), vm);
+ po::notify(vm);
+
+ util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
+
+ lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;
+ // TODO: evaluate options for these.
+ initial.adder_in.total_memory = 32768;
+ initial.adder_in.block_count = 2;
+ initial.adder_out.total_memory = 32768;
+ initial.adder_out.block_count = 2;
+ pipeline.read_backoffs = initial.adder_out;
+
+ // Read from stdin
+ try {
+ lm::builder::Pipeline(pipeline, 0, 1);
+ } catch (const util::MallocException &e) {
+ std::cerr << e.what() << std::endl;
+ std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl;
+ return 1;
+ }
+ util::PrintUsage(std::cerr);
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
+ }
+}
diff --git a/klm/lm/builder/multi_stream.hh b/klm/lm/builder/multi_stream.hh
new file mode 100644
index 00000000..707a98c7
--- /dev/null
+++ b/klm/lm/builder/multi_stream.hh
@@ -0,0 +1,180 @@
+#ifndef LM_BUILDER_MULTI_STREAM__
+#define LM_BUILDER_MULTI_STREAM__
+
+#include "lm/builder/ngram_stream.hh"
+#include "util/scoped.hh"
+#include "util/stream/chain.hh"
+
+#include <cstddef>
+#include <new>
+
+#include <assert.h>
+#include <stdlib.h>
+
+namespace lm { namespace builder {
+
+template <class T> class FixedArray {
+ public:
+ explicit FixedArray(std::size_t count) {
+ Init(count);
+ }
+
+ FixedArray() : newed_end_(NULL) {}
+
+ void Init(std::size_t count) {
+ assert(!block_.get());
+ block_.reset(malloc(sizeof(T) * count));
+ if (!block_.get()) throw std::bad_alloc();
+ newed_end_ = begin();
+ }
+
+ FixedArray(const FixedArray &from) {
+ std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get());
+ Init(size);
+ for (std::size_t i = 0; i < size; ++i) {
+ new(end()) T(from[i]);
+ Constructed();
+ }
+ }
+
+ ~FixedArray() { clear(); }
+
+ T *begin() { return static_cast<T*>(block_.get()); }
+ const T *begin() const { return static_cast<const T*>(block_.get()); }
+ // Always call Constructed after successful completion of new.
+ T *end() { return newed_end_; }
+ const T *end() const { return newed_end_; }
+
+ T &back() { return *(end() - 1); }
+ const T &back() const { return *(end() - 1); }
+
+ std::size_t size() const { return end() - begin(); }
+ bool empty() const { return begin() == end(); }
+
+ T &operator[](std::size_t i) { return begin()[i]; }
+ const T &operator[](std::size_t i) const { return begin()[i]; }
+
+ template <class C> void push_back(const C &c) {
+ new (end()) T(c);
+ Constructed();
+ }
+
+ void clear() {
+ for (T *i = begin(); i != end(); ++i)
+ i->~T();
+ newed_end_ = begin();
+ }
+
+ protected:
+ void Constructed() {
+ ++newed_end_;
+ }
+
+ private:
+ util::scoped_malloc block_;
+
+ T *newed_end_;
+};
+
+class Chains;
+
+class ChainPositions : public FixedArray<util::stream::ChainPosition> {
+ public:
+ ChainPositions() {}
+
+ void Init(Chains &chains);
+
+ explicit ChainPositions(Chains &chains) {
+ Init(chains);
+ }
+};
+
+class Chains : public FixedArray<util::stream::Chain> {
+ private:
+ template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun {
+ typedef Chains type;
+ };
+
+ public:
+ explicit Chains(std::size_t limit) : FixedArray<util::stream::Chain>(limit) {}
+
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) {
+ threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
+ return *this;
+ }
+
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) {
+ threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
+ return *this;
+ }
+
+ Chains &operator>>(const util::stream::Recycler &recycler) {
+ for (util::stream::Chain *i = begin(); i != end(); ++i)
+ *i >> recycler;
+ return *this;
+ }
+
+ void Wait(bool release_memory = true) {
+ threads_.clear();
+ for (util::stream::Chain *i = begin(); i != end(); ++i) {
+ i->Wait(release_memory);
+ }
+ }
+
+ private:
+ boost::ptr_vector<util::stream::Thread> threads_;
+
+ Chains(const Chains &);
+ void operator=(const Chains &);
+};
+
+inline void ChainPositions::Init(Chains &chains) {
+ FixedArray<util::stream::ChainPosition>::Init(chains.size());
+ for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) {
+ new (end()) util::stream::ChainPosition(i->Add()); Constructed();
+ }
+}
+
+inline Chains &operator>>(Chains &chains, ChainPositions &positions) {
+ positions.Init(chains);
+ return chains;
+}
+
+class NGramStreams : public FixedArray<NGramStream> {
+ public:
+ NGramStreams() {}
+
+ // This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning).
+ void InitWithDummy(const ChainPositions &positions) {
+ FixedArray<NGramStream>::Init(positions.size() + 1);
+ new (end()) NGramStream(); Constructed();
+ for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) {
+ push_back(*i);
+ }
+ }
+
+ // Limit restricts to positions[0,limit)
+ void Init(const ChainPositions &positions, std::size_t limit) {
+ FixedArray<NGramStream>::Init(limit);
+ for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) {
+ push_back(*i);
+ }
+ }
+ void Init(const ChainPositions &positions) {
+ Init(positions, positions.size());
+ }
+
+ NGramStreams(const ChainPositions &positions) {
+ Init(positions);
+ }
+};
+
+inline Chains &operator>>(Chains &chains, NGramStreams &streams) {
+ ChainPositions positions;
+ chains >> positions;
+ streams.Init(positions);
+ return chains;
+}
+
+}} // namespaces
+#endif // LM_BUILDER_MULTI_STREAM__
diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh
new file mode 100644
index 00000000..2984ed0b
--- /dev/null
+++ b/klm/lm/builder/ngram.hh
@@ -0,0 +1,84 @@
+#ifndef LM_BUILDER_NGRAM__
+#define LM_BUILDER_NGRAM__
+
+#include "lm/weights.hh"
+#include "lm/word_index.hh"
+
+#include <cstddef>
+
+#include <assert.h>
+#include <stdint.h>
+#include <string.h>
+
+namespace lm {
+namespace builder {
+
+struct Uninterpolated {
+ float prob; // Uninterpolated probability.
+ float gamma; // Interpolation weight for lower order.
+};
+
+union Payload {
+ uint64_t count;
+ Uninterpolated uninterp;
+ ProbBackoff complete;
+};
+
+class NGram {
+ public:
+ NGram(void *begin, std::size_t order)
+ : begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {}
+
+ const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); }
+ uint8_t *Base() { return reinterpret_cast<uint8_t*>(begin_); }
+
+ void ReBase(void *to) {
+ std::size_t difference = end_ - begin_;
+ begin_ = reinterpret_cast<WordIndex*>(to);
+ end_ = begin_ + difference;
+ }
+
+ // Would do operator++ but that can get confusing for a stream.
+ void NextInMemory() {
+ ReBase(&Value() + 1);
+ }
+
+ // Lower-case in deference to STL.
+ const WordIndex *begin() const { return begin_; }
+ WordIndex *begin() { return begin_; }
+ const WordIndex *end() const { return end_; }
+ WordIndex *end() { return end_; }
+
+ const Payload &Value() const { return *reinterpret_cast<const Payload *>(end_); }
+ Payload &Value() { return *reinterpret_cast<Payload *>(end_); }
+
+ uint64_t &Count() { return Value().count; }
+ const uint64_t Count() const { return Value().count; }
+
+ std::size_t Order() const { return end_ - begin_; }
+
+ static std::size_t TotalSize(std::size_t order) {
+ return order * sizeof(WordIndex) + sizeof(Payload);
+ }
+ std::size_t TotalSize() const {
+ // Compiler should optimize this.
+ return TotalSize(Order());
+ }
+ static std::size_t OrderFromSize(std::size_t size) {
+ std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex);
+ assert(size == TotalSize(ret));
+ return ret;
+ }
+
+ private:
+ WordIndex *begin_, *end_;
+};
+
+const WordIndex kUNK = 0;
+const WordIndex kBOS = 1;
+const WordIndex kEOS = 2;
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_NGRAM__
diff --git a/klm/lm/builder/ngram_stream.hh b/klm/lm/builder/ngram_stream.hh
new file mode 100644
index 00000000..3c994664
--- /dev/null
+++ b/klm/lm/builder/ngram_stream.hh
@@ -0,0 +1,55 @@
+#ifndef LM_BUILDER_NGRAM_STREAM__
+#define LM_BUILDER_NGRAM_STREAM__
+
+#include "lm/builder/ngram.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/stream.hh"
+
+#include <cstddef>
+
+namespace lm { namespace builder {
+
+class NGramStream {
+ public:
+ NGramStream() : gram_(NULL, 0) {}
+
+ NGramStream(const util::stream::ChainPosition &position) : gram_(NULL, 0) {
+ Init(position);
+ }
+
+ void Init(const util::stream::ChainPosition &position) {
+ stream_.Init(position);
+ gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize()));
+ }
+
+ NGram &operator*() { return gram_; }
+ const NGram &operator*() const { return gram_; }
+
+ NGram *operator->() { return &gram_; }
+ const NGram *operator->() const { return &gram_; }
+
+ void *Get() { return stream_.Get(); }
+ const void *Get() const { return stream_.Get(); }
+
+ operator bool() const { return stream_; }
+ bool operator!() const { return !stream_; }
+ void Poison() { stream_.Poison(); }
+
+ NGramStream &operator++() {
+ ++stream_;
+ gram_.ReBase(stream_.Get());
+ return *this;
+ }
+
+ private:
+ NGram gram_;
+ util::stream::Stream stream_;
+};
+
+inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) {
+ str.Init(chain.Add());
+ return chain;
+}
+
+}} // namespaces
+#endif // LM_BUILDER_NGRAM_STREAM__
diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc
new file mode 100644
index 00000000..14a1f721
--- /dev/null
+++ b/klm/lm/builder/pipeline.cc
@@ -0,0 +1,320 @@
+#include "lm/builder/pipeline.hh"
+
+#include "lm/builder/adjust_counts.hh"
+#include "lm/builder/corpus_count.hh"
+#include "lm/builder/initial_probabilities.hh"
+#include "lm/builder/interpolate.hh"
+#include "lm/builder/print.hh"
+#include "lm/builder/sort.hh"
+
+#include "lm/sizes.hh"
+
+#include "util/exception.hh"
+#include "util/file.hh"
+#include "util/stream/io.hh"
+
+#include <algorithm>
+#include <iostream>
+#include <vector>
+
+namespace lm { namespace builder {
+
+namespace {
+void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts) {
+ std::cerr << "Statistics:\n";
+ for (size_t i = 0; i < counts.size(); ++i) {
+ std::cerr << (i + 1) << ' ' << counts[i];
+ for (size_t d = 1; d <= 3; ++d)
+ std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d];
+ std::cerr << '\n';
+ }
+}
+
+class Master {
+ public:
+ explicit Master(const PipelineConfig &config)
+ : config_(config), chains_(config.order), files_(config.order) {
+ config_.minimum_block = std::max(NGram::TotalSize(config_.order), config_.minimum_block);
+ }
+
+ const PipelineConfig &Config() const { return config_; }
+
+ Chains &MutableChains() { return chains_; }
+
+ template <class T> Master &operator>>(const T &worker) {
+ chains_ >> worker;
+ return *this;
+ }
+
+ // This takes the (partially) sorted ngrams and sets up for adjusted counts.
+ void InitForAdjust(util::stream::Sort<SuffixOrder, AddCombiner> &ngrams, WordIndex types) {
+ const std::size_t each_order_min = config_.minimum_block * config_.block_count;
+ // We know how many unigrams there are. Don't allocate more than needed to them.
+ const std::size_t min_chains = (config_.order - 1) * each_order_min +
+ std::min(types * NGram::TotalSize(1), each_order_min);
+ // Do merge sort with calculated laziness.
+ const std::size_t merge_using = ngrams.Merge(std::min(config_.TotalMemory() - min_chains, ngrams.DefaultLazy()));
+
+ std::vector<uint64_t> count_bounds(1, types);
+ CreateChains(config_.TotalMemory() - merge_using, count_bounds);
+ ngrams.Output(chains_.back(), merge_using);
+
+ // Setup unigram file.
+ files_.push_back(util::MakeTemp(config_.TempPrefix()));
+ }
+
+ // For initial probabilities, but this is generic.
+ void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, Chains &second, util::stream::ChainConfig second_config) {
+ // Do merge first before allocating chain memory.
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts[i - 1].Merge(0);
+ }
+ // There's no lazy merge, so just divide memory amongst the chains.
+ CreateChains(config_.TotalMemory(), counts);
+ chains_.back().ActivateProgress();
+ chains_[0] >> files_[0].Source();
+ second_config.entry_size = NGram::TotalSize(1);
+ second.push_back(second_config);
+ second.back() >> files_[0].Source();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ util::scoped_fd fd(sorts[i - 1].StealCompleted());
+ chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get()));
+ chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true);
+ second_config.entry_size = NGram::TotalSize(i + 1);
+ second.push_back(second_config);
+ second.back() >> util::stream::PRead(fd.release(), true);
+ }
+ }
+
+ // There is no sort after this, so go for broke on lazy merging.
+ template <class Compare> void MaximumLazyInput(const std::vector<uint64_t> &counts, Sorts<Compare> &sorts) {
+ // Determine the minimum we can use for all the chains.
+ std::size_t min_chains = 0;
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast<uint64_t>(config_.minimum_block));
+ }
+ std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains);
+ std::vector<std::size_t> laziness;
+ // Prioritize longer n-grams.
+ for (util::stream::Sort<SuffixOrder> *i = sorts.end() - 1; i >= sorts.begin(); --i) {
+ laziness.push_back(i->Merge(for_merge));
+ assert(for_merge >= laziness.back());
+ for_merge -= laziness.back();
+ }
+ std::reverse(laziness.begin(), laziness.end());
+
+ CreateChains(for_merge + min_chains, counts);
+ chains_.back().ActivateProgress();
+ chains_[0] >> files_[0].Source();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts[i - 1].Output(chains_[i], laziness[i - 1]);
+ }
+ }
+
+ void BufferFinal(const std::vector<uint64_t> &counts) {
+ chains_[0] >> files_[0].Sink();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ files_.push_back(util::MakeTemp(config_.TempPrefix()));
+ chains_[i] >> files_[i].Sink();
+ }
+ chains_.Wait(true);
+ // Use less memory. Because we can.
+ CreateChains(std::min(config_.sort.buffer_size * config_.order, config_.TotalMemory()), counts);
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ chains_[i] >> files_[i].Source();
+ }
+ }
+
+ template <class Compare> void SetupSorts(Sorts<Compare> &sorts) {
+ sorts.Init(config_.order - 1);
+ // Unigrams don't get sorted because their order is always the same.
+ chains_[0] >> files_[0].Sink();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts.push_back(chains_[i], config_.sort, Compare(i + 1));
+ }
+ chains_.Wait(true);
+ }
+
+ private:
+ // Create chains, allocating memory to them. Totally heuristic. Count
+ // bounds are upper bounds on the counts or not present.
+ void CreateChains(std::size_t remaining_mem, const std::vector<uint64_t> &count_bounds) {
+ std::vector<std::size_t> assignments;
+ assignments.reserve(config_.order);
+ // Start by assigning maximum memory usage (to be refined later).
+ for (std::size_t i = 0; i < count_bounds.size(); ++i) {
+ assignments.push_back(static_cast<std::size_t>(std::min(
+ static_cast<uint64_t>(remaining_mem),
+ count_bounds[i] * static_cast<uint64_t>(NGram::TotalSize(i + 1)))));
+ }
+ assignments.resize(config_.order, remaining_mem);
+
+ // Now we know how much memory everybody wants. How much will they get?
+ // Proportional to this.
+ std::vector<float> portions;
+ // Indices of orders that have yet to be assigned.
+ std::vector<std::size_t> unassigned;
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ portions.push_back(static_cast<float>((i+1) * NGram::TotalSize(i+1)));
+ unassigned.push_back(i);
+ }
+ /*If somebody doesn't eat their full dinner, give it to the rest of the
+ * family. Then somebody else might not eat their full dinner etc. Ends
+ * when everybody unassigned is hungry.
+ */
+ float sum;
+ bool found_more;
+ std::vector<std::size_t> block_count(config_.order);
+ do {
+ sum = 0.0;
+ for (std::size_t i = 0; i < unassigned.size(); ++i) {
+ sum += portions[unassigned[i]];
+ }
+ found_more = false;
+ // If the proportional assignment is more than needed, give it just what it needs.
+ for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end();) {
+ if (assignments[*i] <= remaining_mem * (portions[*i] / sum)) {
+ remaining_mem -= assignments[*i];
+ block_count[*i] = 1;
+ i = unassigned.erase(i);
+ found_more = true;
+ } else {
+ ++i;
+ }
+ }
+ } while (found_more);
+ for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end(); ++i) {
+ assignments[*i] = remaining_mem * (portions[*i] / sum);
+ block_count[*i] = config_.block_count;
+ }
+ chains_.clear();
+ std::cerr << "Chain sizes:";
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ std::cerr << ' ' << (i+1) << ":" << assignments[i];
+ chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i]));
+ }
+ std::cerr << std::endl;
+ }
+
+ PipelineConfig config_;
+
+ Chains chains_;
+ // Often only unigrams, but sometimes all orders.
+ FixedArray<util::stream::FileBuffer> files_;
+};
+
+void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) {
+ const PipelineConfig &config = master.Config();
+ std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl;
+
+ UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory());
+ std::size_t memory_for_chain =
+ // This much memory to work with after vocab hash table.
+ static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) /
+ // Solve for block size including the dedupe multiplier for one block.
+ (static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *
+ // Chain likes memory expressed in terms of total memory.
+ static_cast<float>(config.block_count);
+ util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain));
+
+ WordIndex type_count;
+ util::FilePiece text(text_file, NULL, &std::cerr);
+ text_file_name = text.FileName();
+ CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ chain >> boost::ref(counter);
+
+ util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
+ chain.Wait(true);
+ std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl;
+ master.InitForAdjust(sorter, type_count);
+}
+
+void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+ const PipelineConfig &config = master.Config();
+ Chains second(config.order);
+
+ {
+ Sorts<ContextOrder> sorts;
+ master.SetupSorts(sorts);
+ PrintStatistics(counts, discounts);
+ lm::ngram::ShowSizes(counts);
+ std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl;
+ master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in);
+ }
+
+ Chains gamma_chains(config.order);
+ InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains);
+ // Don't care about gamma for 0.
+ gamma_chains[0] >> util::stream::kRecycle;
+ gammas.Init(config.order - 1);
+ for (std::size_t i = 1; i < config.order; ++i) {
+ gammas.push_back(util::MakeTemp(config.TempPrefix()));
+ gamma_chains[i] >> gammas[i - 1].Sink();
+ }
+ // Has to be done here due to gamma_chains scope.
+ master.SetupSorts(primary);
+}
+
+void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+ std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl;
+ const PipelineConfig &config = master.Config();
+ master.MaximumLazyInput(counts, primary);
+
+ Chains gamma_chains(config.order - 1);
+ util::stream::ChainConfig read_backoffs(config.read_backoffs);
+ read_backoffs.entry_size = sizeof(float);
+ for (std::size_t i = 0; i < config.order - 1; ++i) {
+ gamma_chains.push_back(read_backoffs);
+ gamma_chains.back() >> gammas[i].Source();
+ }
+ master >> Interpolate(counts[0], ChainPositions(gamma_chains));
+ gamma_chains >> util::stream::kRecycle;
+ master.BufferFinal(counts);
+}
+
+} // namespace
+
+void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
+ // Some fail-fast sanity checks.
+ if (config.sort.buffer_size * 4 > config.TotalMemory()) {
+ config.sort.buffer_size = config.TotalMemory() / 4;
+ std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl;
+ }
+ if (config.minimum_block < NGram::TotalSize(config.order)) {
+ config.minimum_block = NGram::TotalSize(config.order);
+ std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl;
+ }
+ UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << ".");
+ UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception,
+ "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size.");
+
+ UTIL_TIMER("(%w s) Total wall time elapsed\n");
+ Master master(config);
+
+ util::scoped_fd vocab_file(config.vocab_file.empty() ?
+ util::MakeTemp(config.TempPrefix()) :
+ util::CreateOrThrow(config.vocab_file.c_str()));
+ uint64_t token_count;
+ std::string text_file_name;
+ CountText(text_file, vocab_file.get(), master, token_count, text_file_name);
+
+ std::vector<uint64_t> counts;
+ std::vector<Discount> discounts;
+ master >> AdjustCounts(counts, discounts);
+
+ {
+ FixedArray<util::stream::FileBuffer> gammas;
+ Sorts<SuffixOrder> primary;
+ InitialProbabilities(counts, discounts, master, primary, gammas);
+ InterpolateProbabilities(counts, master, primary, gammas);
+ }
+
+ std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
+ VocabReconstitute vocab(vocab_file.get());
+ UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?");
+ HeaderInfo header_info(text_file_name, token_count);
+ master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle;
+ master.MutableChains().Wait(true);
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh
new file mode 100644
index 00000000..f1d6c5f6
--- /dev/null
+++ b/klm/lm/builder/pipeline.hh
@@ -0,0 +1,40 @@
+#ifndef LM_BUILDER_PIPELINE__
+#define LM_BUILDER_PIPELINE__
+
+#include "lm/builder/initial_probabilities.hh"
+#include "lm/builder/header_info.hh"
+#include "util/stream/config.hh"
+#include "util/file_piece.hh"
+
+#include <string>
+#include <cstddef>
+
+namespace lm { namespace builder {
+
+struct PipelineConfig {
+ std::size_t order;
+ std::string vocab_file;
+ util::stream::SortConfig sort;
+ InitialProbabilitiesConfig initial_probs;
+ util::stream::ChainConfig read_backoffs;
+ bool verbose_header;
+
+ // Amount of memory to assume that the vocabulary hash table will use. This
+ // is subtracted from total memory for CorpusCount.
+ std::size_t assume_vocab_hash_size;
+
+ // Minimum block size to tolerate.
+ std::size_t minimum_block;
+
+ // Number of blocks to use. This will be overridden to 1 if everything fits.
+ std::size_t block_count;
+
+ const std::string &TempPrefix() const { return sort.temp_prefix; }
+ std::size_t TotalMemory() const { return sort.total_memory; }
+};
+
+// Takes ownership of text_file.
+void Pipeline(PipelineConfig config, int text_file, int out_arpa);
+
+}} // namespaces
+#endif // LM_BUILDER_PIPELINE__
diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc
new file mode 100644
index 00000000..b0323221
--- /dev/null
+++ b/klm/lm/builder/print.cc
@@ -0,0 +1,135 @@
+#include "lm/builder/print.hh"
+
+#include "util/double-conversion/double-conversion.h"
+#include "util/double-conversion/utils.h"
+#include "util/file.hh"
+#include "util/mmap.hh"
+#include "util/scoped.hh"
+#include "util/stream/timer.hh"
+
+#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE
+#include <boost/lexical_cast.hpp>
+
+#include <sstream>
+
+#include <string.h>
+
+namespace lm { namespace builder {
+
+VocabReconstitute::VocabReconstitute(int fd) {
+ uint64_t size = util::SizeOrThrow(fd);
+ util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_);
+ const char *const start = static_cast<const char*>(memory_.get());
+ const char *i;
+ for (i = start; i != start + size; i += strlen(i) + 1) {
+ map_.push_back(i);
+ }
+ // Last one for LookupPiece.
+ map_.push_back(i);
+}
+
+namespace {
+class OutputManager {
+ public:
+ static const std::size_t kOutBuf = 1048576;
+
+ // Does not take ownership of out.
+ explicit OutputManager(int out)
+ : buf_(util::MallocOrThrow(kOutBuf)),
+ builder_(static_cast<char*>(buf_.get()), kOutBuf),
+ // Mostly the default but with inf instead. And no flags.
+ convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0),
+ fd_(out) {}
+
+ ~OutputManager() {
+ Flush();
+ }
+
+ OutputManager &operator<<(float value) {
+ // Odd, but this is the largest number found in the comments.
+ EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
+ convert_.ToShortestSingle(value, &builder_);
+ return *this;
+ }
+
+ OutputManager &operator<<(StringPiece str) {
+ if (str.size() > kOutBuf) {
+ Flush();
+ util::WriteOrThrow(fd_, str.data(), str.size());
+ } else {
+ EnsureRemaining(str.size());
+ builder_.AddSubstring(str.data(), str.size());
+ }
+ return *this;
+ }
+
+ // Inefficient!
+ OutputManager &operator<<(unsigned val) {
+ return *this << boost::lexical_cast<std::string>(val);
+ }
+
+ OutputManager &operator<<(char c) {
+ EnsureRemaining(1);
+ builder_.AddCharacter(c);
+ return *this;
+ }
+
+ void Flush() {
+ util::WriteOrThrow(fd_, buf_.get(), builder_.position());
+ builder_.Reset();
+ }
+
+ private:
+ void EnsureRemaining(std::size_t amount) {
+ if (static_cast<std::size_t>(builder_.size() - builder_.position()) < amount) {
+ Flush();
+ }
+ }
+
+ util::scoped_malloc buf_;
+ double_conversion::StringBuilder builder_;
+ double_conversion::DoubleToStringConverter convert_;
+ int fd_;
+};
+} // namespace
+
+PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd)
+ : vocab_(vocab), out_fd_(out_fd) {
+ std::stringstream stream;
+
+ if (header_info) {
+ stream << "# Input file: " << header_info->input_file << '\n';
+ stream << "# Token count: " << header_info->token_count << '\n';
+ stream << "# Smoothing: Modified Kneser-Ney" << '\n';
+ }
+ stream << "\\data\\\n";
+ for (size_t i = 0; i < counts.size(); ++i) {
+ stream << "ngram " << (i+1) << '=' << counts[i] << '\n';
+ }
+ stream << '\n';
+ std::string as_string(stream.str());
+ util::WriteOrThrow(out_fd, as_string.data(), as_string.size());
+}
+
+void PrintARPA::Run(const ChainPositions &positions) {
+ UTIL_TIMER("(%w s) Wrote ARPA file\n");
+ OutputManager out(out_fd_);
+ for (unsigned order = 1; order <= positions.size(); ++order) {
+ out << "\\" << order << "-grams:" << '\n';
+ for (NGramStream stream(positions[order - 1]); stream; ++stream) {
+ // Correcting for numerical precision issues. Take that IRST.
+ out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin());
+ for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
+ out << ' ' << vocab_.Lookup(*i);
+ }
+ float backoff = stream->Value().complete.backoff;
+ if (backoff != 0.0)
+ out << '\t' << backoff;
+ out << '\n';
+ }
+ out << '\n';
+ }
+ out << "\\end\\\n";
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh
new file mode 100644
index 00000000..aa932e75
--- /dev/null
+++ b/klm/lm/builder/print.hh
@@ -0,0 +1,102 @@
+#ifndef LM_BUILDER_PRINT__
+#define LM_BUILDER_PRINT__
+
+#include "lm/builder/ngram.hh"
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/header_info.hh"
+#include "util/file.hh"
+#include "util/mmap.hh"
+#include "util/string_piece.hh"
+
+#include <ostream>
+
+#include <assert.h>
+
+// Warning: print routines read all unigrams before all bigrams before all
+// trigrams etc. So if other parts of the chain move jointly, you'll have to
+// buffer.
+
+namespace lm { namespace builder {
+
+class VocabReconstitute {
+ public:
+ // fd must be alive for life of this object; does not take ownership.
+ explicit VocabReconstitute(int fd);
+
+ const char *Lookup(WordIndex index) const {
+ assert(index < map_.size() - 1);
+ return map_[index];
+ }
+
+ StringPiece LookupPiece(WordIndex index) const {
+ return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]);
+ }
+
+ std::size_t Size() const {
+ // There's an extra entry to support StringPiece lengths.
+ return map_.size() - 1;
+ }
+
+ private:
+ util::scoped_memory memory_;
+ std::vector<const char*> map_;
+};
+
+// Not defined, only specialized.
+template <class T> void PrintPayload(std::ostream &to, const Payload &payload);
+template <> inline void PrintPayload<uint64_t>(std::ostream &to, const Payload &payload) {
+ to << payload.count;
+}
+template <> inline void PrintPayload<Uninterpolated>(std::ostream &to, const Payload &payload) {
+ to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma);
+}
+template <> inline void PrintPayload<ProbBackoff>(std::ostream &to, const Payload &payload) {
+ to << payload.complete.prob << ' ' << payload.complete.backoff;
+}
+
+// template parameter is the type stored.
+template <class V> class Print {
+ public:
+ explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {}
+
+ void Run(const ChainPositions &chains) {
+ NGramStreams streams(chains);
+ for (NGramStream *s = streams.begin(); s != streams.end(); ++s) {
+ DumpStream(*s);
+ }
+ }
+
+ void Run(const util::stream::ChainPosition &position) {
+ NGramStream stream(position);
+ DumpStream(stream);
+ }
+
+ private:
+ void DumpStream(NGramStream &stream) {
+ for (; stream; ++stream) {
+ PrintPayload<V>(to_, stream->Value());
+ for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) {
+ to_ << ' ' << vocab_.Lookup(*w) << '=' << *w;
+ }
+ to_ << '\n';
+ }
+ }
+
+ const VocabReconstitute &vocab_;
+ std::ostream &to_;
+};
+
+class PrintARPA {
+ public:
+ // header_info may be NULL to disable the header
+ explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd);
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ const VocabReconstitute &vocab_;
+ int out_fd_;
+};
+
+}} // namespaces
+#endif // LM_BUILDER_PRINT__
diff --git a/klm/lm/builder/sort.hh b/klm/lm/builder/sort.hh
new file mode 100644
index 00000000..9989389b
--- /dev/null
+++ b/klm/lm/builder/sort.hh
@@ -0,0 +1,103 @@
+#ifndef LM_BUILDER_SORT__
+#define LM_BUILDER_SORT__
+
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/ngram.hh"
+#include "lm/word_index.hh"
+#include "util/stream/sort.hh"
+
+#include "util/stream/timer.hh"
+
+#include <functional>
+#include <string>
+
+namespace lm {
+namespace builder {
+
+template <class Child> class Comparator : public std::binary_function<const void *, const void *, bool> {
+ public:
+ explicit Comparator(std::size_t order) : order_(order) {}
+
+ inline bool operator()(const void *lhs, const void *rhs) const {
+ return static_cast<const Child*>(this)->Compare(static_cast<const WordIndex*>(lhs), static_cast<const WordIndex*>(rhs));
+ }
+
+ std::size_t Order() const { return order_; }
+
+ protected:
+ std::size_t order_;
+};
+
+class SuffixOrder : public Comparator<SuffixOrder> {
+ public:
+ explicit SuffixOrder(std::size_t order) : Comparator<SuffixOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (std::size_t i = order_ - 1; i != 0; --i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return lhs[0] < rhs[0];
+ }
+
+ static const unsigned kMatchOffset = 1;
+};
+
+class ContextOrder : public Comparator<ContextOrder> {
+ public:
+ explicit ContextOrder(std::size_t order) : Comparator<ContextOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (int i = order_ - 2; i >= 0; --i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return lhs[order_ - 1] < rhs[order_ - 1];
+ }
+};
+
+class PrefixOrder : public Comparator<PrefixOrder> {
+ public:
+ explicit PrefixOrder(std::size_t order) : Comparator<PrefixOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (std::size_t i = 0; i < order_; ++i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return false;
+ }
+
+ static const unsigned kMatchOffset = 0;
+};
+
+// Sum counts for the same n-gram.
+struct AddCombiner {
+ bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const {
+ NGram first(first_void, compare.Order());
+ // There isn't a const version of NGram.
+ NGram second(const_cast<void*>(second_void), compare.Order());
+ if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false;
+ first.Count() += second.Count();
+ return true;
+ }
+};
+
+// The combiner is only used on a single chain, so I didn't bother to allow
+// that template.
+template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > {
+ private:
+ typedef util::stream::Sort<Compare> S;
+ typedef FixedArray<S> P;
+
+ public:
+ void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) {
+ new (P::end()) S(chain, config, compare);
+ P::Constructed();
+ }
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_SORT__
diff --git a/klm/lm/config.cc b/klm/lm/config.cc
index f9d988ca..9520c41c 100644
--- a/klm/lm/config.cc
+++ b/klm/lm/config.cc
@@ -6,6 +6,7 @@ namespace lm {
namespace ngram {
Config::Config() :
+ show_progress(true),
messages(&std::cerr),
enumerate_vocab(NULL),
unknown_missing(COMPLAIN),
diff --git a/klm/lm/config.hh b/klm/lm/config.hh
index 739cee9c..0de7b7c6 100644
--- a/klm/lm/config.hh
+++ b/klm/lm/config.hh
@@ -11,46 +11,52 @@
/* Configuration for ngram model. Separate header to reduce pollution. */
namespace lm {
-
+
class EnumerateVocab;
namespace ngram {
struct Config {
- // EFFECTIVE FOR BOTH ARPA AND BINARY READS
+ // EFFECTIVE FOR BOTH ARPA AND BINARY READS
+
+ // (default true) print progress bar to messages
+ bool show_progress;
// Where to log messages including the progress bar. Set to NULL for
// silence.
std::ostream *messages;
+ std::ostream *ProgressMessages() const {
+ return show_progress ? messages : 0;
+ }
+
// This will be called with every string in the vocabulary. See
// enumerate_vocab.hh for more detail. Config does not take ownership; you
- // are still responsible for deleting it (or stack allocating).
+ // are still responsible for deleting it (or stack allocating).
EnumerateVocab *enumerate_vocab;
-
// ONLY EFFECTIVE WHEN READING ARPA
- // What to do when <unk> isn't in the provided model.
+ // What to do when <unk> isn't in the provided model.
WarningAction unknown_missing;
- // What to do when <s> or </s> is missing from the model.
- // If THROW_UP, the exception will be of type util::SpecialWordMissingException.
+ // What to do when <s> or </s> is missing from the model.
+ // If THROW_UP, the exception will be of type util::SpecialWordMissingException.
WarningAction sentence_marker_missing;
// What to do with a positive log probability. For COMPLAIN and SILENT, map
- // to 0.
+ // to 0.
WarningAction positive_log_probability;
- // The probability to substitute for <unk> if it's missing from the model.
+ // The probability to substitute for <unk> if it's missing from the model.
// No effect if the model has <unk> or unknown_missing == THROW_UP.
float unknown_missing_logprob;
// Size multiplier for probing hash table. Must be > 1. Space is linear in
// this. Time is probing_multiplier / (probing_multiplier - 1). No effect
- // for sorted variant.
+ // for sorted variant.
// If you find yourself setting this to a low number, consider using the
- // TrieModel which has lower memory consumption.
+ // TrieModel which has lower memory consumption.
float probing_multiplier;
// Amount of memory to use for building. The actual memory usage will be
@@ -58,10 +64,10 @@ struct Config {
// models.
std::size_t building_memory;
- // Template for temporary directory appropriate for passing to mkdtemp.
+ // Template for temporary directory appropriate for passing to mkdtemp.
// The characters XXXXXX are appended before passing to mkdtemp. Only
// applies to trie. If NULL, defaults to write_mmap. If that's NULL,
- // defaults to input file name.
+ // defaults to input file name.
const char *temporary_directory_prefix;
// Level of complaining to do when loading from ARPA instead of binary format.
@@ -69,49 +75,46 @@ struct Config {
ARPALoadComplain arpa_complain;
// While loading an ARPA file, also write out this binary format file. Set
- // to NULL to disable.
+ // to NULL to disable.
const char *write_mmap;
enum WriteMethod {
- WRITE_MMAP, // Map the file directly.
- WRITE_AFTER // Write after we're done.
+ WRITE_MMAP, // Map the file directly.
+ WRITE_AFTER // Write after we're done.
};
WriteMethod write_method;
- // Include the vocab in the binary file? Only effective if write_mmap != NULL.
+ // Include the vocab in the binary file? Only effective if write_mmap != NULL.
bool include_vocab;
- // Left rest options. Only used when the model includes rest costs.
+ // Left rest options. Only used when the model includes rest costs.
enum RestFunction {
REST_MAX, // Maximum of any score to the left
- REST_LOWER, // Use lower-order files given below.
+ REST_LOWER, // Use lower-order files given below.
};
RestFunction rest_function;
- // Only used for REST_LOWER.
+ // Only used for REST_LOWER.
std::vector<std::string> rest_lower_files;
-
// Quantization options. Only effective for QuantTrieModel. One value is
// reserved for each of prob and backoff, so 2^bits - 1 buckets will be used
- // to quantize (and one of the remaining backoffs will be 0).
+ // to quantize (and one of the remaining backoffs will be 0).
uint8_t prob_bits, backoff_bits;
// Bhiksha compression (simple form). Only works with trie.
uint8_t pointer_bhiksha_bits;
-
-
+
// ONLY EFFECTIVE WHEN READING BINARY
-
+
// How to get the giant array into memory: lazy mmap, populate, read etc.
- // See util/mmap.hh for details of MapMethod.
+ // See util/mmap.hh for details of MapMethod.
util::LoadMethod load_method;
-
- // Set defaults.
+ // Set defaults.
Config();
};
diff --git a/klm/lm/filter/arpa_io.cc b/klm/lm/filter/arpa_io.cc
new file mode 100644
index 00000000..f8568ac4
--- /dev/null
+++ b/klm/lm/filter/arpa_io.cc
@@ -0,0 +1,108 @@
+#include "lm/filter/arpa_io.hh"
+#include "util/file_piece.hh"
+
+#include <iostream>
+#include <ostream>
+#include <string>
+#include <vector>
+
+#include <ctype.h>
+#include <errno.h>
+#include <string.h>
+
+namespace lm {
+
+ARPAInputException::ARPAInputException(const StringPiece &message) throw() {
+ *this << message;
+}
+
+ARPAInputException::ARPAInputException(const StringPiece &message, const StringPiece &line) throw() {
+ *this << message << " in line " << line;
+}
+
+ARPAInputException::~ARPAInputException() throw() {}
+
+ARPAOutputException::ARPAOutputException(const char *message, const std::string &file_name) throw() {
+ *this << message << " in file " << file_name;
+}
+
+ARPAOutputException::~ARPAOutputException() throw() {}
+
+// Seeking is the responsibility of the caller.
+void WriteCounts(std::ostream &out, const std::vector<uint64_t> &number) {
+ out << "\n\\data\\\n";
+ for (unsigned int i = 0; i < number.size(); ++i) {
+ out << "ngram " << i+1 << "=" << number[i] << '\n';
+ }
+ out << '\n';
+}
+
+size_t SizeNeededForCounts(const std::vector<uint64_t> &number) {
+ std::ostringstream buf;
+ WriteCounts(buf, number);
+ return buf.tellp();
+}
+
+bool IsEntirelyWhiteSpace(const StringPiece &line) {
+ for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) {
+ if (!isspace(line.data()[i])) return false;
+ }
+ return true;
+}
+
+ARPAOutput::ARPAOutput(const char *name, size_t buffer_size) : file_name_(name), buffer_(new char[buffer_size]) {
+ try {
+ file_.exceptions(std::ostream::eofbit | std::ostream::failbit | std::ostream::badbit);
+ if (!file_.rdbuf()->pubsetbuf(buffer_.get(), buffer_size)) {
+ std::cerr << "Warning: could not enlarge buffer for " << name << std::endl;
+ buffer_.reset();
+ }
+ file_.open(name, std::ios::out | std::ios::binary);
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Opening", file_name_);
+ }
+}
+
+void ARPAOutput::ReserveForCounts(std::streampos reserve) {
+ try {
+ for (std::streampos i = 0; i < reserve; i += std::streampos(1)) {
+ file_ << '\n';
+ }
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing blanks to reserve space for counts to ", file_name_);
+ }
+}
+
+void ARPAOutput::BeginLength(unsigned int length) {
+ fast_counter_ = 0;
+ try {
+ file_ << '\\' << length << "-grams:" << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing n-gram header to ", file_name_);
+ }
+}
+
+void ARPAOutput::EndLength(unsigned int length) {
+ try {
+ file_ << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing blank at end of count list to ", file_name_);
+ }
+ if (length > counts_.size()) {
+ counts_.resize(length);
+ }
+ counts_[length - 1] = fast_counter_;
+}
+
+void ARPAOutput::Finish() {
+ try {
+ file_ << "\\end\\\n";
+ file_.seekp(0);
+ WriteCounts(file_, counts_);
+ file_ << std::flush;
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Finishing including writing counts at beginning to ", file_name_);
+ }
+}
+
+} // namespace lm
diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh
new file mode 100644
index 00000000..5b31620b
--- /dev/null
+++ b/klm/lm/filter/arpa_io.hh
@@ -0,0 +1,115 @@
+#ifndef LM_FILTER_ARPA_IO__
+#define LM_FILTER_ARPA_IO__
+/* Input and output for ARPA format language model files.
+ */
+#include "lm/read_arpa.hh"
+#include "util/exception.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/noncopyable.hpp>
+#include <boost/scoped_array.hpp>
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include <err.h>
+#include <string.h>
+#include <stdint.h>
+
+namespace util { class FilePiece; }
+
+namespace lm {
+
+class ARPAInputException : public util::Exception {
+ public:
+ explicit ARPAInputException(const StringPiece &message) throw();
+ explicit ARPAInputException(const StringPiece &message, const StringPiece &line) throw();
+ virtual ~ARPAInputException() throw();
+};
+
+class ARPAOutputException : public util::ErrnoException {
+ public:
+ ARPAOutputException(const char *prefix, const std::string &file_name) throw();
+ virtual ~ARPAOutputException() throw();
+
+ const std::string &File() const throw() { return file_name_; }
+
+ private:
+ const std::string file_name_;
+};
+
+// Handling for the counts of n-grams at the beginning of ARPA files.
+size_t SizeNeededForCounts(const std::vector<uint64_t> &number);
+
+/* Writes an ARPA file. This has to be seekable so the counts can be written
+ * at the end. Hence, I just have it own a std::fstream instead of accepting
+ * a separately held std::ostream. TODO: use the fast one from estimation.
+ */
+class ARPAOutput : boost::noncopyable {
+ public:
+ explicit ARPAOutput(const char *name, size_t buffer_size = 65536);
+
+ void ReserveForCounts(std::streampos reserve);
+
+ void BeginLength(unsigned int length);
+
+ void AddNGram(const StringPiece &line) {
+ try {
+ file_ << line << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing an n-gram", file_name_);
+ }
+ ++fast_counter_;
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ void EndLength(unsigned int length);
+
+ void Finish();
+
+ private:
+ const std::string file_name_;
+ boost::scoped_array<char> buffer_;
+ std::fstream file_;
+ size_t fast_counter_;
+ std::vector<uint64_t> counts_;
+};
+
+
+template <class Output> void ReadNGrams(util::FilePiece &in, unsigned int length, uint64_t number, Output &out) {
+ ReadNGramHeader(in, length);
+ out.BeginLength(length);
+ for (uint64_t i = 0; i < number; ++i) {
+ StringPiece line = in.ReadLine();
+ util::TokenIter<util::SingleCharacter> tabber(line, '\t');
+ if (!tabber) throw ARPAInputException("blank line", line);
+ if (!++tabber) throw ARPAInputException("no tab", line);
+
+ out.AddNGram(*tabber, line);
+ }
+ out.EndLength(length);
+}
+
+template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) {
+ std::vector<uint64_t> number;
+ ReadARPACounts(in_lm, number);
+ out.ReserveForCounts(SizeNeededForCounts(number));
+ for (unsigned int i = 0; i < number.size(); ++i) {
+ ReadNGrams(in_lm, i + 1, number[i], out);
+ }
+ ReadEnd(in_lm);
+ out.Finish();
+}
+
+} // namespace lm
+
+#endif // LM_FILTER_ARPA_IO__
diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh
new file mode 100644
index 00000000..97c0fa25
--- /dev/null
+++ b/klm/lm/filter/count_io.hh
@@ -0,0 +1,91 @@
+#ifndef LM_FILTER_COUNT_IO__
+#define LM_FILTER_COUNT_IO__
+
+#include <fstream>
+#include <iostream>
+#include <string>
+
+#include <err.h>
+
+#include "util/file_piece.hh"
+
+namespace lm {
+
+class CountOutput : boost::noncopyable {
+ public:
+ explicit CountOutput(const char *name) : file_(name, std::ios::out) {}
+
+ void AddNGram(const StringPiece &line) {
+ if (!(file_ << line << '\n')) {
+ err(3, "Writing counts file failed");
+ }
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ private:
+ std::fstream file_;
+};
+
+class CountBatch {
+ public:
+ explicit CountBatch(std::streamsize initial_read)
+ : initial_read_(initial_read) {
+ buffer_.reserve(initial_read);
+ }
+
+ void Read(std::istream &in) {
+ buffer_.resize(initial_read_);
+ in.read(&*buffer_.begin(), initial_read_);
+ buffer_.resize(in.gcount());
+ char got;
+ while (in.get(got) && got != '\n')
+ buffer_.push_back(got);
+ }
+
+ template <class Output> void Send(Output &out) {
+ for (util::TokenIter<util::SingleCharacter> line(StringPiece(&*buffer_.begin(), buffer_.size()), '\n'); line; ++line) {
+ util::TokenIter<util::SingleCharacter> tabber(*line, '\t');
+ if (!tabber) {
+ std::cerr << "Warning: empty n-gram count line being removed\n";
+ continue;
+ }
+ util::TokenIter<util::SingleCharacter, true> words(*tabber, ' ');
+ if (!words) {
+ std::cerr << "Line has a tab but no words.\n";
+ continue;
+ }
+ out.AddNGram(words, util::TokenIter<util::SingleCharacter, true>::end(), *line);
+ }
+ }
+
+ private:
+ std::streamsize initial_read_;
+
+ // This could have been a std::string but that's less happy with raw writes.
+ std::vector<char> buffer_;
+};
+
+template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) {
+ try {
+ while (true) {
+ StringPiece line = in_file.ReadLine();
+ util::TokenIter<util::SingleCharacter> tabber(line, '\t');
+ if (!tabber) {
+ std::cerr << "Warning: empty n-gram count line being removed\n";
+ continue;
+ }
+ out.AddNGram(*tabber, line);
+ }
+ } catch (const util::EndOfFileException &e) {}
+}
+
+} // namespace lm
+
+#endif // LM_FILTER_COUNT_IO__
diff --git a/klm/lm/filter/filter_main.cc b/klm/lm/filter/filter_main.cc
new file mode 100644
index 00000000..1a4ba84f
--- /dev/null
+++ b/klm/lm/filter/filter_main.cc
@@ -0,0 +1,248 @@
+#include "lm/filter/arpa_io.hh"
+#include "lm/filter/format.hh"
+#include "lm/filter/phrase.hh"
+#ifndef NTHREAD
+#include "lm/filter/thread.hh"
+#endif
+#include "lm/filter/vocab.hh"
+#include "lm/filter/wrapper.hh"
+#include "util/file_piece.hh"
+
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include <memory>
+
+namespace lm {
+namespace {
+
+void DisplayHelp(const char *name) {
+ std::cerr
+ << "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n"
+ "copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n"
+ " parser.\n"
+ "single mode treats the entire input as a single sentence.\n"
+ "multiple mode filters to multiple sentences in parallel. Each sentence is on\n"
+ " a separate line. A separate file is created for each file by appending the\n"
+ " 0-indexed line number to the output file name.\n"
+ "union mode produces one filtered model that is the union of models created by\n"
+ " multiple mode.\n\n"
+ "context means only the context (all but last word) has to pass the filter, but\n"
+ " the entire n-gram is output.\n\n"
+ "phrase means that the vocabulary is actually tab-delimited phrases and that the\n"
+ " phrases can generate the n-gram when assembled in arbitrary order and\n"
+ " clipped. Currently works with multiple or union mode.\n\n"
+ "The file format is set by [raw|arpa] with default arpa:\n"
+ "raw means space-separated tokens, optionally followed by a tab and arbitrary\n"
+ " text. This is useful for ngram count files.\n"
+ "arpa means the ARPA file format for n-gram language models.\n\n"
+#ifndef NTHREAD
+ "threads:m sets m threads (default: conccurrency detected by boost)\n"
+ "batch_size:m sets the batch size for threading. Expect memory usage from this\n"
+ " of 2*threads*batch_size n-grams.\n\n"
+#else
+ "This binary was compiled with -DNTHREAD, disabling threading. If you wanted\n"
+ " threading, compile without this flag against Boost >=1.42.0.\n\n"
+#endif
+ "There are two inputs: vocabulary and model. Either may be given as a file\n"
+ " while the other is on stdin. Specify the type given as a file using\n"
+ " vocab: or model: before the file name. \n\n"
+ "For ARPA format, the output must be seekable. For raw format, it can be a\n"
+ " stream i.e. /dev/stdout\n";
+}
+
+typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION, MODE_UNSET} FilterMode;
+typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format;
+
+struct Config {
+ Config() :
+#ifndef NTHREAD
+ batch_size(25000),
+ threads(boost::thread::hardware_concurrency()),
+#endif
+ phrase(false),
+ context(false),
+ format(FORMAT_ARPA)
+ {
+#ifndef NTHREAD
+ if (!threads) threads = 1;
+#endif
+ }
+
+#ifndef NTHREAD
+ size_t batch_size;
+ size_t threads;
+#endif
+ bool phrase;
+ bool context;
+ FilterMode mode;
+ Format format;
+};
+
+template <class Format, class Filter, class OutputBuffer, class Output> void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) {
+#ifndef NTHREAD
+ if (config.threads == 1) {
+#endif
+ Format::RunFilter(in_lm, filter, output);
+#ifndef NTHREAD
+ } else {
+ typedef Controller<Filter, OutputBuffer, Output> Threaded;
+ Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output);
+ Format::RunFilter(in_lm, threading, output);
+ }
+#endif
+}
+
+template <class Format, class Filter, class OutputBuffer, class Output> void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) {
+ if (config.context) {
+ ContextFilter<Filter> context_filter(filter);
+ RunThreadedFilter<Format, ContextFilter<Filter>, OutputBuffer, Output>(config, in_lm, context_filter, output);
+ } else {
+ RunThreadedFilter<Format, Filter, OutputBuffer, Output>(config, in_lm, filter, output);
+ }
+}
+
+template <class Format, class Binary> void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) {
+ typedef BinaryFilter<Binary> Filter;
+ RunContextFilter<Format, Filter, BinaryOutputBuffer, typename Format::Output>(config, in_lm, Filter(binary), out);
+}
+
+template <class Format> void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) {
+ if (config.mode == MODE_MULTIPLE) {
+ if (config.phrase) {
+ typedef phrase::Multiple Filter;
+ phrase::Substrings substrings;
+ typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings));
+ RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(substrings), out);
+ } else {
+ typedef vocab::Multiple Filter;
+ boost::unordered_map<std::string, std::vector<unsigned int> > words;
+ typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words));
+ RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(words), out);
+ }
+ return;
+ }
+
+ typename Format::Output out(out_name);
+
+ if (config.mode == MODE_COPY) {
+ Format::Copy(in_lm, out);
+ return;
+ }
+
+ if (config.mode == MODE_SINGLE) {
+ vocab::Single::Words words;
+ vocab::ReadSingle(in_vocab, words);
+ DispatchBinaryFilter<Format, vocab::Single>(config, in_lm, vocab::Single(words), out);
+ return;
+ }
+
+ if (config.mode == MODE_UNION) {
+ if (config.phrase) {
+ phrase::Substrings substrings;
+ phrase::ReadMultiple(in_vocab, substrings);
+ DispatchBinaryFilter<Format, phrase::Union>(config, in_lm, phrase::Union(substrings), out);
+ } else {
+ vocab::Union::Words words;
+ vocab::ReadMultiple(in_vocab, words);
+ DispatchBinaryFilter<Format, vocab::Union>(config, in_lm, vocab::Union(words), out);
+ }
+ return;
+ }
+}
+
+} // namespace
+} // namespace lm
+
+int main(int argc, char *argv[]) {
+ if (argc < 4) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+
+ // I used to have boost::program_options, but some users didn't want to compile boost.
+ lm::Config config;
+ config.mode = lm::MODE_UNSET;
+ for (int i = 1; i < argc - 2; ++i) {
+ const char *str = argv[i];
+ if (!std::strcmp(str, "copy")) {
+ config.mode = lm::MODE_COPY;
+ } else if (!std::strcmp(str, "single")) {
+ config.mode = lm::MODE_SINGLE;
+ } else if (!std::strcmp(str, "multiple")) {
+ config.mode = lm::MODE_MULTIPLE;
+ } else if (!std::strcmp(str, "union")) {
+ config.mode = lm::MODE_UNION;
+ } else if (!std::strcmp(str, "phrase")) {
+ config.phrase = true;
+ } else if (!std::strcmp(str, "context")) {
+ config.context = true;
+ } else if (!std::strcmp(str, "arpa")) {
+ config.format = lm::FORMAT_ARPA;
+ } else if (!std::strcmp(str, "raw")) {
+ config.format = lm::FORMAT_COUNT;
+#ifndef NTHREAD
+ } else if (!std::strncmp(str, "threads:", 8)) {
+ config.threads = boost::lexical_cast<size_t>(str + 8);
+ if (!config.threads) {
+ std::cerr << "Specify at least one thread." << std::endl;
+ return 1;
+ }
+ } else if (!std::strncmp(str, "batch_size:", 11)) {
+ config.batch_size = boost::lexical_cast<size_t>(str + 11);
+ if (config.batch_size < 5000) {
+ std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
+ if (!config.batch_size) return 1;
+ }
+#endif
+ } else {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+ }
+
+ if (config.mode == lm::MODE_UNSET) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+
+ if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) {
+ std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
+ return 1;
+ }
+
+ bool cmd_is_model = true;
+ const char *cmd_input = argv[argc - 2];
+ if (!strncmp(cmd_input, "vocab:", 6)) {
+ cmd_is_model = false;
+ cmd_input += 6;
+ } else if (!strncmp(cmd_input, "model:", 6)) {
+ cmd_input += 6;
+ } else if (strchr(cmd_input, ':')) {
+ errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input);
+ } else {
+ std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
+ }
+ std::ifstream cmd_file;
+ std::istream *vocab;
+ if (cmd_is_model) {
+ vocab = &std::cin;
+ } else {
+ cmd_file.open(cmd_input, std::ios::in);
+ if (!cmd_file) {
+ err(2, "Could not open input file %s", cmd_input);
+ }
+ vocab = &cmd_file;
+ }
+
+ util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
+
+ if (config.format == lm::FORMAT_ARPA) {
+ lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
+ } else if (config.format == lm::FORMAT_COUNT) {
+ lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ }
+ return 0;
+}
diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh
new file mode 100644
index 00000000..7f945b0d
--- /dev/null
+++ b/klm/lm/filter/format.hh
@@ -0,0 +1,250 @@
+#ifndef LM_FILTER_FORMAT_H__
+#define LM_FITLER_FORMAT_H__
+
+#include "lm/filter/arpa_io.hh"
+#include "lm/filter/count_io.hh"
+
+#include <boost/lexical_cast.hpp>
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <iosfwd>
+
+namespace lm {
+
+template <class Single> class MultipleOutput {
+ private:
+ typedef boost::ptr_vector<Single> Singles;
+ typedef typename Singles::iterator SinglesIterator;
+
+ public:
+ MultipleOutput(const char *prefix, size_t number) {
+ files_.reserve(number);
+ std::string tmp;
+ for (unsigned int i = 0; i < number; ++i) {
+ tmp = prefix;
+ tmp += boost::lexical_cast<std::string>(i);
+ files_.push_back(new Single(tmp.c_str()));
+ }
+ }
+
+ void AddNGram(const StringPiece &line) {
+ for (SinglesIterator i = files_.begin(); i != files_.end(); ++i)
+ i->AddNGram(line);
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ for (SinglesIterator i = files_.begin(); i != files_.end(); ++i)
+ i->AddNGram(begin, end, line);
+ }
+
+ void SingleAddNGram(size_t offset, const StringPiece &line) {
+ files_[offset].AddNGram(line);
+ }
+
+ template <class Iterator> void SingleAddNGram(size_t offset, const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ files_[offset].AddNGram(begin, end, line);
+ }
+
+ protected:
+ Singles files_;
+};
+
+class MultipleARPAOutput : public MultipleOutput<ARPAOutput> {
+ public:
+ MultipleARPAOutput(const char *prefix, size_t number) : MultipleOutput<ARPAOutput>(prefix, number) {}
+
+ void ReserveForCounts(std::streampos reserve) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->ReserveForCounts(reserve);
+ }
+
+ void BeginLength(unsigned int length) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->BeginLength(length);
+ }
+
+ void EndLength(unsigned int length) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->EndLength(length);
+ }
+
+ void Finish() {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->Finish();
+ }
+};
+
+template <class Filter, class Output> class DispatchInput {
+ public:
+ DispatchInput(Filter &filter, Output &output) : filter_(filter), output_(output) {}
+
+/* template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ filter_.AddNGram(begin, end, line, output_);
+ }*/
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ filter_.AddNGram(ngram, line, output_);
+ }
+
+ protected:
+ Filter &filter_;
+ Output &output_;
+};
+
+template <class Filter, class Output> class DispatchARPAInput : public DispatchInput<Filter, Output> {
+ private:
+ typedef DispatchInput<Filter, Output> B;
+
+ public:
+ DispatchARPAInput(Filter &filter, Output &output) : B(filter, output) {}
+
+ void ReserveForCounts(std::streampos reserve) { B::output_.ReserveForCounts(reserve); }
+ void BeginLength(unsigned int length) { B::output_.BeginLength(length); }
+
+ void EndLength(unsigned int length) {
+ B::filter_.Flush();
+ B::output_.EndLength(length);
+ }
+ void Finish() { B::output_.Finish(); }
+};
+
+struct ARPAFormat {
+ typedef ARPAOutput Output;
+ typedef MultipleARPAOutput Multiple;
+ static void Copy(util::FilePiece &in, Output &out) {
+ ReadARPA(in, out);
+ }
+ template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) {
+ DispatchARPAInput<Filter, Out> dispatcher(filter, output);
+ ReadARPA(in, dispatcher);
+ }
+};
+
+struct CountFormat {
+ typedef CountOutput Output;
+ typedef MultipleOutput<Output> Multiple;
+ static void Copy(util::FilePiece &in, Output &out) {
+ ReadCount(in, out);
+ }
+ template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) {
+ DispatchInput<Filter, Out> dispatcher(filter, output);
+ ReadCount(in, dispatcher);
+ }
+};
+
+/* For multithreading, the buffer classes hold batches of filter inputs and
+ * outputs in memory. The strings get reused a lot, so keep them around
+ * instead of clearing each time.
+ */
+class InputBuffer {
+ public:
+ InputBuffer() : actual_(0) {}
+
+ void Reserve(size_t size) { lines_.reserve(size); }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ if (lines_.size() == actual_) lines_.resize(lines_.size() + 1);
+ // TODO avoid this copy.
+ std::string &copied = lines_[actual_].line;
+ copied.assign(line.data(), line.size());
+ lines_[actual_].ngram.set(copied.data() + (ngram.data() - line.data()), ngram.size());
+ ++actual_;
+ }
+
+ template <class Filter, class Output> void CallFilter(Filter &filter, Output &output) const {
+ for (std::vector<Line>::const_iterator i = lines_.begin(); i != lines_.begin() + actual_; ++i) {
+ filter.AddNGram(i->ngram, i->line, output);
+ }
+ }
+
+ void Clear() { actual_ = 0; }
+ bool Empty() { return actual_ == 0; }
+ size_t Size() { return actual_; }
+
+ private:
+ struct Line {
+ std::string line;
+ StringPiece ngram;
+ };
+
+ size_t actual_;
+
+ std::vector<Line> lines_;
+};
+
+class BinaryOutputBuffer {
+ public:
+ BinaryOutputBuffer() {}
+
+ void Reserve(size_t size) {
+ lines_.reserve(size);
+ }
+
+ void AddNGram(const StringPiece &line) {
+ lines_.push_back(line);
+ }
+
+ template <class Output> void Flush(Output &output) {
+ for (std::vector<StringPiece>::const_iterator i = lines_.begin(); i != lines_.end(); ++i) {
+ output.AddNGram(*i);
+ }
+ lines_.clear();
+ }
+
+ private:
+ std::vector<StringPiece> lines_;
+};
+
+class MultipleOutputBuffer {
+ public:
+ MultipleOutputBuffer() : last_(NULL) {}
+
+ void Reserve(size_t size) {
+ annotated_.reserve(size);
+ }
+
+ void AddNGram(const StringPiece &line) {
+ annotated_.resize(annotated_.size() + 1);
+ annotated_.back().line = line;
+ }
+
+ void SingleAddNGram(size_t offset, const StringPiece &line) {
+ if ((line.data() == last_.data()) && (line.length() == last_.length())) {
+ annotated_.back().systems.push_back(offset);
+ } else {
+ annotated_.resize(annotated_.size() + 1);
+ annotated_.back().systems.push_back(offset);
+ annotated_.back().line = line;
+ last_ = line;
+ }
+ }
+
+ template <class Output> void Flush(Output &output) {
+ for (std::vector<Annotated>::const_iterator i = annotated_.begin(); i != annotated_.end(); ++i) {
+ if (i->systems.empty()) {
+ output.AddNGram(i->line);
+ } else {
+ for (std::vector<size_t>::const_iterator j = i->systems.begin(); j != i->systems.end(); ++j) {
+ output.SingleAddNGram(*j, i->line);
+ }
+ }
+ }
+ annotated_.clear();
+ }
+
+ private:
+ struct Annotated {
+ // If this is empty, send to all systems.
+ // A filter should never send to all systems and send to a single one.
+ std::vector<size_t> systems;
+ StringPiece line;
+ };
+
+ StringPiece last_;
+
+ std::vector<Annotated> annotated_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_FORMAT_H__
diff --git a/klm/lm/filter/phrase.cc b/klm/lm/filter/phrase.cc
new file mode 100644
index 00000000..1bef2a3f
--- /dev/null
+++ b/klm/lm/filter/phrase.cc
@@ -0,0 +1,281 @@
+#include "lm/filter/phrase.hh"
+
+#include "lm/filter/format.hh"
+
+#include <algorithm>
+#include <functional>
+#include <iostream>
+#include <queue>
+#include <string>
+#include <vector>
+
+#include <ctype.h>
+
+namespace lm {
+namespace phrase {
+
+unsigned int ReadMultiple(std::istream &in, Substrings &out) {
+ bool sentence_content = false;
+ unsigned int sentence_id = 0;
+ std::vector<Hash> phrase;
+ std::string word;
+ while (in) {
+ char c;
+ // Gather a word.
+ while (!isspace(c = in.get()) && in) word += c;
+ // Treat EOF like a newline.
+ if (!in) c = '\n';
+ // Add the word to the phrase.
+ if (!word.empty()) {
+ phrase.push_back(util::MurmurHashNative(word.data(), word.size()));
+ word.clear();
+ }
+ if (c == ' ') continue;
+ // It's more than just a space. Close out the phrase.
+ if (!phrase.empty()) {
+ sentence_content = true;
+ out.AddPhrase(sentence_id, phrase.begin(), phrase.end());
+ phrase.clear();
+ }
+ if (c == '\t' || c == '\v') continue;
+ // It's more than a space or tab: a newline.
+ if (sentence_content) {
+ ++sentence_id;
+ sentence_content = false;
+ }
+ }
+ if (!in.eof()) in.exceptions(std::istream::failbit | std::istream::badbit);
+ return sentence_id + sentence_content;
+}
+
+namespace detail { const StringPiece kEndSentence("</s>"); }
+
+namespace {
+
+typedef unsigned int Sentence;
+typedef std::vector<Sentence> Sentences;
+
+class Vertex;
+
+class Arc {
+ public:
+ Arc() {}
+
+ // For arcs from one vertex to another.
+ void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) {
+ Set(to, intersect);
+ from_ = &from;
+ }
+
+ /* For arcs from before the n-gram begins to somewhere in the n-gram (right
+ * aligned). These have no from_ vertex; it implictly matches every
+ * sentence. This also handles when the n-gram is a substring of a phrase.
+ */
+ void SetRight(Vertex &to, const Sentences &complete) {
+ Set(to, complete);
+ from_ = NULL;
+ }
+
+ Sentence Current() const {
+ return *current_;
+ }
+
+ bool Empty() const {
+ return current_ == last_;
+ }
+
+ /* When this function returns:
+ * If Empty() then there's nothing left from this intersection.
+ *
+ * If Current() == to then to is part of the intersection.
+ *
+ * Otherwise, Current() > to. In this case, to is not part of the
+ * intersection and neither is anything < Current(). To determine if
+ * any value >= Current() is in the intersection, call LowerBound again
+ * with the value.
+ */
+ void LowerBound(const Sentence to);
+
+ private:
+ void Set(Vertex &to, const Sentences &sentences);
+
+ const Sentence *current_;
+ const Sentence *last_;
+ Vertex *from_;
+};
+
+struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> {
+ bool operator()(const Arc *first, const Arc *second) const {
+ return first->Current() > second->Current();
+ }
+};
+
+class Vertex {
+ public:
+ Vertex() : current_(0) {}
+
+ Sentence Current() const {
+ return current_;
+ }
+
+ bool Empty() const {
+ return incoming_.empty();
+ }
+
+ void LowerBound(const Sentence to);
+
+ private:
+ friend class Arc;
+
+ void AddIncoming(Arc *arc) {
+ if (!arc->Empty()) incoming_.push(arc);
+ }
+
+ unsigned int current_;
+ std::priority_queue<Arc*, std::vector<Arc*>, ArcGreater> incoming_;
+};
+
+void Arc::LowerBound(const Sentence to) {
+ current_ = std::lower_bound(current_, last_, to);
+ // If *current_ > to, don't advance from_. The intervening values of
+ // from_ may be useful for another one of its outgoing arcs.
+ if (!from_ || Empty() || (Current() > to)) return;
+ assert(Current() == to);
+ from_->LowerBound(to);
+ if (from_->Empty()) {
+ current_ = last_;
+ return;
+ }
+ assert(from_->Current() >= to);
+ if (from_->Current() > to) {
+ current_ = std::lower_bound(current_ + 1, last_, from_->Current());
+ }
+}
+
+void Arc::Set(Vertex &to, const Sentences &sentences) {
+ current_ = &*sentences.begin();
+ last_ = &*sentences.end();
+ to.AddIncoming(this);
+}
+
+void Vertex::LowerBound(const Sentence to) {
+ if (Empty()) return;
+ // Union lower bound.
+ while (true) {
+ Arc *top = incoming_.top();
+ if (top->Current() > to) {
+ current_ = top->Current();
+ return;
+ }
+ // If top->Current() == to, we still need to verify that's an actual
+ // element and not just a bound.
+ incoming_.pop();
+ top->LowerBound(to);
+ if (!top->Empty()) {
+ incoming_.push(top);
+ if (top->Current() == to) {
+ current_ = to;
+ return;
+ }
+ } else if (Empty()) {
+ return;
+ }
+ }
+}
+
+void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Vertex *const vertices, Arc *free_arc) {
+ assert(!hashes.empty());
+
+ const Hash *const first_word = &*hashes.begin();
+ const Hash *const last_word = &*hashes.end() - 1;
+
+ Hash hash = 0;
+ const Sentences *found;
+ // Phrases starting at or before the first word in the n-gram.
+ {
+ Vertex *vertex = vertices;
+ for (const Hash *word = first_word; ; ++word, ++vertex) {
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word);
+ // Now hash is [hashes.begin(), word].
+ if (word == last_word) {
+ if (phrase.FindSubstring(hash, found))
+ (free_arc++)->SetRight(*vertex, *found);
+ break;
+ }
+ if (!phrase.FindRight(hash, found)) break;
+ (free_arc++)->SetRight(*vertex, *found);
+ }
+ }
+
+ // Phrases starting at the second or later word in the n-gram.
+ Vertex *vertex_from = vertices;
+ for (const Hash *word_from = first_word + 1; word_from != &*hashes.end(); ++word_from, ++vertex_from) {
+ hash = 0;
+ Vertex *vertex_to = vertex_from + 1;
+ for (const Hash *word_to = word_from; ; ++word_to, ++vertex_to) {
+ // Notice that word_to and vertex_to have the same index.
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word_to);
+ // Now hash covers [word_from, word_to].
+ if (word_to == last_word) {
+ if (phrase.FindLeft(hash, found))
+ (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
+ break;
+ }
+ if (!phrase.FindPhrase(hash, found)) break;
+ (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
+ }
+ }
+}
+
+} // namespace
+
+namespace detail {
+
+} // namespace detail
+
+bool Union::Evaluate() {
+ assert(!hashes_.empty());
+ // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
+ Vertex vertices[hashes_.size()];
+ // One for every substring.
+ Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
+ BuildGraph(substrings_, hashes_, vertices, arcs);
+ Vertex &last_vertex = vertices[hashes_.size() - 1];
+
+ unsigned int lower = 0;
+ while (true) {
+ last_vertex.LowerBound(lower);
+ if (last_vertex.Empty()) return false;
+ if (last_vertex.Current() == lower) return true;
+ lower = last_vertex.Current();
+ }
+}
+
+template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) {
+ assert(!hashes_.empty());
+ // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
+ Vertex vertices[hashes_.size()];
+ // One for every substring.
+ Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
+ BuildGraph(substrings_, hashes_, vertices, arcs);
+ Vertex &last_vertex = vertices[hashes_.size() - 1];
+
+ unsigned int lower = 0;
+ while (true) {
+ last_vertex.LowerBound(lower);
+ if (last_vertex.Empty()) return;
+ if (last_vertex.Current() == lower) {
+ output.SingleAddNGram(lower, line);
+ ++lower;
+ } else {
+ lower = last_vertex.Current();
+ }
+ }
+}
+
+template void Multiple::Evaluate<CountFormat::Multiple>(const StringPiece &line, CountFormat::Multiple &output);
+template void Multiple::Evaluate<ARPAFormat::Multiple>(const StringPiece &line, ARPAFormat::Multiple &output);
+template void Multiple::Evaluate<MultipleOutputBuffer>(const StringPiece &line, MultipleOutputBuffer &output);
+
+} // namespace phrase
+} // namespace lm
diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh
new file mode 100644
index 00000000..b4edff41
--- /dev/null
+++ b/klm/lm/filter/phrase.hh
@@ -0,0 +1,154 @@
+#ifndef LM_FILTER_PHRASE_H__
+#define LM_FILTER_PHRASE_H__
+
+#include "util/murmur_hash.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/unordered_map.hpp>
+
+#include <iosfwd>
+#include <vector>
+
+#define LM_FILTER_PHRASE_METHOD(caps, lower) \
+bool Find##caps(Hash key, const std::vector<unsigned int> *&out) const {\
+ Table::const_iterator i(table_.find(key));\
+ if (i==table_.end()) return false; \
+ out = &i->second.lower; \
+ return true; \
+}
+
+namespace lm {
+namespace phrase {
+
+typedef uint64_t Hash;
+
+class Substrings {
+ private:
+ /* This is the value in a hash table where the key is a string. It indicates
+ * four sets of sentences:
+ * substring is sentences with a phrase containing the key as a substring.
+ * left is sentencess with a phrase that begins with the key (left aligned).
+ * right is sentences with a phrase that ends with the key (right aligned).
+ * phrase is sentences where the key is a phrase.
+ * Each set is encoded as a vector of sentence ids in increasing order.
+ */
+ struct SentenceRelation {
+ std::vector<unsigned int> substring, left, right, phrase;
+ };
+ /* Most of the CPU is hash table lookups, so let's not complicate it with
+ * vector equality comparisons. If a collision happens, the SentenceRelation
+ * structure will contain the union of sentence ids over the colliding strings.
+ * In that case, the filter will be slightly more permissive.
+ * The key here is the same as boost's hash of std::vector<std::string>.
+ */
+ typedef boost::unordered_map<Hash, SentenceRelation> Table;
+
+ public:
+ Substrings() {}
+
+ /* If the string isn't a substring of any phrase, return NULL. Otherwise,
+ * return a pointer to std::vector<unsigned int> listing sentences with
+ * matching phrases. This set may be empty for Left, Right, or Phrase.
+ * Example: const std::vector<unsigned int> *FindSubstring(Hash key)
+ */
+ LM_FILTER_PHRASE_METHOD(Substring, substring)
+ LM_FILTER_PHRASE_METHOD(Left, left)
+ LM_FILTER_PHRASE_METHOD(Right, right)
+ LM_FILTER_PHRASE_METHOD(Phrase, phrase)
+
+#pragma GCC diagnostic ignored "-Wuninitialized" // end != finish so there's always an initialization
+ // sentence_id must be non-decreasing. Iterators are over words in the phrase.
+ template <class Iterator> void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) {
+ // Iterate over all substrings.
+ for (Iterator start = begin; start != end; ++start) {
+ Hash hash = 0;
+ SentenceRelation *relation;
+ for (Iterator finish = start; finish != end; ++finish) {
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish);
+ // Now hash is of [start, finish].
+ relation = &table_[hash];
+ AppendSentence(relation->substring, sentence_id);
+ if (start == begin) AppendSentence(relation->left, sentence_id);
+ }
+ AppendSentence(relation->right, sentence_id);
+ if (start == begin) AppendSentence(relation->phrase, sentence_id);
+ }
+ }
+
+ private:
+ void AppendSentence(std::vector<unsigned int> &vec, unsigned int sentence_id) {
+ if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id);
+ }
+
+ Table table_;
+};
+
+// Read a file with one sentence per line containing tab-delimited phrases of
+// space-separated words.
+unsigned int ReadMultiple(std::istream &in, Substrings &out);
+
+namespace detail {
+extern const StringPiece kEndSentence;
+
+template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::vector<Hash> &hashes) {
+ hashes.clear();
+ if (i == end) return;
+ // TODO: check strict phrase boundaries after <s> and before </s>. For now, just skip tags.
+ if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) {
+ ++i;
+ }
+ for (; i != end && (*i != kEndSentence); ++i) {
+ hashes.push_back(util::MurmurHashNative(i->data(), i->size()));
+ }
+}
+
+} // namespace detail
+
+class Union {
+ public:
+ explicit Union(const Substrings &substrings) : substrings_(substrings) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ detail::MakeHashes(begin, end, hashes_);
+ return hashes_.empty() || Evaluate();
+ }
+
+ private:
+ bool Evaluate();
+
+ std::vector<Hash> hashes_;
+
+ const Substrings &substrings_;
+};
+
+class Multiple {
+ public:
+ explicit Multiple(const Substrings &substrings) : substrings_(substrings) {}
+
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ detail::MakeHashes(begin, end, hashes_);
+ if (hashes_.empty()) {
+ output.AddNGram(line);
+ return;
+ }
+ Evaluate(line, output);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ template <class Output> void Evaluate(const StringPiece &line, Output &output);
+
+ std::vector<Hash> hashes_;
+
+ const Substrings &substrings_;
+};
+
+} // namespace phrase
+} // namespace lm
+#endif // LM_FILTER_PHRASE_H__
diff --git a/klm/lm/filter/thread.hh b/klm/lm/filter/thread.hh
new file mode 100644
index 00000000..e785b263
--- /dev/null
+++ b/klm/lm/filter/thread.hh
@@ -0,0 +1,167 @@
+#ifndef LM_FILTER_THREAD_H__
+#define LM_FILTER_THREAD_H__
+
+#include "util/thread_pool.hh"
+
+#include <boost/utility/in_place_factory.hpp>
+
+#include <deque>
+#include <stack>
+
+namespace lm {
+
+template <class OutputBuffer> class ThreadBatch {
+ public:
+ ThreadBatch() {}
+
+ void Reserve(size_t size) {
+ input_.Reserve(size);
+ output_.Reserve(size);
+ }
+
+ // File reading thread.
+ InputBuffer &Fill(uint64_t sequence) {
+ sequence_ = sequence;
+ // Why wait until now to clear instead of after output? free in the same
+ // thread as allocated.
+ input_.Clear();
+ return input_;
+ }
+
+ // Filter worker thread.
+ template <class Filter> void CallFilter(Filter &filter) {
+ input_.CallFilter(filter, output_);
+ }
+
+ uint64_t Sequence() const { return sequence_; }
+
+ // File writing thread.
+ template <class RealOutput> void Flush(RealOutput &output) {
+ output_.Flush(output);
+ }
+
+ private:
+ InputBuffer input_;
+ OutputBuffer output_;
+
+ uint64_t sequence_;
+};
+
+template <class Batch, class Filter> class FilterWorker {
+ public:
+ typedef Batch *Request;
+
+ FilterWorker(const Filter &filter, util::PCQueue<Request> &done) : filter_(filter), done_(done) {}
+
+ void operator()(Request request) {
+ request->CallFilter(filter_);
+ done_.Produce(request);
+ }
+
+ private:
+ Filter filter_;
+
+ util::PCQueue<Request> &done_;
+};
+
+// There should only be one OutputWorker.
+template <class Batch, class Output> class OutputWorker {
+ public:
+ typedef Batch *Request;
+
+ OutputWorker(Output &output, util::PCQueue<Request> &done) : output_(output), done_(done), base_sequence_(0) {}
+
+ void operator()(Request request) {
+ assert(request->Sequence() >= base_sequence_);
+ // Assemble the output in order.
+ uint64_t pos = request->Sequence() - base_sequence_;
+ if (pos >= ordering_.size()) {
+ ordering_.resize(pos + 1, NULL);
+ }
+ ordering_[pos] = request;
+ while (!ordering_.empty() && ordering_.front()) {
+ ordering_.front()->Flush(output_);
+ done_.Produce(ordering_.front());
+ ordering_.pop_front();
+ ++base_sequence_;
+ }
+ }
+
+ private:
+ Output &output_;
+
+ util::PCQueue<Request> &done_;
+
+ std::deque<Request> ordering_;
+
+ uint64_t base_sequence_;
+};
+
+template <class Filter, class OutputBuffer, class RealOutput> class Controller : boost::noncopyable {
+ private:
+ typedef ThreadBatch<OutputBuffer> Batch;
+
+ public:
+ Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output)
+ : batch_size_(batch_size), queue_size_(queue),
+ batches_(queue),
+ to_read_(queue),
+ output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL),
+ filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL),
+ sequence_(0) {
+ for (size_t i = 0; i < queue; ++i) {
+ batches_[i].Reserve(batch_size);
+ local_read_.push(&batches_[i]);
+ }
+ NewInput();
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) {
+ input_->AddNGram(ngram, line, output);
+ if (input_->Size() == batch_size_) {
+ FlushInput();
+ NewInput();
+ }
+ }
+
+ void Flush() {
+ FlushInput();
+ while (local_read_.size() < queue_size_) {
+ MoveRead();
+ }
+ NewInput();
+ }
+
+ private:
+ void FlushInput() {
+ if (input_->Empty()) return;
+ filter_.Produce(local_read_.top());
+ local_read_.pop();
+ if (local_read_.empty()) MoveRead();
+ }
+
+ void NewInput() {
+ input_ = &local_read_.top()->Fill(sequence_++);
+ }
+
+ void MoveRead() {
+ local_read_.push(to_read_.Consume());
+ }
+
+ const size_t batch_size_;
+ const size_t queue_size_;
+
+ std::vector<Batch> batches_;
+
+ util::PCQueue<Batch*> to_read_;
+ std::stack<Batch*> local_read_;
+ util::ThreadPool<OutputWorker<Batch, RealOutput> > output_;
+ util::ThreadPool<FilterWorker<Batch, Filter> > filter_;
+
+ uint64_t sequence_;
+ InputBuffer *input_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_THREAD_H__
diff --git a/klm/lm/filter/vocab.cc b/klm/lm/filter/vocab.cc
new file mode 100644
index 00000000..7ee4e84b
--- /dev/null
+++ b/klm/lm/filter/vocab.cc
@@ -0,0 +1,54 @@
+#include "lm/filter/vocab.hh"
+
+#include <istream>
+#include <iostream>
+
+#include <ctype.h>
+#include <err.h>
+
+namespace lm {
+namespace vocab {
+
+void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out) {
+ in.exceptions(std::istream::badbit);
+ std::string word;
+ while (in >> word) {
+ out.insert(word);
+ }
+}
+
+namespace {
+bool IsLineEnd(std::istream &in) {
+ int got;
+ do {
+ got = in.get();
+ if (!in) return true;
+ if (got == '\n') return true;
+ } while (isspace(got));
+ in.unget();
+ return false;
+}
+}// namespace
+
+// Read space separated words in enter separated lines. These lines can be
+// very long, so don't read an entire line at a time.
+unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out) {
+ in.exceptions(std::istream::badbit);
+ unsigned int sentence = 0;
+ bool used_id = false;
+ std::string word;
+ while (in >> word) {
+ used_id = true;
+ std::vector<unsigned int> &posting = out[word];
+ if (posting.empty() || (posting.back() != sentence))
+ posting.push_back(sentence);
+ if (IsLineEnd(in)) {
+ ++sentence;
+ used_id = false;
+ }
+ }
+ return sentence + used_id;
+}
+
+} // namespace vocab
+} // namespace lm
diff --git a/klm/lm/filter/vocab.hh b/klm/lm/filter/vocab.hh
new file mode 100644
index 00000000..7f0fadaa
--- /dev/null
+++ b/klm/lm/filter/vocab.hh
@@ -0,0 +1,133 @@
+#ifndef LM_FILTER_VOCAB_H__
+#define LM_FILTER_VOCAB_H__
+
+// Vocabulary-based filters for language models.
+
+#include "util/multi_intersection.hh"
+#include "util/string_piece.hh"
+#include "util/string_piece_hash.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/noncopyable.hpp>
+#include <boost/range/iterator_range.hpp>
+#include <boost/unordered/unordered_map.hpp>
+#include <boost/unordered/unordered_set.hpp>
+
+#include <string>
+#include <vector>
+
+namespace lm {
+namespace vocab {
+
+void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out);
+
+// Read one sentence vocabulary per line. Return the number of sentences.
+unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out);
+
+/* Is this a special tag like <s> or <UNK>? This actually includes anything
+ * surrounded with < and >, which most tokenizers separate for real words, so
+ * this should not catch real words as it looks at a single token.
+ */
+inline bool IsTag(const StringPiece &value) {
+ // The parser should never give an empty string.
+ assert(!value.empty());
+ return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>');
+}
+
+class Single {
+ public:
+ typedef boost::unordered_set<std::string> Words;
+
+ explicit Single(const Words &vocab) : vocab_(vocab) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ for (Iterator i = begin; i != end; ++i) {
+ if (IsTag(*i)) continue;
+ if (FindStringPiece(vocab_, *i) == vocab_.end()) return false;
+ }
+ return true;
+ }
+
+ private:
+ const Words &vocab_;
+};
+
+class Union {
+ public:
+ typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
+
+ explicit Union(const Words &vocabs) : vocabs_(vocabs) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ sets_.clear();
+
+ for (Iterator i(begin); i != end; ++i) {
+ if (IsTag(*i)) continue;
+ Words::const_iterator found(FindStringPiece(vocabs_, *i));
+ if (vocabs_.end() == found) return false;
+ sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
+ }
+ return (sets_.empty() || util::FirstIntersection(sets_));
+ }
+
+ private:
+ const Words &vocabs_;
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets_;
+};
+
+class Multiple {
+ public:
+ typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
+
+ Multiple(const Words &vocabs) : vocabs_(vocabs) {}
+
+ private:
+ // Callback from AllIntersection that does AddNGram.
+ template <class Output> class Callback {
+ public:
+ Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {}
+
+ void operator()(unsigned int index) {
+ out_.SingleAddNGram(index, line_);
+ }
+
+ private:
+ Output &out_;
+ const StringPiece &line_;
+ };
+
+ public:
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ sets_.clear();
+ for (Iterator i(begin); i != end; ++i) {
+ if (IsTag(*i)) continue;
+ Words::const_iterator found(FindStringPiece(vocabs_, *i));
+ if (vocabs_.end() == found) return;
+ sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
+ }
+ if (sets_.empty()) {
+ output.AddNGram(line);
+ return;
+ }
+
+ Callback<Output> cb(output, line);
+ util::AllIntersection(sets_, cb);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ const Words &vocabs_;
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets_;
+};
+
+} // namespace vocab
+} // namespace lm
+
+#endif // LM_FILTER_VOCAB_H__
diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh
new file mode 100644
index 00000000..90b07a08
--- /dev/null
+++ b/klm/lm/filter/wrapper.hh
@@ -0,0 +1,58 @@
+#ifndef LM_FILTER_WRAPPER_H__
+#define LM_FILTER_WRAPPER_H__
+
+#include "util/string_piece.hh"
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+namespace lm {
+
+// Provide a single-output filter with the same interface as a
+// multiple-output filter so clients code against one interface.
+template <class Binary> class BinaryFilter {
+ public:
+ // Binary modes are just references (and a set) and it makes the API cleaner to copy them.
+ explicit BinaryFilter(Binary binary) : binary_(binary) {}
+
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ if (binary_.PassNGram(begin, end))
+ output.AddNGram(line);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ Binary binary_;
+};
+
+// Wrap another filter to pay attention only to context words
+template <class FilterT> class ContextFilter {
+ public:
+ typedef FilterT Filter;
+
+ explicit ContextFilter(Filter &backend) : backend_(backend) {}
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ pieces_.clear();
+ // TODO: this copy could be avoided by a lookahead iterator.
+ std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_));
+ backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ std::vector<StringPiece> pieces_;
+
+ Filter backend_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_WRAPPER_H__
diff --git a/klm/lm/fragment_main.cc b/klm/lm/fragment_main.cc
new file mode 100644
index 00000000..0267cd4e
--- /dev/null
+++ b/klm/lm/fragment_main.cc
@@ -0,0 +1,37 @@
+#include "lm/binary_format.hh"
+#include "lm/model.hh"
+#include "lm/left.hh"
+#include "util/tokenize_piece.hh"
+
+template <class Model> void Query(const char *name) {
+ Model model(name);
+ std::string line;
+ lm::ngram::ChartState ignored;
+ while (getline(std::cin, line)) {
+ lm::ngram::RuleScore<Model> scorer(model, ignored);
+ for (util::TokenIter<util::SingleCharacter, true> i(line, ' '); i; ++i) {
+ scorer.Terminal(model.GetVocabulary().Index(*i));
+ }
+ std::cout << scorer.Finish() << '\n';
+ }
+}
+
+int main(int argc, char *argv[]) {
+ if (argc != 2) {
+ std::cerr << "Expected model file name." << std::endl;
+ return 1;
+ }
+ const char *name = argv[1];
+ lm::ngram::ModelType model_type = lm::ngram::PROBING;
+ lm::ngram::RecognizeBinary(name, model_type);
+ switch (model_type) {
+ case lm::ngram::PROBING:
+ Query<lm::ngram::ProbingModel>(name);
+ break;
+ case lm::ngram::REST_PROBING:
+ Query<lm::ngram::RestProbingModel>(name);
+ break;
+ default:
+ std::cerr << "Model type not supported yet." << std::endl;
+ }
+}
diff --git a/klm/lm/max_order.cc b/klm/lm/kenlm_max_order_main.cc
index 94221201..94221201 100644
--- a/klm/lm/max_order.cc
+++ b/klm/lm/kenlm_max_order_main.cc
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index 8c27232e..85c1ea37 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -51,36 +51,36 @@ namespace ngram {
template <class M> class RuleScore {
public:
- explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) {
+ explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) {
out.left.length = 0;
out.right.length = 0;
}
void BeginSentence() {
- out_.right = model_.BeginSentenceState();
- // out_.left is empty.
+ out_->right = model_.BeginSentenceState();
+ // out_->left is empty.
left_done_ = true;
}
void Terminal(WordIndex word) {
- State copy(out_.right);
- FullScoreReturn ret(model_.FullScore(copy, word, out_.right));
+ State copy(out_->right);
+ FullScoreReturn ret(model_.FullScore(copy, word, out_->right));
if (left_done_) { prob_ += ret.prob; return; }
if (ret.independent_left) {
prob_ += ret.prob;
left_done_ = true;
return;
}
- out_.left.pointers[out_.left.length++] = ret.extend_left;
+ out_->left.pointers[out_->left.length++] = ret.extend_left;
prob_ += ret.rest;
- if (out_.right.length != copy.length + 1)
+ if (out_->right.length != copy.length + 1)
left_done_ = true;
}
// Faster version of NonTerminal for the case where the rule begins with a non-terminal.
void BeginNonTerminal(const ChartState &in, float prob = 0.0) {
prob_ = prob;
- out_ = in;
+ *out_ = in;
left_done_ = in.left.full;
}
@@ -89,23 +89,23 @@ template <class M> class RuleScore {
if (!in.left.length) {
if (in.left.full) {
- for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i;
+ for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i;
left_done_ = true;
- out_.right = in.right;
+ out_->right = in.right;
}
return;
}
- if (!out_.right.length) {
- out_.right = in.right;
+ if (!out_->right.length) {
+ out_->right = in.right;
if (left_done_) {
prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);
return;
}
- if (out_.left.length) {
+ if (out_->left.length) {
left_done_ = true;
} else {
- out_.left = in.left;
+ out_->left = in.left;
left_done_ = in.left.full;
}
return;
@@ -113,10 +113,10 @@ template <class M> class RuleScore {
float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1];
float *back = backoffs, *back2 = backoffs2;
- unsigned char next_use = out_.right.length;
+ unsigned char next_use = out_->right.length;
// First word
- if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return;
+ if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return;
// Words after the first, so extending a bigram to begin with
for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) {
@@ -127,54 +127,58 @@ template <class M> class RuleScore {
if (in.left.full) {
for (const float *i = back; i != back + next_use; ++i) prob_ += *i;
left_done_ = true;
- out_.right = in.right;
+ out_->right = in.right;
return;
}
// Right state was minimized, so it's already independent of the new words to the left.
if (in.right.length < in.left.length) {
- out_.right = in.right;
+ out_->right = in.right;
return;
}
// Shift exisiting words down.
- for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) {
+ for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) {
*(i + in.right.length) = *i;
}
// Add words from in.right.
- std::copy(in.right.words, in.right.words + in.right.length, out_.right.words);
+ std::copy(in.right.words, in.right.words + in.right.length, out_->right.words);
// Assemble backoff composed on the existing state's backoff followed by the new state's backoff.
- std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff);
- std::copy(back, back + next_use, out_.right.backoff + in.right.length);
- out_.right.length = in.right.length + next_use;
+ std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff);
+ std::copy(back, back + next_use, out_->right.backoff + in.right.length);
+ out_->right.length = in.right.length + next_use;
}
float Finish() {
// A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
- out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1);
+ out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1);
return prob_;
}
void Reset() {
prob_ = 0.0;
left_done_ = false;
- out_.left.length = 0;
- out_.right.length = 0;
+ out_->left.length = 0;
+ out_->right.length = 0;
+ }
+ void Reset(ChartState &replacement) {
+ out_ = &replacement;
+ Reset();
}
private:
bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
ProcessRet(model_.ExtendLeft(
- out_.right.words, out_.right.words + next_use, // Words to extend into
+ out_->right.words, out_->right.words + next_use, // Words to extend into
back_in, // Backoffs to use
in.left.pointers[extend_length - 1], extend_length, // Words to be extended
back_out, // Backoffs for the next score
next_use)); // Length of n-gram to use in next scoring.
- if (next_use != out_.right.length) {
+ if (next_use != out_->right.length) {
left_done_ = true;
if (!next_use) {
// Early exit.
- out_.right = in.right;
+ out_->right = in.right;
prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);
return true;
}
@@ -193,13 +197,13 @@ template <class M> class RuleScore {
left_done_ = true;
return;
}
- out_.left.pointers[out_.left.length++] = ret.extend_left;
+ out_->left.pointers[out_->left.length++] = ret.extend_left;
prob_ += ret.rest;
}
const M &model_;
- ChartState &out_;
+ ChartState *out_;
bool left_done_;
diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh
index bc8687cd..3eb97ccd 100644
--- a/klm/lm/max_order.hh
+++ b/klm/lm/max_order.hh
@@ -4,9 +4,6 @@
* (kMaxOrder - 1) * sizeof(float) bytes instead of
* sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
*/
-#ifndef KENLM_MAX_ORDER
-#define KENLM_MAX_ORDER 6
-#endif
#ifndef KENLM_ORDER_MESSAGE
-#define KENLM_ORDER_MESSAGE "Edit klm/lm/max_order.hh."
+#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh."
#endif
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index b46333a4..a40fd2fb 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -12,6 +12,7 @@
#include <functional>
#include <numeric>
#include <cmath>
+#include <limits>
namespace lm {
namespace ngram {
@@ -19,23 +20,24 @@ namespace detail {
template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType;
-template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
+template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) {
+ size_t goal_size = util::CheckOverflow(Size(counts, config));
uint8_t *start = static_cast<uint8_t*>(base);
size_t allocated = VocabularyT::Size(counts[0], config);
vocab_.SetupMemory(start, allocated, counts[0], config);
start += allocated;
start = search_.SetupMemory(start, counts, config);
- if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config));
+ if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);
}
template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
LoadLM(file, config, *this);
- // g++ prints warnings unless these are fully initialized.
+ // g++ prints warnings unless these are fully initialized.
State begin_sentence = State();
begin_sentence.length = 1;
begin_sentence.words[0] = vocab_.BeginSentence();
@@ -49,38 +51,43 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge
}
namespace {
-void CheckMaxOrder(size_t order) {
- UTIL_THROW_IF(order > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << order << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
+void CheckCounts(const std::vector<uint64_t> &counts) {
+ UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
+ if (sizeof(uint64_t) > sizeof(std::size_t)) {
+ for (std::vector<uint64_t>::const_iterator i = counts.begin(); i != counts.end(); ++i) {
+ UTIL_THROW_IF(*i > static_cast<uint64_t>(std::numeric_limits<size_t>::max()), util::OverflowException, "This model has " << *i << " " << (i - counts.begin() + 1) << "-grams which is too many for 32-bit machines.");
+ }
+ }
}
} // namespace
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
- CheckMaxOrder(params.counts.size());
+ CheckCounts(params.counts);
SetupMemory(start, params.counts, config);
vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
search_.LoadedBinary();
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
- // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
- util::FilePiece f(backing_.file.release(), file, config.messages);
+ // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
+ util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());
try {
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
ReadARPACounts(f, counts);
- CheckMaxOrder(counts.size());
+ CheckCounts(counts);
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
- std::size_t vocab_size = VocabularyT::Size(counts[0], config);
- // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
+ std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
+ // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
if (config.write_mmap) {
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get());
+ wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
@@ -88,7 +95,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (!vocab_.SawUnk()) {
assert(config.unknown_missing != THROW_UP);
- // Default probabilities for unknown.
+ // Default probabilities for unknown.
search_.UnknownUnigram().backoff = 0.0;
search_.UnknownUnigram().prob = config.unknown_missing_logprob;
}
@@ -140,7 +147,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const {
- // Generate a state from context.
+ // Generate a state from context.
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
if (context_rend == context_rbegin) {
out_state.length = 0;
@@ -184,7 +191,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
ret.rest = ptr.Rest();
ret.prob = ptr.Prob();
ret.extend_left = extend_pointer;
- // If this function is called, then it does depend on left words.
+ // If this function is called, then it does depend on left words.
ret.independent_left = false;
}
float subtract_me = ret.rest;
@@ -192,7 +199,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
next_use = extend_length;
ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret);
next_use -= extend_length;
- // Charge backoffs.
+ // Charge backoffs.
for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
ret.prob -= subtract_me;
ret.rest -= subtract_me;
@@ -202,7 +209,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
namespace {
// Do a paraonoid copy of history, assuming new_word has already been copied
// (hence the -1). out_state.length could be zero so I avoided using
-// std::copy.
+// std::copy.
void CopyRemainingHistory(const WordIndex *from, State &out_state) {
WordIndex *out = out_state.words + 1;
const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1;
@@ -210,18 +217,19 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) {
}
} // namespace
-/* Ugly optimized function. Produce a score excluding backoff.
- * The search goes in increasing order of ngram length.
+/* Ugly optimized function. Produce a score excluding backoff.
+ * The search goes in increasing order of ngram length.
* Context goes backward, so context_begin is the word immediately preceeding
- * new_word.
+ * new_word.
*/
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(
const WordIndex *const context_rbegin,
const WordIndex *const context_rend,
const WordIndex new_word,
State &out_state) const {
+ assert(new_word < vocab_.Bound());
FullScoreReturn ret;
- // ret.ngram_length contains the last known non-blank ngram length.
+ // ret.ngram_length contains the last known non-blank ngram length.
ret.ngram_length = 1;
typename Search::Node node;
@@ -230,9 +238,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
ret.prob = uni.Prob();
ret.rest = uni.Rest();
- // This is the length of the context that should be used for continuation to the right.
+ // This is the length of the context that should be used for continuation to the right.
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
- // We'll write the word anyway since it will probably be used and does no harm being there.
+ // We'll write the word anyway since it will probably be used and does no harm being there.
out_state.words[0] = new_word;
if (context_rbegin == context_rend) return ret;
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index 6dee9419..13ff864e 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -41,7 +41,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
* does not include small non-mapped control structures, such as this class
* itself.
*/
- static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
+ static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
/* Load the model from a file. It may be an ARPA or binary file. Binary
* files must have the format expected by this class or you'll get an
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 32084b5b..eb159094 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -1,6 +1,7 @@
#include "lm/model.hh"
#include <stdlib.h>
+#include <string.h>
#define BOOST_TEST_MODULE ModelTest
#include <boost/test/unit_test.hpp>
@@ -22,17 +23,20 @@ std::ostream &operator<<(std::ostream &o, const State &state) {
namespace {
+// Stupid bjam reverses the command line arguments randomly.
const char *TestLocation() {
- if (boost::unit_test::framework::master_test_suite().argc < 2) {
+ if (boost::unit_test::framework::master_test_suite().argc < 3) {
return "test.arpa";
}
- return boost::unit_test::framework::master_test_suite().argv[1];
+ char **argv = boost::unit_test::framework::master_test_suite().argv;
+ return argv[strstr(argv[1], "nounk") ? 2 : 1];
}
const char *TestNoUnkLocation() {
if (boost::unit_test::framework::master_test_suite().argc < 3) {
return "test_nounk.arpa";
}
- return boost::unit_test::framework::master_test_suite().argv[2];
+ char **argv = boost::unit_test::framework::master_test_suite().argv;
+ return argv[strstr(argv[1], "nounk") ? 1 : 2];
}
template <class Model> State GetState(const Model &model, const char *word, const State &in) {
diff --git a/klm/lm/partial.hh b/klm/lm/partial.hh
new file mode 100644
index 00000000..1dede359
--- /dev/null
+++ b/klm/lm/partial.hh
@@ -0,0 +1,167 @@
+#ifndef LM_PARTIAL__
+#define LM_PARTIAL__
+
+#include "lm/return.hh"
+#include "lm/state.hh"
+
+#include <algorithm>
+
+#include <assert.h>
+
+namespace lm {
+namespace ngram {
+
+struct ExtendReturn {
+ float adjust;
+ bool make_full;
+ unsigned char next_use;
+};
+
+template <class Model> ExtendReturn ExtendLoop(
+ const Model &model,
+ unsigned char seen, const WordIndex *add_rbegin, const WordIndex *add_rend, const float *backoff_start,
+ const uint64_t *pointers, const uint64_t *pointers_end,
+ uint64_t *&pointers_write,
+ float *backoff_write) {
+ unsigned char add_length = add_rend - add_rbegin;
+
+ float backoff_buf[2][KENLM_MAX_ORDER - 1];
+ float *backoff_in = backoff_buf[0], *backoff_out = backoff_buf[1];
+ std::copy(backoff_start, backoff_start + add_length, backoff_in);
+
+ ExtendReturn value;
+ value.make_full = false;
+ value.adjust = 0.0;
+ value.next_use = add_length;
+
+ unsigned char i = 0;
+ unsigned char length = pointers_end - pointers;
+ // pointers_write is NULL means that the existing left state is full, so we should use completed probabilities.
+ if (pointers_write) {
+ // Using full context, writing to new left state.
+ for (; i < length; ++i) {
+ FullScoreReturn ret(model.ExtendLeft(
+ add_rbegin, add_rbegin + value.next_use,
+ backoff_in,
+ pointers[i], i + seen + 1,
+ backoff_out,
+ value.next_use));
+ std::swap(backoff_in, backoff_out);
+ if (ret.independent_left) {
+ value.adjust += ret.prob;
+ value.make_full = true;
+ ++i;
+ break;
+ }
+ value.adjust += ret.rest;
+ *pointers_write++ = ret.extend_left;
+ if (value.next_use != add_length) {
+ value.make_full = true;
+ ++i;
+ break;
+ }
+ }
+ }
+ // Using some of the new context.
+ for (; i < length && value.next_use; ++i) {
+ FullScoreReturn ret(model.ExtendLeft(
+ add_rbegin, add_rbegin + value.next_use,
+ backoff_in,
+ pointers[i], i + seen + 1,
+ backoff_out,
+ value.next_use));
+ std::swap(backoff_in, backoff_out);
+ value.adjust += ret.prob;
+ }
+ float unrest = model.UnRest(pointers + i, pointers_end, i + seen + 1);
+ // Using none of the new context.
+ value.adjust += unrest;
+
+ std::copy(backoff_in, backoff_in + value.next_use, backoff_write);
+ return value;
+}
+
+template <class Model> float RevealBefore(const Model &model, const Right &reveal, const unsigned char seen, bool reveal_full, Left &left, Right &right) {
+ assert(seen < reveal.length || reveal_full);
+ uint64_t *pointers_write = reveal_full ? NULL : left.pointers;
+ float backoff_buffer[KENLM_MAX_ORDER - 1];
+ ExtendReturn value(ExtendLoop(
+ model,
+ seen, reveal.words + seen, reveal.words + reveal.length, reveal.backoff + seen,
+ left.pointers, left.pointers + left.length,
+ pointers_write,
+ left.full ? backoff_buffer : (right.backoff + right.length)));
+ if (reveal_full) {
+ left.length = 0;
+ value.make_full = true;
+ } else {
+ left.length = pointers_write - left.pointers;
+ value.make_full |= (left.length == model.Order() - 1);
+ }
+ if (left.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
+ } else {
+ // If left wasn't full when it came in, put words into right state.
+ std::copy(reveal.words + seen, reveal.words + seen + value.next_use, right.words + right.length);
+ right.length += value.next_use;
+ left.full = value.make_full || (right.length == model.Order() - 1);
+ }
+ return value.adjust;
+}
+
+template <class Model> float RevealAfter(const Model &model, Left &left, Right &right, const Left &reveal, unsigned char seen) {
+ assert(seen < reveal.length || reveal.full);
+ uint64_t *pointers_write = left.full ? NULL : (left.pointers + left.length);
+ ExtendReturn value(ExtendLoop(
+ model,
+ seen, right.words, right.words + right.length, right.backoff,
+ reveal.pointers + seen, reveal.pointers + reveal.length,
+ pointers_write,
+ right.backoff));
+ if (reveal.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += right.backoff[i];
+ right.length = 0;
+ value.make_full = true;
+ } else {
+ right.length = value.next_use;
+ value.make_full |= (right.length == model.Order() - 1);
+ }
+ if (!left.full) {
+ left.length = pointers_write - left.pointers;
+ left.full = value.make_full || (left.length == model.Order() - 1);
+ }
+ return value.adjust;
+}
+
+template <class Model> float Subsume(const Model &model, Left &first_left, const Right &first_right, const Left &second_left, Right &second_right, const unsigned int between_length) {
+ assert(first_right.length < KENLM_MAX_ORDER);
+ assert(second_left.length < KENLM_MAX_ORDER);
+ assert(between_length < KENLM_MAX_ORDER - 1);
+ uint64_t *pointers_write = first_left.full ? NULL : (first_left.pointers + first_left.length);
+ float backoff_buffer[KENLM_MAX_ORDER - 1];
+ ExtendReturn value(ExtendLoop(
+ model,
+ between_length, first_right.words, first_right.words + first_right.length, first_right.backoff,
+ second_left.pointers, second_left.pointers + second_left.length,
+ pointers_write,
+ second_left.full ? backoff_buffer : (second_right.backoff + second_right.length)));
+ if (second_left.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
+ } else {
+ std::copy(first_right.words, first_right.words + value.next_use, second_right.words + second_right.length);
+ second_right.length += value.next_use;
+ value.make_full |= (second_right.length == model.Order() - 1);
+ }
+ if (!first_left.full) {
+ first_left.length = pointers_write - first_left.pointers;
+ first_left.full = value.make_full || second_left.full || (first_left.length == model.Order() - 1);
+ }
+ assert(first_left.length < KENLM_MAX_ORDER);
+ assert(second_right.length < KENLM_MAX_ORDER);
+ return value.adjust;
+}
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_PARTIAL__
diff --git a/klm/lm/partial_test.cc b/klm/lm/partial_test.cc
new file mode 100644
index 00000000..8d309c85
--- /dev/null
+++ b/klm/lm/partial_test.cc
@@ -0,0 +1,199 @@
+#include "lm/partial.hh"
+
+#include "lm/left.hh"
+#include "lm/model.hh"
+#include "util/tokenize_piece.hh"
+
+#define BOOST_TEST_MODULE PartialTest
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+
+namespace lm {
+namespace ngram {
+namespace {
+
+const char *TestLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 2) {
+ return "test.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[1];
+}
+
+Config SilentConfig() {
+ Config config;
+ config.arpa_complain = Config::NONE;
+ config.messages = NULL;
+ return config;
+}
+
+struct ModelFixture {
+ ModelFixture() : m(TestLocation(), SilentConfig()) {}
+
+ RestProbingModel m;
+};
+
+BOOST_FIXTURE_TEST_SUITE(suite, ModelFixture)
+
+BOOST_AUTO_TEST_CASE(SimpleBefore) {
+ Left left;
+ left.full = false;
+ left.length = 0;
+ Right right;
+ right.length = 0;
+
+ Right reveal;
+ reveal.length = 1;
+ WordIndex period = m.GetVocabulary().Index(".");
+ reveal.words[0] = period;
+ reveal.backoff[0] = -0.845098;
+
+ BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 0, false, left, right), 0.001);
+ BOOST_CHECK_EQUAL(0, left.length);
+ BOOST_CHECK(!left.full);
+ BOOST_CHECK_EQUAL(1, right.length);
+ BOOST_CHECK_EQUAL(period, right.words[0]);
+ BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001);
+
+ WordIndex more = m.GetVocabulary().Index("more");
+ reveal.words[1] = more;
+ reveal.backoff[1] = -0.4771212;
+ reveal.length = 2;
+ BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 1, false, left, right), 0.001);
+ BOOST_CHECK_EQUAL(0, left.length);
+ BOOST_CHECK(!left.full);
+ BOOST_CHECK_EQUAL(2, right.length);
+ BOOST_CHECK_EQUAL(period, right.words[0]);
+ BOOST_CHECK_EQUAL(more, right.words[1]);
+ BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001);
+ BOOST_CHECK_CLOSE(-0.4771212, right.backoff[1], 0.001);
+}
+
+BOOST_AUTO_TEST_CASE(AlsoWouldConsider) {
+ WordIndex would = m.GetVocabulary().Index("would");
+ WordIndex consider = m.GetVocabulary().Index("consider");
+
+ ChartState current;
+ current.left.length = 1;
+ current.left.pointers[0] = would;
+ current.left.full = false;
+ current.right.length = 1;
+ current.right.words[0] = would;
+ current.right.backoff[0] = -0.30103;
+
+ Left after;
+ after.full = false;
+ after.length = 1;
+ after.pointers[0] = consider;
+
+ // adjustment for would consider
+ BOOST_CHECK_CLOSE(-1.687872 - -0.2922095 - 0.30103, RevealAfter(m, current.left, current.right, after, 0), 0.001);
+
+ BOOST_CHECK_EQUAL(2, current.left.length);
+ BOOST_CHECK_EQUAL(would, current.left.pointers[0]);
+ BOOST_CHECK_EQUAL(false, current.left.full);
+
+ WordIndex also = m.GetVocabulary().Index("also");
+ Right before;
+ before.length = 1;
+ before.words[0] = also;
+ before.backoff[0] = -0.30103;
+ // r(would) = -0.2922095 [i would], r(would -> consider) = -1.988902 [b(would) + p(consider)]
+ // p(also -> would) = -2, p(also would -> consider) = -3
+ BOOST_CHECK_CLOSE(-2 + 0.2922095 -3 + 1.988902, RevealBefore(m, before, 0, false, current.left, current.right), 0.001);
+ BOOST_CHECK_EQUAL(0, current.left.length);
+ BOOST_CHECK(current.left.full);
+ BOOST_CHECK_EQUAL(2, current.right.length);
+ BOOST_CHECK_EQUAL(would, current.right.words[0]);
+ BOOST_CHECK_EQUAL(also, current.right.words[1]);
+}
+
+BOOST_AUTO_TEST_CASE(EndSentence) {
+ WordIndex loin = m.GetVocabulary().Index("loin");
+ WordIndex period = m.GetVocabulary().Index(".");
+ WordIndex eos = m.GetVocabulary().EndSentence();
+
+ ChartState between;
+ between.left.length = 1;
+ between.left.pointers[0] = eos;
+ between.left.full = true;
+ between.right.length = 0;
+
+ Right before;
+ before.words[0] = period;
+ before.words[1] = loin;
+ before.backoff[0] = -0.845098;
+ before.backoff[1] = 0.0;
+
+ before.length = 1;
+ BOOST_CHECK_CLOSE(-0.0410707, RevealBefore(m, before, 0, true, between.left, between.right), 0.001);
+ BOOST_CHECK_EQUAL(0, between.left.length);
+}
+
+float ScoreFragment(const RestProbingModel &model, unsigned int *begin, unsigned int *end, ChartState &out) {
+ RuleScore<RestProbingModel> scorer(model, out);
+ for (unsigned int *i = begin; i < end; ++i) {
+ scorer.Terminal(*i);
+ }
+ return scorer.Finish();
+}
+
+void CheckAdjustment(const RestProbingModel &model, float expect, const Right &before_in, bool before_full, ChartState between, const Left &after_in) {
+ Right before(before_in);
+ Left after(after_in);
+ after.full = false;
+ float got = 0.0;
+ for (unsigned int i = 1; i < 5; ++i) {
+ if (before_in.length >= i) {
+ before.length = i;
+ got += RevealBefore(model, before, i - 1, false, between.left, between.right);
+ }
+ if (after_in.length >= i) {
+ after.length = i;
+ got += RevealAfter(model, between.left, between.right, after, i - 1);
+ }
+ }
+ if (after_in.full) {
+ after.full = true;
+ got += RevealAfter(model, between.left, between.right, after, after.length);
+ }
+ if (before_full) {
+ got += RevealBefore(model, before, before.length, true, between.left, between.right);
+ }
+ // Sometimes they're zero and BOOST_CHECK_CLOSE fails for this.
+ BOOST_CHECK(fabs(expect - got) < 0.001);
+}
+
+void FullDivide(const RestProbingModel &model, StringPiece str) {
+ std::vector<WordIndex> indices;
+ for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) {
+ indices.push_back(model.GetVocabulary().Index(*i));
+ }
+ ChartState full_state;
+ float full = ScoreFragment(model, &indices.front(), &indices.back() + 1, full_state);
+
+ ChartState before_state;
+ before_state.left.full = false;
+ RuleScore<RestProbingModel> before_scorer(model, before_state);
+ float before_score = 0.0;
+ for (unsigned int before = 0; before < indices.size(); ++before) {
+ for (unsigned int after = before; after <= indices.size(); ++after) {
+ ChartState after_state, between_state;
+ float after_score = ScoreFragment(model, &indices.front() + after, &indices.front() + indices.size(), after_state);
+ float between_score = ScoreFragment(model, &indices.front() + before, &indices.front() + after, between_state);
+ CheckAdjustment(model, full - before_score - after_score - between_score, before_state.right, before_state.left.full, between_state, after_state.left);
+ }
+ before_scorer.Terminal(indices[before]);
+ before_score = before_scorer.Finish();
+ }
+}
+
+BOOST_AUTO_TEST_CASE(Strings) {
+ FullDivide(m, "also would consider");
+ FullDivide(m, "looking on a little more loin . </s>");
+ FullDivide(m, "in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
+}
+
+BOOST_AUTO_TEST_SUITE_END()
+} // namespace
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh
index abed0112..8ce2378a 100644
--- a/klm/lm/quantize.hh
+++ b/klm/lm/quantize.hh
@@ -24,7 +24,7 @@ class DontQuantize {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
- static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
+ static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
static uint8_t LongestBits(const Config &/*config*/) { return 31; }
@@ -138,9 +138,9 @@ class SeparatelyQuantize {
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
- static std::size_t Size(uint8_t order, const Config &config) {
- size_t longest_table = (static_cast<size_t>(1) << static_cast<size_t>(config.prob_bits)) * sizeof(float);
- size_t middle_table = (static_cast<size_t>(1) << static_cast<size_t>(config.backoff_bits)) * sizeof(float) + longest_table;
+ static uint64_t Size(uint8_t order, const Config &config) {
+ uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float);
+ uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table;
// unigrams are currently not quantized so no need for a table.
return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8;
}
diff --git a/klm/lm/ngram_query.cc b/klm/lm/query_main.cc
index 49757d9a..49757d9a 100644
--- a/klm/lm/ngram_query.cc
+++ b/klm/lm/query_main.cc
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index 70727e4c..9ea08798 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -1,13 +1,15 @@
#include "lm/read_arpa.hh"
#include "lm/blank.hh"
+#include "util/file.hh"
+#include <cmath>
#include <cstdlib>
#include <iostream>
+#include <sstream>
#include <vector>
#include <ctype.h>
-#include <math.h>
#include <string.h>
#include <stdint.h>
@@ -31,12 +33,27 @@ bool IsEntirelyWhiteSpace(const StringPiece &line) {
const char kBinaryMagic[] = "mmap lm http://kheafield.com/code";
+// strtoull isn't portable enough :-(
+uint64_t ReadCount(const std::string &from) {
+ std::stringstream stream(from);
+ uint64_t ret;
+ stream >> ret;
+ UTIL_THROW_IF(!stream, FormatLoadException, "Bad count " << from);
+ return ret;
+}
+
} // namespace
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
number.clear();
- StringPiece line;
- while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
+ StringPiece line = in.ReadLine();
+ // In general, ARPA files can have arbitrary text before "\data\"
+ // But in KenLM, we require such lines to start with "#", so that
+ // we can do stricter error checking
+ while (IsEntirelyWhiteSpace(line) || line.starts_with("#")) {
+ line = in.ReadLine();
+ }
+
if (line != "\\data\\") {
if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
@@ -52,15 +69,11 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
// So strtol doesn't go off the end of line.
std::string remaining(line.data() + 6, line.size() - 6);
char *end_ptr;
- unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10);
+ unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10);
if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line);
if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line);
++end_ptr;
- const char *start = end_ptr;
- long int count = std::strtol(start, &end_ptr, 10);
- if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count);
- if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line);
- number.push_back(count);
+ number.push_back(ReadCount(end_ptr));
}
}
@@ -103,7 +116,7 @@ void ReadBackoff(util::FilePiece &in, float &backoff) {
int float_class = _fpclass(backoff);
UTIL_THROW_IF(float_class == _FPCLASS_SNAN || float_class == _FPCLASS_QNAN || float_class == _FPCLASS_NINF || float_class == _FPCLASS_PINF, FormatLoadException, "Bad backoff " << backoff);
#else
- int float_class = fpclassify(backoff);
+ int float_class = std::fpclassify(backoff);
UTIL_THROW_IF(float_class == FP_NAN || float_class == FP_INFINITE, FormatLoadException, "Bad backoff " << backoff);
#endif
}
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 13942309..2d6f15b2 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -231,27 +231,27 @@ template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char *
template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
NoRestBuild build;
- ApplyBuild(f, counts, config, vocab, warn, build);
+ ApplyBuild(f, counts, vocab, warn, build);
}
-template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
+template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
switch (config.rest_function) {
case Config::REST_MAX:
{
MaxRestBuild build;
- ApplyBuild(f, counts, config, vocab, warn, build);
+ ApplyBuild(f, counts, vocab, warn, build);
}
break;
case Config::REST_LOWER:
{
LowerRestBuild<ProbingModel> build(config, counts.size(), vocab);
- ApplyBuild(f, counts, config, vocab, warn, build);
+ ApplyBuild(f, counts, vocab, warn, build);
}
break;
}
}
-template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) {
+template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) {
for (WordIndex i = 0; i < counts[0]; ++i) {
build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]);
}
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index 7e8c1220..00595796 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -74,8 +74,8 @@ template <class Value> class HashedSearch {
// TODO: move probing_multiplier here with next binary file format update.
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
- static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
- std::size_t ret = Unigram::Size(counts[0]);
+ static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
+ uint64_t ret = Unigram::Size(counts[0]);
for (unsigned char n = 1; n < counts.size() - 1; ++n) {
ret += Middle::Size(counts[n], config.probing_multiplier);
}
@@ -147,7 +147,7 @@ template <class Value> class HashedSearch {
// Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
- template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
+ template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
class Unigram {
public:
@@ -160,8 +160,8 @@ template <class Value> class HashedSearch {
#endif
{}
- static std::size_t Size(uint64_t count) {
- return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
+ static uint64_t Size(uint64_t count) {
+ return (count + 1) * sizeof(typename Value::Weights); // +1 for hallucinate <unk>
}
const typename Value::Weights &Lookup(WordIndex index) const {
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 832cc9f7..1b0d9b26 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -55,7 +55,7 @@ struct ProbPointer {
uint64_t index;
};
-// Array of n-grams and float indices.
+// Array of n-grams and float indices.
class BackoffMessages {
public:
void Init(std::size_t entry_size) {
@@ -89,7 +89,7 @@ class BackoffMessages {
if (!HasExtension(weights.backoff)) {
weights.backoff = kExtensionBackoff;
UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed.");
- WriteOrThrow(unigrams, &weights, sizeof(weights));
+ util::WriteOrThrow(unigrams, &weights, sizeof(weights));
}
const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex));
base[write_to.array][write_to.index] += weights.backoff;
@@ -100,7 +100,7 @@ class BackoffMessages {
void Apply(float *const *const base, RecordReader &reader) {
FinishedAdding();
if (current_ == allocated_) return;
- // We'll also use the same buffer to record messages to blanks that they extend.
+ // We'll also use the same buffer to record messages to blanks that they extend.
WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_);
const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex);
for (reader.Rewind(); reader && (current_ != allocated_); ) {
@@ -109,7 +109,7 @@ class BackoffMessages {
++reader;
break;
case 1:
- // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends.
+ // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends.
for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w;
current_ += entry_size_;
break;
@@ -126,7 +126,7 @@ class BackoffMessages {
break;
}
}
- // Now this is a list of blanks that extend right.
+ // Now this is a list of blanks that extend right.
entry_size_ = sizeof(WordIndex) * order;
Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get()));
current_ = (uint8_t*)backing_.get();
@@ -153,7 +153,7 @@ class BackoffMessages {
private:
void FinishedAdding() {
Resize(current_ - (uint8_t*)backing_.get());
- // Sort requests in same order as files.
+ // Sort requests in same order as files.
std::sort(
util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)),
util::SizedIterator(util::SizedProxy(current_, entry_size_)),
@@ -220,7 +220,7 @@ class SRISucks {
}
private:
- // This used to be one array. Then I needed to separate it by order for quantization to work.
+ // This used to be one array. Then I needed to separate it by order for quantization to work.
std::vector<float> values_[KENLM_MAX_ORDER - 1];
BackoffMessages messages_[KENLM_MAX_ORDER - 1];
@@ -253,7 +253,7 @@ class FindBlanks {
++counts_.back();
}
- // Unigrams wrote one past.
+ // Unigrams wrote one past.
void Cleanup() {
--counts_[0];
}
@@ -270,15 +270,15 @@ class FindBlanks {
SRISucks &sri_;
};
-// Phase to actually write n-grams to the trie.
+// Phase to actually write n-grams to the trie.
template <class Quant, class Bhiksha> class WriteEntries {
public:
- WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
+ WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
contexts_(contexts),
quant_(quant),
unigrams_(unigrams),
middle_(middle),
- longest_(longest),
+ longest_(longest),
bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)),
order_(order),
sri_(sri) {}
@@ -328,7 +328,7 @@ struct Gram {
const WordIndex *begin, *end;
- // For queue, this is the direction we want.
+ // For queue, this is the direction we want.
bool operator<(const Gram &other) const {
return std::lexicographical_compare(other.begin, other.end, begin, end);
}
@@ -353,7 +353,7 @@ template <class Doing> class BlankManager {
been_length_ = length;
return;
}
- // There are blanks to insert starting with order blank.
+ // There are blanks to insert starting with order blank.
unsigned char blank = cur - to + 1;
UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context.");
const float *lower_basis;
@@ -363,7 +363,7 @@ template <class Doing> class BlankManager {
assert(*lower_basis != kBadProb);
doing_.MiddleBlank(blank, to, based_on, *lower_basis);
*pre = *cur;
- // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
+ // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
basis_[blank - 1] = kBadProb;
}
*pre = *cur;
@@ -377,7 +377,7 @@ template <class Doing> class BlankManager {
unsigned char been_length_;
float basis_[KENLM_MAX_ORDER];
-
+
Doing &doing_;
};
@@ -451,7 +451,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re
}
void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) {
- // Fill unigram probabilities.
+ // Fill unigram probabilities.
try {
rewind(file);
for (WordIndex i = 0; i < unigram_count; ++i) {
@@ -486,7 +486,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
util::scoped_memory unigrams;
MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
- RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder);
+ RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder);
fixed_counts = finder.Counts();
}
unigram_file.reset(util::FDOpenOrThrow(unigram_fd));
@@ -504,7 +504,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
inputs[i-2].Rewind();
}
if (Quant::kTrain) {
- util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing");
+ util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0),
+ config.ProgressMessages(), "Quantizing");
for (unsigned char i = 2; i < counts.size(); ++i) {
TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
}
@@ -519,13 +520,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
}
- // Fill entries except unigram probabilities.
+ // Fill entries except unigram probabilities.
{
WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
- RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer);
+ RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);
}
- // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
+ // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
for (unsigned char order = 2; order <= counts.size(); ++order) {
const RecordReader &context = contexts[order - 2];
if (context) {
@@ -541,13 +542,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
}
/* Set ending offsets so the last entry will be sized properly */
- // Last entry for unigrams was already set.
+ // Last entry for unigrams was already set.
if (out.middle_begin_ != out.middle_end_) {
for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
i->FinishedLoading((i+1)->InsertIndex(), config);
}
(out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config);
- }
+ }
}
template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
@@ -595,7 +596,7 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::Initializ
} else {
temporary_prefix = file;
}
- // At least 1MB sorting memory.
+ // At least 1MB sorting memory.
SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab);
BuildTrie(sorted, counts, config, *this, quant_, vocab, backing);
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 10b22ab1..1264baf5 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -44,8 +44,8 @@ template <class Quant, class Bhiksha> class TrieSearch {
Bhiksha::UpdateConfigFromBinary(fd, config);
}
- static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
- std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
+ static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
+ uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
}
diff --git a/klm/lm/sizes.cc b/klm/lm/sizes.cc
new file mode 100644
index 00000000..55ad586c
--- /dev/null
+++ b/klm/lm/sizes.cc
@@ -0,0 +1,63 @@
+#include "lm/sizes.hh"
+#include "lm/model.hh"
+#include "util/file_piece.hh"
+
+#include <vector>
+#include <iomanip>
+
+namespace lm {
+namespace ngram {
+
+void ShowSizes(const std::vector<uint64_t> &counts, const lm::ngram::Config &config) {
+ uint64_t sizes[6];
+ sizes[0] = ProbingModel::Size(counts, config);
+ sizes[1] = RestProbingModel::Size(counts, config);
+ sizes[2] = TrieModel::Size(counts, config);
+ sizes[3] = QuantTrieModel::Size(counts, config);
+ sizes[4] = ArrayTrieModel::Size(counts, config);
+ sizes[5] = QuantArrayTrieModel::Size(counts, config);
+ uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
+ uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
+ uint64_t divide;
+ char prefix;
+ if (min_length < (1 << 10) * 10) {
+ prefix = ' ';
+ divide = 1;
+ } else if (min_length < (1 << 20) * 10) {
+ prefix = 'k';
+ divide = 1 << 10;
+ } else if (min_length < (1ULL << 30) * 10) {
+ prefix = 'M';
+ divide = 1 << 20;
+ } else {
+ prefix = 'G';
+ divide = 1 << 30;
+ }
+ long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide))));
+ std::cerr << "Memory estimate for binary LM:\ntype ";
+
+ // right align bytes.
+ for (long int i = 0; i < length - 2; ++i) std::cerr << ' ';
+
+ std::cerr << prefix << "B\n"
+ "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"
+ "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n"
+ "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n"
+ "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
+ "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
+ "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";
+}
+
+void ShowSizes(const std::vector<uint64_t> &counts) {
+ lm::ngram::Config config;
+ ShowSizes(counts, config);
+}
+
+void ShowSizes(const char *file, const lm::ngram::Config &config) {
+ std::vector<uint64_t> counts;
+ util::FilePiece f(file);
+ lm::ReadARPACounts(f, counts);
+ ShowSizes(counts, config);
+}
+
+}} //namespaces
diff --git a/klm/lm/sizes.hh b/klm/lm/sizes.hh
new file mode 100644
index 00000000..85abade7
--- /dev/null
+++ b/klm/lm/sizes.hh
@@ -0,0 +1,17 @@
+#ifndef LM_SIZES__
+#define LM_SIZES__
+
+#include <vector>
+
+#include <stdint.h>
+
+namespace lm { namespace ngram {
+
+struct Config;
+
+void ShowSizes(const std::vector<uint64_t> &counts, const lm::ngram::Config &config);
+void ShowSizes(const std::vector<uint64_t> &counts);
+void ShowSizes(const char *file, const lm::ngram::Config &config);
+
+}} // namespaces
+#endif // LM_SIZES__
diff --git a/klm/lm/sri_test.cc b/klm/lm/sri_test.cc
deleted file mode 100644
index e697d722..00000000
--- a/klm/lm/sri_test.cc
+++ /dev/null
@@ -1,65 +0,0 @@
-#include "lm/sri.hh"
-
-#include <stdlib.h>
-
-#define BOOST_TEST_MODULE SRITest
-#include <boost/test/unit_test.hpp>
-
-namespace lm {
-namespace sri {
-namespace {
-
-#define StartTest(word, ngram, score) \
- ret = model.FullScore( \
- state, \
- model.GetVocabulary().Index(word), \
- out);\
- BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
- BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
- BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
-
-#define AppendTest(word, ngram, score) \
- StartTest(word, ngram, score) \
- state = out;
-
-template <class M> void Starters(M &model) {
- FullScoreReturn ret;
- Model::State state(model.BeginSentenceState());
- Model::State out;
-
- StartTest("looking", 2, -0.4846522);
-
- // , probability plus <s> backoff
- StartTest(",", 1, -1.383514 + -0.4149733);
- // <unk> probability plus <s> backoff
- StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
-}
-
-template <class M> void Continuation(M &model) {
- FullScoreReturn ret;
- Model::State state(model.BeginSentenceState());
- Model::State out;
-
- AppendTest("looking", 2, -0.484652);
- AppendTest("on", 3, -0.348837);
- AppendTest("a", 4, -0.0155266);
- AppendTest("little", 5, -0.00306122);
- State preserve = state;
- AppendTest("the", 1, -4.04005);
- AppendTest("biarritz", 1, -1.9889);
- AppendTest("not_found", 0, -2.29666);
- AppendTest("more", 1, -1.20632);
- AppendTest(".", 2, -0.51363);
- AppendTest("</s>", 3, -0.0191651);
-
- state = preserve;
- AppendTest("more", 5, -0.00181395);
- AppendTest("loin", 5, -0.0432557);
-}
-
-BOOST_AUTO_TEST_CASE(starters) { Model m("test.arpa", 5); Starters(m); }
-BOOST_AUTO_TEST_CASE(continuation) { Model m("test.arpa", 5); Continuation(m); }
-
-} // namespace
-} // namespace sri
-} // namespace lm
diff --git a/klm/lm/state.hh b/klm/lm/state.hh
index 830e40aa..d8e6c132 100644
--- a/klm/lm/state.hh
+++ b/klm/lm/state.hh
@@ -47,6 +47,8 @@ class State {
unsigned char length;
};
+typedef State Right;
+
inline uint64_t hash_value(const State &state, uint64_t seed = 0) {
return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed);
}
@@ -54,14 +56,14 @@ inline uint64_t hash_value(const State &state, uint64_t seed = 0) {
struct Left {
bool operator==(const Left &other) const {
return
- (length == other.length) &&
- pointers[length - 1] == other.pointers[length - 1] &&
- full == other.full;
+ length == other.length &&
+ (!length || (pointers[length - 1] == other.pointers[length - 1] && full == other.full));
}
int Compare(const Left &other) const {
if (length < other.length) return -1;
if (length > other.length) return 1;
+ if (length == 0) return 0; // Must be full.
if (pointers[length - 1] > other.pointers[length - 1]) return 1;
if (pointers[length - 1] < other.pointers[length - 1]) return -1;
return (int)full - (int)other.full;
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 0f1ca574..d9895f89 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -36,7 +36,7 @@ bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_
}
} // namespace
-std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
+uint64_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits;
// Extra entry for next pointer at the end.
// +7 then / 8 to round up bits and convert to bytes
@@ -57,7 +57,7 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)
max_vocab_ = max_vocab;
}
-template <class Bhiksha> std::size_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
+template <class Bhiksha> uint64_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
}
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index 034a1414..9ea3c546 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -49,7 +49,7 @@ class Unigram {
unigram_ = static_cast<UnigramValue*>(start);
}
- static std::size_t Size(uint64_t count) {
+ static uint64_t Size(uint64_t count) {
// +1 in case unknown doesn't appear. +1 for the final next.
return (count + 2) * sizeof(UnigramValue);
}
@@ -84,7 +84,7 @@ class BitPacked {
}
protected:
- static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
+ static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
@@ -99,7 +99,7 @@ class BitPacked {
template <class Bhiksha> class BitPackedMiddle : public BitPacked {
public:
- static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
+ static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
// next_source need not be initialized.
BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
@@ -128,7 +128,7 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked {
class BitPackedLongest : public BitPacked {
public:
- static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
+ static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
return BaseSize(entries, max_vocab, quant_bits);
}
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index 0d83221e..dc542bb3 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -22,12 +22,6 @@
namespace lm {
namespace ngram {
namespace trie {
-
-void WriteOrThrow(FILE *to, const void *data, size_t size) {
- assert(size);
- if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
-}
-
namespace {
typedef util::SizedIterator NGramIter;
@@ -71,13 +65,13 @@ class PartialViewProxy {
typedef util::ProxyIterator<PartialViewProxy> PartialIter;
-FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) {
- util::scoped_fd file(maker.Make());
+FILE *DiskFlush(const void *mem_begin, const void *mem_end, const std::string &temp_prefix) {
+ util::scoped_fd file(util::MakeTemp(temp_prefix));
util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin);
return util::FDOpenOrThrow(file);
}
-FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) {
+FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_prefix, std::size_t entry_size, unsigned char order) {
const size_t context_size = sizeof(WordIndex) * (order - 1);
// Sort just the contexts using the same memory.
PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
@@ -90,17 +84,17 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make
#endif
(context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
- util::scoped_FILE out(maker.MakeFile());
+ util::scoped_FILE out(util::FMakeTemp(temp_prefix));
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
if (context_begin == context_end) return out.release();
PartialIter i(context_begin);
- WriteOrThrow(out.get(), i->Data(), context_size);
+ util::WriteOrThrow(out.get(), i->Data(), context_size);
const void *previous = i->Data();
++i;
for (; i != context_end; ++i) {
if (memcmp(previous, i->Data(), context_size)) {
- WriteOrThrow(out.get(), i->Data(), context_size);
+ util::WriteOrThrow(out.get(), i->Data(), context_size);
previous = i->Data();
}
}
@@ -116,23 +110,23 @@ struct ThrowCombine {
// Useful for context files that just contain records with no value.
struct FirstCombine {
void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const {
- WriteOrThrow(out, first, entry_size);
+ util::WriteOrThrow(out, first, entry_size);
}
};
-template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) {
+template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const std::string &temp_prefix, std::size_t weights_size, unsigned char order, const Combine &combine) {
std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
RecordReader first, second;
first.Init(first_file, entry_size);
second.Init(second_file, entry_size);
- util::scoped_FILE out_file(maker.MakeFile());
+ util::scoped_FILE out_file(util::FMakeTemp(temp_prefix));
EntryCompare less(order);
while (first && second) {
if (less(first.Data(), second.Data())) {
- WriteOrThrow(out_file.get(), first.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), first.Data(), entry_size);
++first;
} else if (less(second.Data(), first.Data())) {
- WriteOrThrow(out_file.get(), second.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), second.Data(), entry_size);
++second;
} else {
combine(entry_size, first.Data(), second.Data(), out_file.get());
@@ -140,7 +134,7 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f
}
}
for (RecordReader &remains = (first ? first : second); remains; ++remains) {
- WriteOrThrow(out_file.get(), remains.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), remains.Data(), entry_size);
}
return out_file.release();
}
@@ -164,7 +158,7 @@ void RecordReader::Init(FILE *file, std::size_t entry_size) {
void RecordReader::Overwrite(const void *start, std::size_t amount) {
long internal = (uint8_t*)start - (uint8_t*)data_.get();
UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
- WriteOrThrow(file_, start, amount);
+ util::WriteOrThrow(file_, start, amount);
long forward = entry_size_ - internal - amount;
#if !defined(_WIN32) && !defined(_WIN64)
if (forward)
@@ -183,9 +177,8 @@ void RecordReader::Rewind() {
}
SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
- util::TempMaker maker(file_prefix);
PositiveProbWarn warn(config.positive_log_probability);
- unigram_.reset(maker.Make());
+ unigram_.reset(util::MakeTemp(file_prefix));
{
// In case <unk> appears.
size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff);
@@ -208,7 +201,7 @@ SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<u
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
for (unsigned char order = 2; order <= counts.size(); ++order) {
- ConvertToSorted(f, vocab, counts, maker, order, warn, mem.get(), buffer);
+ ConvertToSorted(f, vocab, counts, file_prefix, order, warn, mem.get(), buffer);
}
ReadEnd(f);
}
@@ -233,7 +226,7 @@ class Closer {
};
} // namespace
-void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
+void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
ReadNGramHeader(f, order);
const size_t count = counts[order - 1];
// Size of weights. Does it include backoff?
@@ -267,8 +260,8 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
std::sort
#endif
(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order)));
- files.push_back(DiskFlush(begin, out_end, maker));
- contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order));
+ files.push_back(DiskFlush(begin, out_end, file_prefix));
+ contexts.push_back(WriteContextFile(begin, out_end, file_prefix, entry_size, order));
done += (out_end - begin) / entry_size;
}
@@ -276,10 +269,10 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
// All individual files created. Merge them.
while (files.size() > 1) {
- files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine()));
+ files.push_back(MergeSortedFiles(files[0], files[1], file_prefix, weights_size, order, ThrowCombine()));
files_closer.PopFront();
files_closer.PopFront();
- contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine()));
+ contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], file_prefix, 0, order - 1, FirstCombine()));
contexts_closer.PopFront();
contexts_closer.PopFront();
}
diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh
index 1e6fce51..1afd9562 100644
--- a/klm/lm/trie_sort.hh
+++ b/klm/lm/trie_sort.hh
@@ -18,7 +18,6 @@
namespace util {
class FilePiece;
-class TempMaker;
} // namespace util
namespace lm {
@@ -29,8 +28,6 @@ struct Config;
namespace trie {
-void WriteOrThrow(FILE *to, const void *data, size_t size);
-
class EntryCompare : public std::binary_function<const void*, const void*, bool> {
public:
explicit EntryCompare(unsigned char order) : order_(order) {}
@@ -103,7 +100,7 @@ class SortedFiles {
}
private:
- void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size);
+ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size);
util::scoped_fd unigram_;
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 5de68f16..fd7f96dc 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -80,14 +80,14 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
buffer_.push_back(0);
}
-void WriteWordsWrapper::Write(int fd) {
- util::SeekEnd(fd);
+void WriteWordsWrapper::Write(int fd, uint64_t start) {
+ util::SeekOrThrow(fd, start);
util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
}
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
-std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) {
+uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
// Lead with the number of entries.
return sizeof(uint64_t) + sizeof(uint64_t) * entries;
}
@@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
}
*end_ = hashed;
if (enumerate_) {
- strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size());
+ void *copied = string_backing_.Allocate(str.size());
+ memcpy(copied, str.data(), str.size());
+ strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size());
}
++end_;
// This is 1 + the offset where it was inserted to make room for unk.
@@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
if (enumerate_) {
if (!strings_to_enumerate_.empty()) {
- util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
+ util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
util::JointSort(begin_, end_, values);
}
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
@@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
enumerate_->Add(i + 1, strings_to_enumerate_[i]);
}
strings_to_enumerate_.clear();
+ string_backing_.FreeAll();
} else {
util::JointSort(begin_, end_, reorder_vocab + 1);
}
@@ -165,7 +168,7 @@ struct ProbingVocabularyHeader {
ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
-std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) {
+uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) {
return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier);
}
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index a25432f9..3902f117 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -4,6 +4,7 @@
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/virtual_interface.hh"
+#include "util/pool.hh"
#include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
@@ -35,7 +36,7 @@ class WriteWordsWrapper : public EnumerateVocab {
void Add(WordIndex index, const StringPiece &str);
- void Write(int fd);
+ void Write(int fd, uint64_t start);
private:
EnumerateVocab *inner_;
@@ -62,7 +63,7 @@ class SortedVocabulary : public base::Vocabulary {
}
// Size for purposes of file writing
- static size_t Size(std::size_t entries, const Config &config);
+ static uint64_t Size(uint64_t entries, const Config &config);
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
WordIndex Bound() const { return bound_; }
@@ -96,7 +97,9 @@ class SortedVocabulary : public base::Vocabulary {
EnumerateVocab *enumerate_;
// Actual strings. Used only when loading from ARPA and enumerate_ != NULL
- std::vector<std::string> strings_to_enumerate_;
+ util::Pool string_backing_;
+
+ std::vector<StringPiece> strings_to_enumerate_;
};
#pragma pack(push)
@@ -129,7 +132,7 @@ class ProbingVocabulary : public base::Vocabulary {
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
}
- static size_t Size(std::size_t entries, const Config &config);
+ static uint64_t Size(uint64_t entries, const Config &config);
// Vocab words are [0, Bound()).
WordIndex Bound() const { return bound_; }
diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh
index 67841c30..e09557a7 100644
--- a/klm/lm/word_index.hh
+++ b/klm/lm/word_index.hh
@@ -2,8 +2,11 @@
#ifndef LM_WORD_INDEX__
#define LM_WORD_INDEX__
+#include <limits.h>
+
namespace lm {
typedef unsigned int WordIndex;
+const WordIndex kMaxWordIndex = UINT_MAX;
} // namespace lm
typedef lm::WordIndex LMWordIndex;