summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
authorKenneth Heafield <kenlm@kheafield.com>2012-02-28 17:23:55 -0500
committerKenneth Heafield <kenlm@kheafield.com>2012-02-28 17:23:55 -0500
commit00d0b3b462bd9fd230b25b3ab6021c197980bdff (patch)
tree2c9e416b3c6f5579f14d00cf193b6d5768519c7d /klm/lm
parent5c63dae2edca73b2fa1c668d708b8b0c3ff1f7dc (diff)
Subject: where's my kenlm update?? From: Chris Dyer <cdyer@cs.cmu.edu>
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/bhiksha.cc7
-rw-r--r--klm/lm/bhiksha.hh2
-rw-r--r--klm/lm/binary_format.cc139
-rw-r--r--klm/lm/binary_format.hh14
-rw-r--r--klm/lm/blank.hh2
-rw-r--r--klm/lm/build_binary.cc35
-rw-r--r--klm/lm/config.cc1
-rw-r--r--klm/lm/config.hh8
-rw-r--r--klm/lm/left_test.cc11
-rw-r--r--klm/lm/model.cc17
-rw-r--r--klm/lm/model.hh11
-rw-r--r--klm/lm/model_test.cc24
-rw-r--r--klm/lm/ngram_query.cc145
-rw-r--r--klm/lm/ngram_query.hh103
-rw-r--r--klm/lm/quantize.cc38
-rw-r--r--klm/lm/quantize.hh2
-rw-r--r--klm/lm/read_arpa.cc2
-rw-r--r--klm/lm/return.hh2
-rw-r--r--klm/lm/search_hashed.cc22
-rw-r--r--klm/lm/search_hashed.hh63
-rw-r--r--klm/lm/search_trie.cc88
-rw-r--r--klm/lm/search_trie.hh10
-rw-r--r--klm/lm/trie.hh2
-rw-r--r--klm/lm/trie_sort.cc217
-rw-r--r--klm/lm/trie_sort.hh55
-rw-r--r--klm/lm/vocab.cc53
-rw-r--r--klm/lm/vocab.hh36
27 files changed, 648 insertions, 461 deletions
diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc
index bf86fd4b..cdeafb47 100644
--- a/klm/lm/bhiksha.cc
+++ b/klm/lm/bhiksha.cc
@@ -1,5 +1,6 @@
#include "lm/bhiksha.hh"
#include "lm/config.hh"
+#include "util/file.hh"
#include <limits>
@@ -12,12 +13,12 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_
const uint8_t kArrayBhikshaVersion = 0;
+// TODO: put this in binary file header instead when I change the binary file format again.
void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) {
uint8_t version;
uint8_t configured_bits;
- if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) {
- UTIL_THROW(util::ErrnoException, "Could not read from binary file");
- }
+ util::ReadOrThrow(fd, &version, 1);
+ util::ReadOrThrow(fd, &configured_bits, 1);
if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion);
config.pointer_bhiksha_bits = configured_bits;
}
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh
index 3df43dda..5182ee2e 100644
--- a/klm/lm/bhiksha.hh
+++ b/klm/lm/bhiksha.hh
@@ -13,7 +13,7 @@
#ifndef LM_BHIKSHA__
#define LM_BHIKSHA__
-#include <inttypes.h>
+#include <stdint.h>
#include <assert.h>
#include "lm/model_type.hh"
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 27cada13..4796f6d1 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -1,19 +1,15 @@
#include "lm/binary_format.hh"
#include "lm/lm_exception.hh"
+#include "util/file.hh"
#include "util/file_piece.hh"
+#include <cstddef>
+#include <cstring>
#include <limits>
#include <string>
-#include <fcntl.h>
-#include <errno.h>
-#include <stdlib.h>
-#include <string.h>
-#include <sys/mman.h>
-#include <sys/types.h>
-#include <sys/stat.h>
-#include <unistd.h>
+#include <stdint.h>
namespace lm {
namespace ngram {
@@ -24,14 +20,16 @@ const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n
const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
const long int kMagicVersion = 5;
-// Test values.
-struct Sanity {
+// Old binary files built on 32-bit machines have this header.
+// TODO: eliminate with next binary release.
+struct OldSanity {
char magic[sizeof(kMagicBytes)];
float zero_f, one_f, minus_half_f;
WordIndex one_word_index, max_word_index;
uint64_t one_uint64;
void SetToReference() {
+ std::memset(this, 0, sizeof(OldSanity));
std::memcpy(magic, kMagicBytes, sizeof(magic));
zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
one_word_index = 1;
@@ -40,27 +38,35 @@ struct Sanity {
}
};
-const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
-std::size_t TotalHeaderSize(unsigned char order) {
- return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
-}
+// Test values aligned to 8 bytes.
+struct Sanity {
+ char magic[ALIGN8(sizeof(kMagicBytes))];
+ float zero_f, one_f, minus_half_f;
+ WordIndex one_word_index, max_word_index, padding_to_8;
+ uint64_t one_uint64;
-void ReadLoop(int fd, void *to_void, std::size_t size) {
- uint8_t *to = static_cast<uint8_t*>(to_void);
- while (size) {
- ssize_t ret = read(fd, to, size);
- if (ret == -1) UTIL_THROW(util::ErrnoException, "Failed to read from binary file");
- if (ret == 0) UTIL_THROW(util::ErrnoException, "Binary file too short");
- to += ret;
- size -= ret;
+ void SetToReference() {
+ std::memset(this, 0, sizeof(Sanity));
+ std::memcpy(magic, kMagicBytes, sizeof(kMagicBytes));
+ zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
+ one_word_index = 1;
+ max_word_index = std::numeric_limits<WordIndex>::max();
+ padding_to_8 = 0;
+ one_uint64 = 1;
}
+};
+
+const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
+
+std::size_t TotalHeaderSize(unsigned char order) {
+ return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
}
void WriteHeader(void *to, const Parameters &params) {
Sanity header = Sanity();
header.SetToReference();
- memcpy(to, &header, sizeof(Sanity));
+ std::memcpy(to, &header, sizeof(Sanity));
char *out = reinterpret_cast<char*>(to) + sizeof(Sanity);
*reinterpret_cast<FixedWidthParameters*>(out) = params.fixed;
@@ -74,14 +80,6 @@ void WriteHeader(void *to, const Parameters &params) {
} // namespace
-void SeekOrThrow(int fd, off_t off) {
- if ((off_t)-1 == lseek(fd, off, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed");
-}
-
-void AdvanceOrThrow(int fd, off_t off) {
- if ((off_t)-1 == lseek(fd, off, SEEK_CUR)) UTIL_THROW(util::ErrnoException, "Seek failed");
-}
-
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
std::size_t total = TotalHeaderSize(order) + memory_size;
@@ -89,7 +87,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
} else {
- backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ util::MapAnonymous(memory_size, backing.vocab);
return reinterpret_cast<uint8_t*>(backing.vocab.get());
}
}
@@ -98,42 +96,58 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
if (config.write_mmap) {
// Grow the file to accomodate the search, using zeros.
- if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size))
- UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed");
+ try {
+ util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
+ } catch (util::ErrnoException &e) {
+ e << " for file " << config.write_mmap;
+ throw e;
+ }
+ if (config.write_method == Config::WRITE_AFTER) {
+ util::MapAnonymous(memory_size, backing.search);
+ return reinterpret_cast<uint8_t*>(backing.search.get());
+ }
+ // mmap it now.
// We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
- off_t page_size = sysconf(_SC_PAGE_SIZE);
- off_t alignment_cruft = adjusted_vocab % page_size;
+ std::size_t page_size = util::SizePage();
+ std::size_t alignment_cruft = adjusted_vocab % page_size;
backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
-
return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
} else {
- backing.search.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ util::MapAnonymous(memory_size, backing.search);
return reinterpret_cast<uint8_t*>(backing.search.get());
}
}
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing) {
- if (config.write_mmap) {
- if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC))
- UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap);
- // header and vocab share the same mmap. The header is written here because we know the counts.
- Parameters params;
- params.counts = counts;
- params.fixed.order = counts.size();
- params.fixed.probing_multiplier = config.probing_multiplier;
- params.fixed.model_type = model_type;
- params.fixed.has_vocabulary = config.include_vocab;
- params.fixed.search_version = search_version;
- WriteHeader(backing.vocab.get(), params);
+void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
+ if (!config.write_mmap) return;
+ util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
+ switch (config.write_method) {
+ case Config::WRITE_MMAP:
+ util::SyncOrThrow(backing.search.get(), backing.search.size());
+ break;
+ case Config::WRITE_AFTER:
+ util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
+ util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
+ util::FSyncOrThrow(backing.file.get());
+ break;
}
+ // header and vocab share the same mmap. The header is written here because we know the counts.
+ Parameters params = Parameters();
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+ params.fixed.search_version = search_version;
+ WriteHeader(backing.vocab.get(), params);
}
namespace detail {
bool IsBinaryFormat(int fd) {
- const off_t size = util::SizeFile(fd);
- if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(Sanity)))) return false;
+ const uint64_t size = util::SizeFile(fd);
+ if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
// Try reading the header.
util::scoped_memory memory;
try {
@@ -154,19 +168,23 @@ bool IsBinaryFormat(int fd) {
if ((end_ptr != begin_version) && version != kMagicVersion) {
UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
}
+
+ OldSanity old_sanity = OldSanity();
+ old_sanity.SetToReference();
+ UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
}
return false;
}
void ReadHeader(int fd, Parameters &out) {
- SeekOrThrow(fd, sizeof(Sanity));
- ReadLoop(fd, &out.fixed, sizeof(out.fixed));
+ util::SeekOrThrow(fd, sizeof(Sanity));
+ util::ReadOrThrow(fd, &out.fixed, sizeof(out.fixed));
if (out.fixed.probing_multiplier < 1.0)
UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0.");
out.counts.resize(static_cast<std::size_t>(out.fixed.order));
- ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
+ if (out.fixed.order) util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
}
void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters &params) {
@@ -179,11 +197,11 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet
}
void SeekPastHeader(int fd, const Parameters &params) {
- SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
+ util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
}
uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing) {
- const off_t file_size = util::SizeFile(backing.file.get());
+ const uint64_t file_size = util::SizeFile(backing.file.get());
// The header is smaller than a page, so we have to map the whole header as well.
std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size;
if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
@@ -194,9 +212,8 @@ uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t
if (config.enumerate_vocab && !params.fixed.has_vocabulary)
UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
- if (config.enumerate_vocab) {
- SeekOrThrow(backing.file.get(), total_map);
- }
+ // Seek to vocabulary words
+ util::SeekOrThrow(backing.file.get(), total_map);
return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
}
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh
index e9df0892..dd795f62 100644
--- a/klm/lm/binary_format.hh
+++ b/klm/lm/binary_format.hh
@@ -12,7 +12,7 @@
#include <cstddef>
#include <vector>
-#include <inttypes.h>
+#include <stdint.h>
namespace lm {
namespace ngram {
@@ -33,10 +33,8 @@ struct FixedWidthParameters {
unsigned int search_version;
};
-inline std::size_t Align8(std::size_t in) {
- std::size_t off = in % 8;
- return off ? (in + 8 - off) : in;
-}
+// This is a macro instead of an inline function so constants can be assigned using it.
+#define ALIGN8(a) ((std::ptrdiff_t(((a)-1)/8)+1)*8)
// Parameters stored in the header of a binary file.
struct Parameters {
@@ -53,10 +51,6 @@ struct Backing {
util::scoped_memory search;
};
-void SeekOrThrow(int fd, off_t off);
-// Seek forward
-void AdvanceOrThrow(int fd, off_t off);
-
// Create just enough of a binary file to write vocabulary to it.
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
@@ -64,7 +58,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
// Write header to binary file. This is done last to prevent incomplete files
// from loading.
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing);
+void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing);
namespace detail {
diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh
index 2fb64cd0..4da81209 100644
--- a/klm/lm/blank.hh
+++ b/klm/lm/blank.hh
@@ -3,7 +3,7 @@
#include <limits>
-#include <inttypes.h>
+#include <stdint.h>
#include <math.h>
namespace lm {
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index fdb62a71..8cbb69d0 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -8,18 +8,24 @@
#include <math.h>
#include <stdlib.h>
-#include <unistd.h>
+
+#ifdef WIN32
+#include "util/getopt.hh"
+#endif
namespace lm {
namespace ngram {
namespace {
void Usage(const char *name) {
- std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
+ std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
" Default is -100. The ARPA file will always take precedence.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n"
-"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n"
+"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n"
+"-w mmap|after determines how writing is done.\n"
+" mmap maps the binary file and writes to it. Default for trie.\n"
+" after allocates anonymous memory, builds, and writes. Default for probing.\n\n"
"type is either probing or trie. Default is probing.\n\n"
"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n"
@@ -55,7 +61,7 @@ uint8_t ParseBitCount(const char *from) {
unsigned long val = ParseUInt(from);
if (val > 25) {
util::ParseNumberException e(from);
- e << " bit counts are limited to 256.";
+ e << " bit counts are limited to 25.";
}
return val;
}
@@ -87,7 +93,7 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
prefix = 'G';
divide = 1 << 30;
}
- long int length = std::max<long int>(2, lrint(ceil(log10(max_length / divide))));
+ long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide))));
std::cout << "Memory estimate:\ntype ";
// right align bytes.
for (long int i = 0; i < length - 2; ++i) std::cout << ' ';
@@ -112,10 +118,10 @@ int main(int argc, char *argv[]) {
using namespace lm::ngram;
try {
- bool quantize = false, set_backoff_bits = false, bhiksha = false;
+ bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false;
lm::ngram::Config config;
int opt;
- while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) {
+ while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:si")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
@@ -129,6 +135,7 @@ int main(int argc, char *argv[]) {
case 'a':
config.pointer_bhiksha_bits = ParseBitCount(optarg);
bhiksha = true;
+ break;
case 'u':
config.unknown_missing_logprob = ParseFloat(optarg);
break;
@@ -141,6 +148,16 @@ int main(int argc, char *argv[]) {
case 'm':
config.building_memory = ParseUInt(optarg) * 1048576;
break;
+ case 'w':
+ set_write_method = true;
+ if (!strcmp(optarg, "mmap")) {
+ config.write_method = Config::WRITE_MMAP;
+ } else if (!strcmp(optarg, "after")) {
+ config.write_method = Config::WRITE_AFTER;
+ } else {
+ Usage(argv[0]);
+ }
+ break;
case 's':
config.sentence_marker_missing = lm::SILENT;
break;
@@ -166,9 +183,11 @@ int main(int argc, char *argv[]) {
const char *from_file = argv[optind + 1];
config.write_mmap = argv[optind + 2];
if (!strcmp(model_type, "probing")) {
+ if (!set_write_method) config.write_method = Config::WRITE_AFTER;
if (quantize || set_backoff_bits) ProbingQuantizationUnsupported();
ProbingModel(from_file, config);
} else if (!strcmp(model_type, "trie")) {
+ if (!set_write_method) config.write_method = Config::WRITE_MMAP;
if (quantize) {
if (bhiksha) {
QuantArrayTrieModel(from_file, config);
@@ -191,7 +210,9 @@ int main(int argc, char *argv[]) {
}
catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
+ std::cerr << "ERROR" << std::endl;
return 1;
}
+ std::cerr << "SUCCESS" << std::endl;
return 0;
}
diff --git a/klm/lm/config.cc b/klm/lm/config.cc
index 297589a4..dbe762b3 100644
--- a/klm/lm/config.cc
+++ b/klm/lm/config.cc
@@ -17,6 +17,7 @@ Config::Config() :
temporary_directory_prefix(NULL),
arpa_complain(ALL),
write_mmap(NULL),
+ write_method(WRITE_AFTER),
include_vocab(true),
prob_bits(8),
backoff_bits(8),
diff --git a/klm/lm/config.hh b/klm/lm/config.hh
index 8564661b..01b75632 100644
--- a/klm/lm/config.hh
+++ b/klm/lm/config.hh
@@ -70,9 +70,17 @@ struct Config {
// to NULL to disable.
const char *write_mmap;
+ typedef enum {
+ WRITE_MMAP, // Map the file directly.
+ WRITE_AFTER // Write after we're done.
+ } WriteMethod;
+ WriteMethod write_method;
+
// Include the vocab in the binary file? Only effective if write_mmap != NULL.
bool include_vocab;
+
+
// Quantization options. Only effective for QuantTrieModel. One value is
// reserved for each of prob and backoff, so 2^bits - 1 buckets will be used
// to quantize (and one of the remaining backoffs will be 0).
diff --git a/klm/lm/left_test.cc b/klm/lm/left_test.cc
index 8bb91cb3..c85e5efa 100644
--- a/klm/lm/left_test.cc
+++ b/klm/lm/left_test.cc
@@ -142,7 +142,7 @@ template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &wo
template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vector<WordIndex> &out) {
out.clear();
- for (util::PieceIterator<' '> i(str); i; ++i) {
+ for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) {
out.push_back(m.GetVocabulary().Index(*i));
}
}
@@ -326,10 +326,17 @@ template <class M> void FullGrow(const M &m) {
}
}
+const char *FileLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 2) {
+ return "test.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[1];
+}
+
template <class M> void Everything() {
Config config;
config.messages = NULL;
- M m("test.arpa", config);
+ M m(FileLocation(), config);
Short(m);
Charge(m);
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index e4c1ec1d..478ebed1 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -46,7 +46,7 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
SetupMemory(start, params.counts, config);
- vocab_.LoadedBinary(fd, config.enumerate_vocab);
+ vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
search_.LoadedBinary();
}
@@ -82,13 +82,18 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.unigram.Unknown().backoff = 0.0;
search_.unigram.Unknown().prob = config.unknown_missing_logprob;
}
- FinishFile(config, kModelType, kVersion, counts, backing_);
+ FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
} catch (util::Exception &e) {
e << " Byte: " << f.Offset();
throw;
}
}
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
+ util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config));
+ Search::UpdateConfigFromBinary(fd, counts, config);
+}
+
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) {
@@ -114,7 +119,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
}
float backoff;
// i is the order of the backoff we're looking for.
- const Middle *mid_iter = search_.MiddleBegin() + start - 2;
+ typename Search::MiddleIter mid_iter = search_.MiddleBegin() + start - 2;
for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) {
if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break;
ret.prob += backoff;
@@ -134,7 +139,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.LookupUnigram(*context_rbegin, out_state.backoff[0], node, ignored);
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
float *backoff_out = out_state.backoff + 1;
- const typename Search::Middle *mid = search_.MiddleBegin();
+ typename Search::MiddleIter mid(search_.MiddleBegin());
for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) {
if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) {
std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
@@ -161,7 +166,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
// If this function is called, then it does depend on left words.
ret.independent_left = false;
ret.extend_left = extend_pointer;
- const typename Search::Middle *mid_iter = search_.MiddleBegin() + extend_length - 1;
+ typename Search::MiddleIter mid_iter(search_.MiddleBegin() + extend_length - 1);
const WordIndex *i = add_rbegin;
for (; ; ++i, ++backoff_out, ++mid_iter) {
if (i == add_rend) {
@@ -230,7 +235,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
// Ok start by looking up the bigram.
const WordIndex *hist_iter = context_rbegin;
- const typename Search::Middle *mid_iter = search_.MiddleBegin();
+ typename Search::MiddleIter mid_iter(search_.MiddleBegin());
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == context_rend) {
// Ran out of history. Typically no backoff, but this could be a blank.
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index c278acd6..6ea62a78 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -90,7 +90,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
* TrieModel. To classify binary files, call RecognizeBinary in
* lm/binary_format.hh.
*/
- GenericModel(const char *file, const Config &config = Config());
+ explicit GenericModel(const char *file, const Config &config = Config());
/* Score p(new_word | in_state) and incorporate new_word into out_state.
* Note that in_state and out_state must be different references:
@@ -137,14 +137,9 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
unsigned char &next_use) const;
private:
- friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
+ friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
- AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config));
- Search::UpdateConfigFromBinary(fd, counts, config);
- }
-
- float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const;
+ static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 2654071f..461704d4 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -19,6 +19,20 @@ std::ostream &operator<<(std::ostream &o, const State &state) {
namespace {
+const char *TestLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 2) {
+ return "test.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[1];
+}
+const char *TestNoUnkLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 3) {
+ return "test_nounk.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[2];
+
+}
+
#define StartTest(word, ngram, score, indep_left) \
ret = model.FullScore( \
state, \
@@ -307,7 +321,7 @@ template <class ModelT> void LoadingTest() {
{
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
- ModelT m("test.arpa", config);
+ ModelT m(TestLocation(), config);
enumerate.Check(m.GetVocabulary());
BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
Everything(m);
@@ -315,7 +329,7 @@ template <class ModelT> void LoadingTest() {
{
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
- ModelT m("test_nounk.arpa", config);
+ ModelT m(TestNoUnkLocation(), config);
enumerate.Check(m.GetVocabulary());
BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
NoUnkCheck(m);
@@ -346,7 +360,7 @@ template <class ModelT> void BinaryTest() {
config.enumerate_vocab = &enumerate;
{
- ModelT copy_model("test.arpa", config);
+ ModelT copy_model(TestLocation(), config);
enumerate.Check(copy_model.GetVocabulary());
enumerate.Clear();
Everything(copy_model);
@@ -370,14 +384,14 @@ template <class ModelT> void BinaryTest() {
config.messages = NULL;
enumerate.Clear();
{
- ModelT copy_model("test_nounk.arpa", config);
+ ModelT copy_model(TestNoUnkLocation(), config);
enumerate.Check(copy_model.GetVocabulary());
enumerate.Clear();
NoUnkCheck(copy_model);
}
config.write_mmap = NULL;
{
- ModelT binary("test_nounk.binary", config);
+ ModelT binary(TestNoUnkLocation(), config);
enumerate.Check(binary.GetVocabulary());
NoUnkCheck(binary);
}
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc
index d9db4aa2..8f7a0e1c 100644
--- a/klm/lm/ngram_query.cc
+++ b/klm/lm/ngram_query.cc
@@ -1,87 +1,4 @@
-#include "lm/enumerate_vocab.hh"
-#include "lm/model.hh"
-
-#include <cstdlib>
-#include <fstream>
-#include <iostream>
-#include <string>
-
-#include <ctype.h>
-
-#include <sys/resource.h>
-#include <sys/time.h>
-
-float FloatSec(const struct timeval &tv) {
- return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000000.0);
-}
-
-void PrintUsage(const char *message) {
- struct rusage usage;
- if (getrusage(RUSAGE_SELF, &usage)) {
- perror("getrusage");
- return;
- }
- std::cerr << message;
- std::cerr << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n';
-
- // Linux doesn't set memory usage :-(.
- std::ifstream status("/proc/self/status", std::ios::in);
- std::string line;
- while (getline(status, line)) {
- if (!strncmp(line.c_str(), "VmRSS:\t", 7)) {
- std::cerr << "rss " << (line.c_str() + 7) << '\n';
- break;
- }
- }
-}
-
-template <class Model> void Query(const Model &model, bool sentence_context) {
- PrintUsage("Loading statistics:\n");
- typename Model::State state, out;
- lm::FullScoreReturn ret;
- std::string word;
-
- while (std::cin) {
- state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
- float total = 0.0;
- bool got = false;
- unsigned int oov = 0;
- while (std::cin >> word) {
- got = true;
- lm::WordIndex vocab = model.GetVocabulary().Index(word);
- if (vocab == 0) ++oov;
- ret = model.FullScore(state, vocab, out);
- total += ret.prob;
- std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
- state = out;
- char c;
- while (true) {
- c = std::cin.get();
- if (!std::cin) break;
- if (c == '\n') break;
- if (!isspace(c)) {
- std::cin.unget();
- break;
- }
- }
- if (c == '\n') break;
- }
- if (!got && !std::cin) break;
- if (sentence_context) {
- ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
- total += ret.prob;
- std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
- }
- std::cout << "Total: " << total << " OOV: " << oov << '\n';
- }
- PrintUsage("After queries:\n");
-}
-
-template <class Model> void Query(const char *name) {
- lm::ngram::Config config;
- Model model(name, config);
- Query(model);
-}
+#include "lm/ngram_query.hh"
int main(int argc, char *argv[]) {
if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) {
@@ -89,34 +6,40 @@ int main(int argc, char *argv[]) {
std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl;
return 1;
}
- bool sentence_context = (argc == 2);
- lm::ngram::ModelType model_type;
- if (lm::ngram::RecognizeBinary(argv[1], model_type)) {
- switch(model_type) {
- case lm::ngram::HASH_PROBING:
- Query<lm::ngram::ProbingModel>(argv[1], sentence_context);
- break;
- case lm::ngram::TRIE_SORTED:
- Query<lm::ngram::TrieModel>(argv[1], sentence_context);
- break;
- case lm::ngram::QUANT_TRIE_SORTED:
- Query<lm::ngram::QuantTrieModel>(argv[1], sentence_context);
- break;
- case lm::ngram::ARRAY_TRIE_SORTED:
- Query<lm::ngram::ArrayTrieModel>(argv[1], sentence_context);
- break;
- case lm::ngram::QUANT_ARRAY_TRIE_SORTED:
- Query<lm::ngram::QuantArrayTrieModel>(argv[1], sentence_context);
- break;
- case lm::ngram::HASH_SORTED:
- default:
- std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
- abort();
+ try {
+ bool sentence_context = (argc == 2);
+ using namespace lm::ngram;
+ ModelType model_type;
+ if (RecognizeBinary(argv[1], model_type)) {
+ switch(model_type) {
+ case HASH_PROBING:
+ Query<lm::ngram::ProbingModel>(argv[1], sentence_context, std::cin, std::cout);
+ break;
+ case TRIE_SORTED:
+ Query<TrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ break;
+ case QUANT_TRIE_SORTED:
+ Query<QuantTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ break;
+ case ARRAY_TRIE_SORTED:
+ Query<ArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ break;
+ case QUANT_ARRAY_TRIE_SORTED:
+ Query<QuantArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ break;
+ case HASH_SORTED:
+ default:
+ std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
+ abort();
+ }
+ } else {
+ Query<ProbingModel>(argv[1], sentence_context, std::cin, std::cout);
}
- } else {
- Query<lm::ngram::ProbingModel>(argv[1], sentence_context);
- }
- PrintUsage("Total time including destruction:\n");
+ PrintUsage("Total time including destruction:\n");
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
+ }
return 0;
}
diff --git a/klm/lm/ngram_query.hh b/klm/lm/ngram_query.hh
new file mode 100644
index 00000000..4990df22
--- /dev/null
+++ b/klm/lm/ngram_query.hh
@@ -0,0 +1,103 @@
+#ifndef LM_NGRAM_QUERY__
+#define LM_NGRAM_QUERY__
+
+#include "lm/enumerate_vocab.hh"
+#include "lm/model.hh"
+
+#include <cstdlib>
+#include <fstream>
+#include <iostream>
+#include <string>
+
+#include <ctype.h>
+#if !defined(_WIN32) && !defined(_WIN64)
+#include <sys/resource.h>
+#include <sys/time.h>
+#endif
+
+namespace lm {
+namespace ngram {
+
+#if !defined(_WIN32) && !defined(_WIN64)
+float FloatSec(const struct timeval &tv) {
+ return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000000.0);
+}
+#endif
+
+void PrintUsage(const char *message) {
+#if !defined(_WIN32) && !defined(_WIN64)
+ struct rusage usage;
+ if (getrusage(RUSAGE_SELF, &usage)) {
+ perror("getrusage");
+ return;
+ }
+ std::cerr << message;
+ std::cerr << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n';
+
+ // Linux doesn't set memory usage :-(.
+ std::ifstream status("/proc/self/status", std::ios::in);
+ std::string line;
+ while (getline(status, line)) {
+ if (!strncmp(line.c_str(), "VmRSS:\t", 7)) {
+ std::cerr << "rss " << (line.c_str() + 7) << '\n';
+ break;
+ }
+ }
+#endif
+}
+
+template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
+ PrintUsage("Loading statistics:\n");
+ typename Model::State state, out;
+ lm::FullScoreReturn ret;
+ std::string word;
+
+ while (in_stream) {
+ state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
+ float total = 0.0;
+ bool got = false;
+ unsigned int oov = 0;
+ while (in_stream >> word) {
+ got = true;
+ lm::WordIndex vocab = model.GetVocabulary().Index(word);
+ if (vocab == 0) ++oov;
+ ret = model.FullScore(state, vocab, out);
+ total += ret.prob;
+ out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
+ state = out;
+ char c;
+ while (true) {
+ c = in_stream.get();
+ if (!in_stream) break;
+ if (c == '\n') break;
+ if (!isspace(c)) {
+ in_stream.unget();
+ break;
+ }
+ }
+ if (c == '\n') break;
+ }
+ if (!got && !in_stream) break;
+ if (sentence_context) {
+ ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
+ total += ret.prob;
+ out_stream << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
+ }
+ out_stream << "Total: " << total << " OOV: " << oov << '\n';
+ }
+ PrintUsage("After queries:\n");
+}
+
+template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
+ Config config;
+// config.load_method = util::LAZY;
+ M model(file, config);
+ Query(model, sentence_context, in_stream, out_stream);
+}
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_NGRAM_QUERY__
+
+
diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc
index 98a5d048..a8e0cb21 100644
--- a/klm/lm/quantize.cc
+++ b/klm/lm/quantize.cc
@@ -1,31 +1,30 @@
+/* Quantize into bins of equal size as described in
+ * M. Federico and N. Bertoldi. 2006. How many bits are needed
+ * to store probabilities for phrase-based translation? In Proc.
+ * of the Workshop on Statistical Machine Translation, pages
+ * 94–101, New York City, June. Association for Computa-
+ * tional Linguistics.
+ */
+
#include "lm/quantize.hh"
#include "lm/binary_format.hh"
#include "lm/lm_exception.hh"
+#include "util/file.hh"
#include <algorithm>
#include <numeric>
-#include <unistd.h>
-
namespace lm {
namespace ngram {
-/* Quantize into bins of equal size as described in
- * M. Federico and N. Bertoldi. 2006. How many bits are needed
- * to store probabilities for phrase-based translation? In Proc.
- * of the Workshop on Statistical Machine Translation, pages
- * 94–101, New York City, June. Association for Computa-
- * tional Linguistics.
- */
-
namespace {
-void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) {
- std::sort(values, values_end);
- const float *start = values, *finish;
+void MakeBins(std::vector<float> &values, float *centers, uint32_t bins) {
+ std::sort(values.begin(), values.end());
+ std::vector<float>::const_iterator start = values.begin(), finish;
for (uint32_t i = 0; i < bins; ++i, ++centers, start = finish) {
- finish = values + (((values_end - values) * static_cast<uint64_t>(i + 1)) / bins);
+ finish = values.begin() + ((values.size() * static_cast<uint64_t>(i + 1)) / bins);
if (finish == start) {
// zero length bucket.
*centers = i ? *(centers - 1) : -std::numeric_limits<float>::infinity();
@@ -41,10 +40,11 @@ const char kSeparatelyQuantizeVersion = 2;
void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) {
char version;
- if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1)
- UTIL_THROW(util::ErrnoException, "Failed to read header for quantization.");
+ util::ReadOrThrow(fd, &version, 1);
+ util::ReadOrThrow(fd, &config.prob_bits, 1);
+ util::ReadOrThrow(fd, &config.backoff_bits, 1);
if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion);
- AdvanceOrThrow(fd, -3);
+ util::AdvanceOrThrow(fd, -3);
}
void SeparatelyQuantize::SetupMemory(void *start, const Config &config) {
@@ -66,12 +66,12 @@ void SeparatelyQuantize::Train(uint8_t order, std::vector<float> &prob, std::vec
float *centers = start_ + TableStart(order) + ProbTableLength();
*(centers++) = kNoExtensionBackoff;
*(centers++) = kExtensionBackoff;
- MakeBins(&*backoff.begin(), &*backoff.end(), centers, (1ULL << backoff_bits_) - 2);
+ MakeBins(backoff, centers, (1ULL << backoff_bits_) - 2);
}
void SeparatelyQuantize::TrainProb(uint8_t order, std::vector<float> &prob) {
float *centers = start_ + TableStart(order);
- MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_));
+ MakeBins(prob, centers, (1ULL << prob_bits_));
}
void SeparatelyQuantize::FinishedLoading(const Config &config) {
diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh
index 4cf4236e..6d130a57 100644
--- a/klm/lm/quantize.hh
+++ b/klm/lm/quantize.hh
@@ -9,7 +9,7 @@
#include <algorithm>
#include <vector>
-#include <inttypes.h>
+#include <stdint.h>
#include <iostream>
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index dce73f77..05f761be 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -8,7 +8,7 @@
#include <ctype.h>
#include <string.h>
-#include <inttypes.h>
+#include <stdint.h>
namespace lm {
diff --git a/klm/lm/return.hh b/klm/lm/return.hh
index 15571960..1b55091b 100644
--- a/klm/lm/return.hh
+++ b/klm/lm/return.hh
@@ -1,7 +1,7 @@
#ifndef LM_RETURN__
#define LM_RETURN__
-#include <inttypes.h>
+#include <stdint.h>
namespace lm {
/* Structure returned by scoring routines. */
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 247832b0..1d6fb5be 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -30,7 +30,7 @@ template <class Middle> class ActivateLowerMiddle {
// TODO: somehow get text of n-gram for this error message.
if (!modify_.UnsafeMutableFind(hash, i))
UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram");
- SetExtension(i->MutableValue().backoff);
+ SetExtension(i->value.backoff);
}
private:
@@ -65,7 +65,7 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign
blank.prob -= unigrams[vocab_ids[1]].backoff;
SetExtension(unigrams[vocab_ids[1]].backoff);
// Bigram including a unigram's backoff
- middle[0].Insert(Middle::Packing::Make(keys[0], blank));
+ middle[0].Insert(detail::ProbBackoffEntry::Make(keys[0], blank));
fix = 1;
} else {
for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]);
@@ -74,22 +74,24 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign
for (; fix <= n - 3; ++fix) {
typename Middle::MutableIterator gotit;
if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) {
- float &backoff = gotit->MutableValue().backoff;
+ float &backoff = gotit->value.backoff;
SetExtension(backoff);
blank.prob -= backoff;
}
- middle[fix].Insert(Middle::Packing::Make(keys[fix], blank));
+ middle[fix].Insert(detail::ProbBackoffEntry::Make(keys[fix], blank));
backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]);
}
}
template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) {
+ assert(n >= 2);
ReadNGramHeader(f, n);
- // vocab ids of words in reverse order
+ // Both vocab_ids and keys are non-empty because n >= 2.
+ // vocab ids of words in reverse order.
std::vector<WordIndex> vocab_ids(n);
std::vector<uint64_t> keys(n-1);
- typename Store::Packing::Value value;
+ typename Store::Entry::Value value;
typename Middle::MutableIterator found;
for (size_t i = 0; i < count; ++i) {
ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn);
@@ -100,7 +102,7 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
}
// Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
util::SetSign(value.prob);
- store.Insert(Store::Packing::Make(keys[n-2], value));
+ store.Insert(Store::Entry::Make(keys[n-2], value));
// Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
int lower;
util::FloatEnc fix_prob;
@@ -113,9 +115,9 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
}
if (middle[lower].UnsafeMutableFind(keys[lower], found)) {
// Turn off sign bit to indicate that it extends left.
- fix_prob.f = found->MutableValue().prob;
+ fix_prob.f = found->value.prob;
fix_prob.i &= ~util::kSignBit;
- found->MutableValue().prob = fix_prob.f;
+ found->value.prob = fix_prob.f;
// We don't need to recurse further down because this entry already set the bits for lower entries.
break;
}
@@ -147,7 +149,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT,
template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {
// TODO: fix sorted.
- SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config);
+ SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config);
PositiveProbWarn warn(config.positive_log_probability);
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index e289fd11..4352c72d 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -8,7 +8,6 @@
#include "lm/weights.hh"
#include "util/bit_packing.hh"
-#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include <algorithm>
@@ -92,8 +91,10 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing);
- const Middle *MiddleBegin() const { return &*middle_.begin(); }
- const Middle *MiddleEnd() const { return &*middle_.end(); }
+ typedef typename std::vector<Middle>::const_iterator MiddleIter;
+
+ MiddleIter MiddleBegin() const { return middle_.begin(); }
+ MiddleIter MiddleEnd() const { return middle_.end(); }
Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
util::FloatEnc val;
@@ -105,7 +106,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl;
abort();
}
- val.f = found->GetValue().prob;
+ val.f = found->value.prob;
}
val.i |= util::kSignBit;
prob = val.f;
@@ -117,12 +118,12 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
util::FloatEnc enc;
- enc.f = found->GetValue().prob;
+ enc.f = found->value.prob;
ret.independent_left = (enc.i & util::kSignBit);
ret.extend_left = node;
enc.i |= util::kSignBit;
ret.prob = enc.f;
- backoff = found->GetValue().backoff;
+ backoff = found->value.backoff;
return true;
}
@@ -132,7 +133,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
- backoff = found->GetValue().backoff;
+ backoff = found->value.backoff;
return true;
}
@@ -141,7 +142,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
node = CombineWordHash(node, word);
typename Longest::ConstIterator found;
if (!longest.Find(node, found)) return false;
- prob = found->GetValue().prob;
+ prob = found->value.prob;
return true;
}
@@ -160,14 +161,50 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
std::vector<Middle> middle_;
};
-// std::identity is an SGI extension :-(
-struct IdentityHash : public std::unary_function<uint64_t, size_t> {
- size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
+/* These look like perfect candidates for a template, right? Ancient gcc (4.1
+ * on RedHat stale linux) doesn't pack templates correctly. ProbBackoffEntry
+ * is a multiple of 8 bytes anyway. ProbEntry is 12 bytes so it's set to pack.
+ */
+struct ProbBackoffEntry {
+ uint64_t key;
+ ProbBackoff value;
+ typedef uint64_t Key;
+ typedef ProbBackoff Value;
+ uint64_t GetKey() const {
+ return key;
+ }
+ static ProbBackoffEntry Make(uint64_t key, ProbBackoff value) {
+ ProbBackoffEntry ret;
+ ret.key = key;
+ ret.value = value;
+ return ret;
+ }
};
+#pragma pack(push)
+#pragma pack(4)
+struct ProbEntry {
+ uint64_t key;
+ Prob value;
+ typedef uint64_t Key;
+ typedef Prob Value;
+ uint64_t GetKey() const {
+ return key;
+ }
+ static ProbEntry Make(uint64_t key, Prob value) {
+ ProbEntry ret;
+ ret.key = key;
+ ret.value = value;
+ return ret;
+ }
+};
+
+#pragma pack(pop)
+
+
struct ProbingHashedSearch : public TemplateHashedSearch<
- util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>,
- util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > {
+ util::ProbingHashTable<ProbBackoffEntry, util::IdentityHash>,
+ util::ProbingHashTable<ProbEntry, util::IdentityHash> > {
static const ModelType kModelType = HASH_PROBING;
};
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 4bd3f4ee..ffadfa94 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -13,6 +13,7 @@
#include "lm/weights.hh"
#include "lm/word_index.hh"
#include "util/ersatz_progress.hh"
+#include "util/mmap.hh"
#include "util/proxy_iterator.hh"
#include "util/scoped.hh"
#include "util/sized_iterator.hh"
@@ -20,14 +21,15 @@
#include <algorithm>
#include <cstring>
#include <cstdio>
+#include <cstdlib>
#include <queue>
#include <limits>
#include <numeric>
#include <vector>
-#include <sys/mman.h>
-#include <sys/types.h>
-#include <sys/stat.h>
+#if defined(_WIN32) || defined(_WIN64)
+#include <windows.h>
+#endif
namespace lm {
namespace ngram {
@@ -195,7 +197,7 @@ class SRISucks {
void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) {
for (unsigned char i = 0; i < kMaxOrder - 1; ++i) {
- it_[i] = &*values_[i].begin();
+ it_[i] = values_[i].empty() ? NULL : &*values_[i].begin();
}
messages_[0].Apply(it_, unigram_file);
BackoffMessages *messages = messages_ + 1;
@@ -227,8 +229,8 @@ class SRISucks {
class FindBlanks {
public:
- FindBlanks(uint64_t *counts, unsigned char order, const ProbBackoff *unigrams, SRISucks &messages)
- : counts_(counts), longest_counts_(counts + order - 1), unigrams_(unigrams), sri_(messages) {}
+ FindBlanks(unsigned char order, const ProbBackoff *unigrams, SRISucks &messages)
+ : counts_(order), unigrams_(unigrams), sri_(messages) {}
float UnigramProb(WordIndex index) const {
return unigrams_[index].prob;
@@ -248,7 +250,7 @@ class FindBlanks {
}
void Longest(const void * /*data*/) {
- ++*longest_counts_;
+ ++counts_.back();
}
// Unigrams wrote one past.
@@ -256,8 +258,12 @@ class FindBlanks {
--counts_[0];
}
+ const std::vector<uint64_t> &Counts() const {
+ return counts_;
+ }
+
private:
- uint64_t *const counts_, *const longest_counts_;
+ std::vector<uint64_t> counts_;
const ProbBackoff *unigrams_;
@@ -375,7 +381,7 @@ template <class Doing> class BlankManager {
template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) {
util::ErsatzProgress progress(progress_out, message, unigram_count + 1);
- unsigned int unigram = 0;
+ WordIndex unigram = 0;
std::priority_queue<Gram> grams;
grams.push(Gram(&unigram, 1));
for (unsigned char i = 2; i <= total_order; ++i) {
@@ -461,42 +467,33 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c
} // namespace
-template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
RecordReader inputs[kMaxOrder - 1];
RecordReader contexts[kMaxOrder - 1];
for (unsigned char i = 2; i <= counts.size(); ++i) {
- std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(i) << "_merged";
- inputs[i-2].Init(assembled.str(), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff)));
- util::RemoveOrThrow(assembled.str().c_str());
- assembled << kContextSuffix;
- contexts[i-2].Init(assembled.str(), (i-1) * sizeof(WordIndex));
- util::RemoveOrThrow(assembled.str().c_str());
+ inputs[i-2].Init(files.Full(i), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff)));
+ contexts[i-2].Init(files.Context(i), (i-1) * sizeof(WordIndex));
}
SRISucks sri;
- std::vector<uint64_t> fixed_counts(counts.size());
+ std::vector<uint64_t> fixed_counts;
+ util::scoped_FILE unigram_file;
+ util::scoped_fd unigram_fd(files.StealUnigram());
{
- std::string temp(file_prefix); temp += "unigrams";
- util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str()));
util::scoped_memory unigrams;
- MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
- FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
+ MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
+ FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder);
+ fixed_counts = finder.Counts();
}
+ unigram_file.reset(util::FDOpenOrThrow(unigram_fd));
for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) {
if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading");
}
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
- util::scoped_FILE unigram_file;
- {
- std::string name(file_prefix + "unigrams");
- unigram_file.reset(OpenOrThrow(name.c_str(), "r+"));
- util::RemoveOrThrow(name.c_str());
- }
sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
@@ -587,42 +584,19 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBin
longest.LoadedBinary();
}
-namespace {
-bool IsDirectory(const char *path) {
- struct stat info;
- if (0 != stat(path, &info)) return false;
- return S_ISDIR(info.st_mode);
-}
-} // namespace
-
template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
- std::string temporary_directory;
+ std::string temporary_prefix;
if (config.temporary_directory_prefix) {
- temporary_directory = config.temporary_directory_prefix;
- if (!temporary_directory.empty() && temporary_directory[temporary_directory.size() - 1] != '/' && IsDirectory(temporary_directory.c_str()))
- temporary_directory += '/';
+ temporary_prefix = config.temporary_directory_prefix;
} else if (config.write_mmap) {
- temporary_directory = config.write_mmap;
+ temporary_prefix = config.write_mmap;
} else {
- temporary_directory = file;
- }
- // Null on end is kludge to ensure null termination.
- temporary_directory += "_trie_tmp_XXXXXX";
- temporary_directory += '\0';
- if (!mkdtemp(&temporary_directory[0])) {
- UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str());
+ temporary_prefix = file;
}
- // Chop off null kludge.
- temporary_directory.resize(strlen(temporary_directory.c_str()));
- // Add directory delimiter. Assumes a real operating system.
- temporary_directory += '/';
// At least 1MB sorting memory.
- ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
+ SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab);
- BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing);
- if (rmdir(temporary_directory.c_str()) && config.messages) {
- *config.messages << "Failed to delete " << temporary_directory << std::endl;
- }
+ BuildTrie(sorted, counts, config, *this, quant_, vocab, backing);
}
template class TrieSearch<DontQuantize, DontBhiksha>;
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 33ae8cff..5155ca02 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -7,6 +7,7 @@
#include "lm/trie.hh"
#include "lm/weights.hh"
+#include "util/file.hh"
#include "util/file_piece.hh"
#include <vector>
@@ -20,7 +21,8 @@ class SortedVocabulary;
namespace trie {
template <class Quant, class Bhiksha> class TrieSearch;
-template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
+class SortedFiles;
+template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
template <class Quant, class Bhiksha> class TrieSearch {
public:
@@ -40,7 +42,7 @@ template <class Quant, class Bhiksha> class TrieSearch {
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
Quant::UpdateConfigFromBinary(fd, counts, config);
- AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
+ util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
Bhiksha::UpdateConfigFromBinary(fd, config);
}
@@ -60,6 +62,8 @@ template <class Quant, class Bhiksha> class TrieSearch {
void LoadedBinary();
+ typedef const Middle *MiddleIter;
+
const Middle *MiddleBegin() const { return middle_begin_; }
const Middle *MiddleEnd() const { return middle_end_; }
@@ -108,7 +112,7 @@ template <class Quant, class Bhiksha> class TrieSearch {
}
private:
- friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
+ friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index 06cc96ac..ebe9910f 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -1,7 +1,7 @@
#ifndef LM_TRIE__
#define LM_TRIE__
-#include <inttypes.h>
+#include <stdint.h>
#include <cstddef>
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index bb126f18..b80fed02 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -14,6 +14,7 @@
#include <algorithm>
#include <cstring>
#include <cstdio>
+#include <cstdlib>
#include <deque>
#include <limits>
#include <vector>
@@ -22,14 +23,6 @@ namespace lm {
namespace ngram {
namespace trie {
-const char *kContextSuffix = "_contexts";
-
-FILE *OpenOrThrow(const char *name, const char *mode) {
- FILE *ret = fopen(name, mode);
- if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode);
- return ret;
-}
-
void WriteOrThrow(FILE *to, const void *data, size_t size) {
assert(size);
if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
@@ -78,28 +71,29 @@ class PartialViewProxy {
typedef util::ProxyIterator<PartialViewProxy> PartialIter;
-std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order) {
- std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch;
- std::string ret(assembled.str());
- util::scoped_fd out(util::CreateOrThrow(ret.c_str()));
- util::WriteOrThrow(out.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin);
- return ret;
+FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) {
+ util::scoped_fd file(maker.Make());
+ util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin);
+ return util::FDOpenOrThrow(file);
}
-void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) {
+FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) {
const size_t context_size = sizeof(WordIndex) * (order - 1);
// Sort just the contexts using the same memory.
PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size));
- std::sort(context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
+#if defined(_WIN32) || defined(_WIN64)
+ std::stable_sort
+#else
+ std::sort
+#endif
+ (context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
- std::string name(ngram_file_name + kContextSuffix);
- util::scoped_FILE out(OpenOrThrow(name.c_str(), "w"));
+ util::scoped_FILE out(maker.MakeFile());
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
- if (context_begin == context_end) return;
+ if (context_begin == context_end) return out.release();
PartialIter i(context_begin);
WriteOrThrow(out.get(), i->Data(), context_size);
const void *previous = i->Data();
@@ -110,6 +104,7 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil
previous = i->Data();
}
}
+ return out.release();
}
struct ThrowCombine {
@@ -125,14 +120,12 @@ struct FirstCombine {
}
};
-template <class Combine> void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order, const Combine &combine = ThrowCombine()) {
+template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) {
std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
RecordReader first, second;
- first.Init(first_name.c_str(), entry_size);
- util::RemoveOrThrow(first_name.c_str());
- second.Init(second_name.c_str(), entry_size);
- util::RemoveOrThrow(second_name.c_str());
- util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w"));
+ first.Init(first_file, entry_size);
+ second.Init(second_file, entry_size);
+ util::scoped_FILE out_file(maker.MakeFile());
EntryCompare less(order);
while (first && second) {
if (less(first.Data(), second.Data())) {
@@ -149,67 +142,14 @@ template <class Combine> void MergeSortedFiles(const std::string &first_name, co
for (RecordReader &remains = (first ? first : second); remains; ++remains) {
WriteOrThrow(out_file.get(), remains.Data(), entry_size);
}
-}
-
-void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) {
- ReadNGramHeader(f, order);
- const size_t count = counts[order - 1];
- // Size of weights. Does it include backoff?
- const size_t words_size = sizeof(WordIndex) * order;
- const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float));
- const size_t entry_size = words_size + weights_size;
- const size_t batch_size = std::min(count, mem.size() / entry_size);
- uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get());
- std::deque<std::string> files;
- for (std::size_t batch = 0, done = 0; done < count; ++batch) {
- uint8_t *out = begin;
- uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
- if (order == counts.size()) {
- for (; out != out_end; out += entry_size) {
- ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn);
- }
- } else {
- for (; out != out_end; out += entry_size) {
- ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
- }
- }
- // Sort full records by full n-gram.
- util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
- // parallel_sort uses too much RAM
- std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order)));
- files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order));
- WriteContextFile(begin, out_end, files.back(), entry_size, order);
-
- done += (out_end - begin) / entry_size;
- }
-
- // All individual files created. Merge them.
-
- std::size_t merge_count = 0;
- while (files.size() > 1) {
- std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++);
- files.push_back(assembled.str());
- MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine());
- MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order - 1, FirstCombine());
- files.pop_front();
- files.pop_front();
- }
- if (!files.empty()) {
- std::stringstream assembled;
- assembled << file_prefix << static_cast<unsigned int>(order) << "_merged";
- std::string merged_name(assembled.str());
- if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str());
- std::string context_name = files[0] + kContextSuffix;
- merged_name += kContextSuffix;
- if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str());
- }
+ return out_file.release();
}
} // namespace
-void RecordReader::Init(const std::string &name, std::size_t entry_size) {
- file_.reset(OpenOrThrow(name.c_str(), "r+"));
+void RecordReader::Init(FILE *file, std::size_t entry_size) {
+ rewind(file);
+ file_ = file;
data_.reset(malloc(entry_size));
UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer");
remains_ = true;
@@ -219,20 +159,29 @@ void RecordReader::Init(const std::string &name, std::size_t entry_size) {
void RecordReader::Overwrite(const void *start, std::size_t amount) {
long internal = (uint8_t*)start - (uint8_t*)data_.get();
- UTIL_THROW_IF(fseek(file_.get(), internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
- WriteOrThrow(file_.get(), start, amount);
+ UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
+ WriteOrThrow(file_, start, amount);
long forward = entry_size_ - internal - amount;
- if (forward) UTIL_THROW_IF(fseek(file_.get(), forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision");
+#if !defined(_WIN32) && !defined(_WIN64)
+ if (forward)
+#endif
+ UTIL_THROW_IF(fseek(file_, forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision");
}
-void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
+void RecordReader::Rewind() {
+ rewind(file_);
+ remains_ = true;
+ ++*this;
+}
+
+SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
+ util::TempMaker maker(file_prefix);
PositiveProbWarn warn(config.positive_log_probability);
+ unigram_.reset(maker.Make());
{
- std::string unigram_name = file_prefix + "unigrams";
- util::scoped_fd unigram_file;
// In case <unk> appears.
- size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff);
- util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out);
+ size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff);
+ util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_.get(), size_out), size_out);
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);
CheckSpecials(config, vocab);
if (!vocab.SawUnk()) ++counts[0];
@@ -246,16 +195,96 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin
buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
buffer = std::min<size_t>(buffer, buffer_use);
- util::scoped_memory mem;
- mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED);
+ util::scoped_malloc mem;
+ mem.reset(malloc(buffer));
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
for (unsigned char order = 2; order <= counts.size(); ++order) {
- ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn);
+ ConvertToSorted(f, vocab, counts, maker, order, warn, mem.get(), buffer);
}
ReadEnd(f);
}
+namespace {
+class Closer {
+ public:
+ explicit Closer(std::deque<FILE*> &files) : files_(files) {}
+
+ ~Closer() {
+ for (std::deque<FILE*>::iterator i = files_.begin(); i != files_.end(); ++i) {
+ util::scoped_FILE deleter(*i);
+ }
+ }
+
+ void PopFront() {
+ util::scoped_FILE deleter(files_.front());
+ files_.pop_front();
+ }
+ private:
+ std::deque<FILE*> &files_;
+};
+} // namespace
+
+void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
+ ReadNGramHeader(f, order);
+ const size_t count = counts[order - 1];
+ // Size of weights. Does it include backoff?
+ const size_t words_size = sizeof(WordIndex) * order;
+ const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float));
+ const size_t entry_size = words_size + weights_size;
+ const size_t batch_size = std::min(count, mem_size / entry_size);
+ uint8_t *const begin = reinterpret_cast<uint8_t*>(mem);
+
+ std::deque<FILE*> files, contexts;
+ Closer files_closer(files), contexts_closer(contexts);
+
+ for (std::size_t batch = 0, done = 0; done < count; ++batch) {
+ uint8_t *out = begin;
+ uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
+ if (order == counts.size()) {
+ for (; out != out_end; out += entry_size) {
+ ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn);
+ }
+ } else {
+ for (; out != out_end; out += entry_size) {
+ ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
+ }
+ }
+ // Sort full records by full n-gram.
+ util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
+ // parallel_sort uses too much RAM. TODO: figure out why windows sort doesn't like my proxies.
+#if defined(_WIN32) || defined(_WIN64)
+ std::stable_sort
+#else
+ std::sort
+#endif
+ (NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order)));
+ files.push_back(DiskFlush(begin, out_end, maker));
+ contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order));
+
+ done += (out_end - begin) / entry_size;
+ }
+
+ // All individual files created. Merge them.
+
+ while (files.size() > 1) {
+ files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine()));
+ files_closer.PopFront();
+ files_closer.PopFront();
+ contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine()));
+ contexts_closer.PopFront();
+ contexts_closer.PopFront();
+ }
+
+ if (!files.empty()) {
+ // Steal from closers.
+ full_[order - 2].reset(files.front());
+ files.pop_front();
+ context_[order - 2].reset(contexts.front());
+ contexts.pop_front();
+ }
+}
+
} // namespace trie
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh
index a6916483..3036319d 100644
--- a/klm/lm/trie_sort.hh
+++ b/klm/lm/trie_sort.hh
@@ -1,6 +1,9 @@
+// Step of trie builder: create sorted files.
+
#ifndef LM_TRIE_SORT__
#define LM_TRIE_SORT__
+#include "lm/max_order.hh"
#include "lm/word_index.hh"
#include "util/file.hh"
@@ -11,20 +14,21 @@
#include <string>
#include <vector>
-#include <inttypes.h>
+#include <stdint.h>
-namespace util { class FilePiece; }
+namespace util {
+class FilePiece;
+class TempMaker;
+} // namespace util
-// Step of trie builder: create sorted files.
namespace lm {
+class PositiveProbWarn;
namespace ngram {
class SortedVocabulary;
class Config;
namespace trie {
-extern const char *kContextSuffix;
-FILE *OpenOrThrow(const char *name, const char *mode);
void WriteOrThrow(FILE *to, const void *data, size_t size);
class EntryCompare : public std::binary_function<const void*, const void*, bool> {
@@ -49,15 +53,15 @@ class RecordReader {
public:
RecordReader() : remains_(true) {}
- void Init(const std::string &name, std::size_t entry_size);
+ void Init(FILE *file, std::size_t entry_size);
void *Data() { return data_.get(); }
const void *Data() const { return data_.get(); }
RecordReader &operator++() {
- std::size_t ret = fread(data_.get(), entry_size_, 1, file_.get());
+ std::size_t ret = fread(data_.get(), entry_size_, 1, file_);
if (!ret) {
- UTIL_THROW_IF(!feof(file_.get()), util::ErrnoException, "Error reading temporary file");
+ UTIL_THROW_IF(!feof(file_), util::ErrnoException, "Error reading temporary file");
remains_ = false;
}
return *this;
@@ -65,27 +69,46 @@ class RecordReader {
operator bool() const { return remains_; }
- void Rewind() {
- rewind(file_.get());
- remains_ = true;
- ++*this;
- }
+ void Rewind();
std::size_t EntrySize() const { return entry_size_; }
void Overwrite(const void *start, std::size_t amount);
private:
+ FILE *file_;
+
util::scoped_malloc data_;
bool remains_;
std::size_t entry_size_;
-
- util::scoped_FILE file_;
};
-void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab);
+class SortedFiles {
+ public:
+ // Build from ARPA
+ SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab);
+
+ int StealUnigram() {
+ return unigram_.release();
+ }
+
+ FILE *Full(unsigned char order) {
+ return full_[order - 2].get();
+ }
+
+ FILE *Context(unsigned char of_order) {
+ return context_[of_order - 2].get();
+ }
+
+ private:
+ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size);
+
+ util::scoped_fd unigram_;
+
+ util::scoped_FILE full_[kMaxOrder - 1], context_[kMaxOrder - 1];
+};
} // namespace trie
} // namespace ngram
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index ffec41ca..9fd698bb 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -6,12 +6,15 @@
#include "lm/config.hh"
#include "lm/weights.hh"
#include "util/exception.hh"
+#include "util/file.hh"
#include "util/joint_sort.hh"
#include "util/murmur_hash.hh"
#include "util/probing_hash_table.hh"
#include <string>
+#include <string.h>
+
namespace lm {
namespace ngram {
@@ -29,23 +32,30 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
-WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {
- if (!enumerate) return std::numeric_limits<WordIndex>::max();
+void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) {
+ // Check that we're at the right place by reading <unk> which is always first.
+ char check_unk[6];
+ util::ReadOrThrow(fd, check_unk, 6);
+ UTIL_THROW_IF(
+ memcmp(check_unk, "<unk>", 6),
+ FormatLoadException,
+ "Vocabulary words are in the wrong place. This could be because the binary file was built with stale gcc and old kenlm. Stale gcc, including the gcc distributed with RedHat and OS X, has a bug that ignores pragma pack for template-dependent types. New kenlm works around this, so you'll save memory but have to rebuild any binary files using the probing data structure.");
+ if (!enumerate) return;
+ enumerate->Add(0, "<unk>");
+
+ // Read all the words after unk.
const std::size_t kInitialRead = 16384;
std::string buf;
buf.reserve(kInitialRead + 100);
buf.resize(kInitialRead);
- WordIndex index = 0;
+ WordIndex index = 1; // Read <unk> already.
while (true) {
- ssize_t got = read(fd, &buf[0], kInitialRead);
- UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words");
- if (got == 0) return index;
+ std::size_t got = util::ReadOrEOF(fd, &buf[0], kInitialRead);
+ if (got == 0) break;
buf.resize(got);
while (buf[buf.size() - 1]) {
char next_char;
- ssize_t ret = read(fd, &next_char, 1);
- UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words");
- UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word.");
+ util::ReadOrThrow(fd, &next_char, 1);
buf.push_back(next_char);
}
// Ok now we have null terminated strings.
@@ -55,6 +65,8 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {
i += length + 1 /* null byte */;
}
}
+
+ UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file.");
}
} // namespace
@@ -69,8 +81,7 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
}
void WriteWordsWrapper::Write(int fd) {
- if ((off_t)-1 == lseek(fd, 0, SEEK_END))
- UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words");
+ util::SeekEnd(fd);
util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
}
@@ -114,8 +125,10 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
if (enumerate_) {
- util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
- util::JointSort(begin_, end_, values);
+ if (!strings_to_enumerate_.empty()) {
+ util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
+ util::JointSort(begin_, end_, values);
+ }
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
// <unk> strikes again: +1 here.
enumerate_->Add(i + 1, strings_to_enumerate_[i]);
@@ -131,11 +144,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
bound_ = end_ - begin_ + 1;
}
-void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
+void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
- ReadWords(fd, to);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
+ if (have_words) ReadWords(fd, to, bound_);
}
namespace {
@@ -153,12 +166,12 @@ struct ProbingVocabularyHeader {
ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) {
- return Align8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier);
+ return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier);
}
void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) {
header_ = static_cast<detail::ProbingVocabularyHeader*>(start);
- lookup_ = Lookup(static_cast<uint8_t*>(start) + Align8(sizeof(detail::ProbingVocabularyHeader)), allocated);
+ lookup_ = Lookup(static_cast<uint8_t*>(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated);
bound_ = 1;
saw_unk_ = false;
}
@@ -178,7 +191,7 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
return 0;
} else {
if (enumerate_) enumerate_->Add(bound_, str);
- lookup_.Insert(Lookup::Packing::Make(hashed, bound_));
+ lookup_.Insert(ProbingVocabuaryEntry::Make(hashed, bound_));
return bound_++;
}
}
@@ -190,12 +203,12 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) {
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
-void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
+void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
lookup_.LoadedBinary();
- ReadWords(fd, to);
bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
+ if (have_words) ReadWords(fd, to, bound_);
}
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 3c3414fb..06fdefe4 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -4,7 +4,6 @@
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/virtual_interface.hh"
-#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
@@ -83,7 +82,7 @@ class SortedVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
private:
uint64_t *begin_, *end_;
@@ -100,6 +99,26 @@ class SortedVocabulary : public base::Vocabulary {
std::vector<std::string> strings_to_enumerate_;
};
+#pragma pack(push)
+#pragma pack(4)
+struct ProbingVocabuaryEntry {
+ uint64_t key;
+ WordIndex value;
+
+ typedef uint64_t Key;
+ uint64_t GetKey() const {
+ return key;
+ }
+
+ static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) {
+ ProbingVocabuaryEntry ret;
+ ret.key = key;
+ ret.value = value;
+ return ret;
+ }
+};
+#pragma pack(pop)
+
// Vocabulary storing a map from uint64_t to WordIndex.
class ProbingVocabulary : public base::Vocabulary {
public:
@@ -107,7 +126,7 @@ class ProbingVocabulary : public base::Vocabulary {
WordIndex Index(const StringPiece &str) const {
Lookup::ConstIterator i;
- return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0;
+ return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
}
static size_t Size(std::size_t entries, const Config &config);
@@ -124,17 +143,14 @@ class ProbingVocabulary : public base::Vocabulary {
void FinishedLoading(ProbBackoff *reorder_vocab);
+ std::size_t UnkCountChangePadding() const { return 0; }
+
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
private:
- // std::identity is an SGI extension :-(
- struct IdentityHash : public std::unary_function<uint64_t, std::size_t> {
- std::size_t operator()(uint64_t arg) const { return static_cast<std::size_t>(arg); }
- };
-
- typedef util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, WordIndex>, IdentityHash> Lookup;
+ typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup;
Lookup lookup_;