summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
Diffstat (limited to 'klm')
-rw-r--r--klm/lm/Jamfile14
-rw-r--r--klm/lm/bhiksha.cc3
-rw-r--r--klm/lm/bhiksha.hh6
-rw-r--r--klm/lm/binary_format.cc20
-rw-r--r--klm/lm/binary_format.hh4
-rw-r--r--klm/lm/build_binary.cc18
-rw-r--r--klm/lm/fragment.cc37
-rw-r--r--klm/lm/left.hh2
-rw-r--r--klm/lm/max_order.cc6
-rw-r--r--klm/lm/max_order.hh26
-rw-r--r--klm/lm/model.cc26
-rw-r--r--klm/lm/model.hh3
-rw-r--r--klm/lm/partial.hh167
-rw-r--r--klm/lm/partial_test.cc199
-rw-r--r--klm/lm/quantize.hh12
-rw-r--r--klm/lm/read_arpa.cc34
-rw-r--r--klm/lm/search_hashed.cc2
-rw-r--r--klm/lm/search_hashed.hh8
-rw-r--r--klm/lm/search_trie.cc22
-rw-r--r--klm/lm/search_trie.hh4
-rw-r--r--klm/lm/sri_test.cc65
-rw-r--r--klm/lm/state.hh12
-rw-r--r--klm/lm/trie.cc4
-rw-r--r--klm/lm/trie.hh10
-rw-r--r--klm/lm/trie_sort.cc44
-rw-r--r--klm/lm/trie_sort.hh6
-rw-r--r--klm/lm/value.hh2
-rw-r--r--klm/lm/value_build.hh6
-rw-r--r--klm/lm/vocab.cc8
-rw-r--r--klm/lm/vocab.hh10
-rw-r--r--klm/lm/word_index.hh3
-rw-r--r--klm/search/Jamfile5
-rw-r--r--klm/search/Makefile.am11
-rw-r--r--klm/search/config.hh25
-rw-r--r--klm/search/context.hh65
-rw-r--r--klm/search/edge.hh54
-rw-r--r--klm/search/edge_generator.cc110
-rw-r--r--klm/search/edge_generator.hh57
-rw-r--r--klm/search/final.hh36
-rw-r--r--klm/search/header.hh57
-rw-r--r--klm/search/note.hh12
-rw-r--r--klm/search/rule.cc43
-rw-r--r--klm/search/rule.hh20
-rw-r--r--klm/search/types.hh14
-rw-r--r--klm/search/vertex.cc42
-rw-r--r--klm/search/vertex.hh159
-rw-r--r--klm/search/vertex_generator.cc94
-rw-r--r--klm/search/vertex_generator.hh46
-rw-r--r--klm/search/weights.cc71
-rw-r--r--klm/search/weights.hh52
-rw-r--r--klm/search/weights_test.cc38
-rw-r--r--klm/util/Jamfile10
-rw-r--r--klm/util/Makefile.am2
-rw-r--r--klm/util/ersatz_progress.cc10
-rw-r--r--klm/util/ersatz_progress.hh10
-rw-r--r--klm/util/exception.cc3
-rw-r--r--klm/util/exception.hh22
-rw-r--r--klm/util/file.cc44
-rw-r--r--klm/util/file.hh8
-rw-r--r--klm/util/file_piece.cc4
-rw-r--r--klm/util/mmap.cc16
-rw-r--r--klm/util/pool.cc35
-rw-r--r--klm/util/pool.hh45
-rw-r--r--klm/util/probing_hash_table.hh5
-rw-r--r--klm/util/string_piece.cc192
-rw-r--r--klm/util/string_piece.hh5
-rw-r--r--klm/util/tokenize_piece.hh12
67 files changed, 1973 insertions, 244 deletions
diff --git a/klm/lm/Jamfile b/klm/lm/Jamfile
deleted file mode 100644
index b1971d88..00000000
--- a/klm/lm/Jamfile
+++ /dev/null
@@ -1,14 +0,0 @@
-lib kenlm : bhiksha.cc binary_format.cc config.cc lm_exception.cc model.cc quantize.cc read_arpa.cc search_hashed.cc search_trie.cc trie.cc trie_sort.cc value_build.cc virtual_interface.cc vocab.cc ../util//kenutil : <include>.. : : <include>.. <library>../util//kenutil ;
-
-import testing ;
-
-run left_test.cc ../util//kenutil kenlm ../..//boost_unit_test_framework : : test.arpa ;
-run model_test.cc ../util//kenutil kenlm ../..//boost_unit_test_framework : : test.arpa test_nounk.arpa ;
-
-exe query : ngram_query.cc kenlm ../util//kenutil ;
-exe build_binary : build_binary.cc kenlm ../util//kenutil ;
-
-install legacy : build_binary query
- : <location>$(TOP)/klm/lm <install-type>EXE <install-dependencies>on <link>shared:<dll-path>$(TOP)/klm/lm <link>shared:<install-type>LIB ;
-
-alias programs : build_binary query ;
diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc
index cdeafb47..088ea98d 100644
--- a/klm/lm/bhiksha.cc
+++ b/klm/lm/bhiksha.cc
@@ -1,6 +1,7 @@
#include "lm/bhiksha.hh"
#include "lm/config.hh"
#include "util/file.hh"
+#include "util/exception.hh"
#include <limits>
@@ -49,7 +50,7 @@ std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &con
}
} // namespace
-std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) {
+uint64_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) {
return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */;
}
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh
index 5182ee2e..8ff88654 100644
--- a/klm/lm/bhiksha.hh
+++ b/klm/lm/bhiksha.hh
@@ -23,7 +23,7 @@
namespace lm {
namespace ngram {
-class Config;
+struct Config;
namespace trie {
@@ -33,7 +33,7 @@ class DontBhiksha {
static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {}
- static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
+ static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) {
return util::RequiredBits(max_next);
@@ -67,7 +67,7 @@ class ArrayBhiksha {
static void UpdateConfigFromBinary(int fd, Config &config);
- static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
+ static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config);
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index a56e998e..efa67056 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -83,7 +83,13 @@ void WriteHeader(void *to, const Parameters &params) {
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
std::size_t total = TotalHeaderSize(order) + memory_size;
- backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED);
+ backing.file.reset(util::CreateOrThrow(config.write_mmap));
+ if (config.write_method == Config::WRITE_MMAP) {
+ backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
+ } else {
+ util::ResizeOrThrow(backing.file.get(), 0);
+ util::MapAnonymous(total, backing.vocab);
+ }
strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
} else {
@@ -121,12 +127,14 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
if (!config.write_mmap) return;
- util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
switch (config.write_method) {
case Config::WRITE_MMAP:
+ util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
util::SyncOrThrow(backing.search.get(), backing.search.size());
break;
case Config::WRITE_AFTER:
+ util::SeekOrThrow(backing.file.get(), 0);
+ util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size());
util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
util::FSyncOrThrow(backing.file.get());
@@ -141,6 +149,10 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_
params.fixed.has_vocabulary = config.include_vocab;
params.fixed.search_version = search_version;
WriteHeader(backing.vocab.get(), params);
+ if (config.write_method == Config::WRITE_AFTER) {
+ util::SeekOrThrow(backing.file.get(), 0);
+ util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size()));
+ }
}
namespace detail {
@@ -200,10 +212,10 @@ void SeekPastHeader(int fd, const Parameters &params) {
util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
}
-uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing) {
+uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing) {
const uint64_t file_size = util::SizeFile(backing.file.get());
// The header is smaller than a page, so we have to map the whole header as well.
- std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size;
+ std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh
index dd795f62..bf699d5f 100644
--- a/klm/lm/binary_format.hh
+++ b/klm/lm/binary_format.hh
@@ -70,7 +70,7 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet
void SeekPastHeader(int fd, const Parameters &params);
-uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing);
+uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing);
void ComplainAboutARPA(const Config &config, ModelType model_type);
@@ -90,7 +90,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)
new_config.probing_multiplier = params.fixed.probing_multiplier;
detail::SeekPastHeader(backing.file.get(), params);
To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config);
- std::size_t memory_size = To::Size(params.counts, new_config);
+ uint64_t memory_size = To::Size(params.counts, new_config);
uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing);
to.InitializeFromBinary(start, params, new_config, backing.file.get());
} else {
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index c4a01cb4..2b8c9d5b 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -11,6 +11,8 @@
#ifdef WIN32
#include "util/getopt.hh"
+#else
+#include <unistd.h>
#endif
namespace lm {
@@ -25,7 +27,11 @@ void Usage(const char *name) {
"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n"
"-w mmap|after determines how writing is done.\n"
" mmap maps the binary file and writes to it. Default for trie.\n"
-" after allocates anonymous memory, builds, and writes. Default for probing.\n\n"
+" after allocates anonymous memory, builds, and writes. Default for probing.\n"
+"-r \"order1.arpa order2 order3 order4\" adds lower-order rest costs from these\n"
+" model files. order1.arpa must be an ARPA file. All others may be ARPA or\n"
+" the same data structure as being built. All files must have the same\n"
+" vocabulary. For probing, the unigrams must be in the same order.\n\n"
"type is either probing or trie. Default is probing.\n\n"
"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n"
@@ -81,16 +87,16 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
std::vector<uint64_t> counts;
util::FilePiece f(file);
lm::ReadARPACounts(f, counts);
- std::size_t sizes[6];
+ uint64_t sizes[6];
sizes[0] = ProbingModel::Size(counts, config);
sizes[1] = RestProbingModel::Size(counts, config);
sizes[2] = TrieModel::Size(counts, config);
sizes[3] = QuantTrieModel::Size(counts, config);
sizes[4] = ArrayTrieModel::Size(counts, config);
sizes[5] = QuantArrayTrieModel::Size(counts, config);
- std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
- std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
- std::size_t divide;
+ uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
+ uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
+ uint64_t divide;
char prefix;
if (min_length < (1 << 10) * 10) {
prefix = ' ';
@@ -111,7 +117,7 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
for (long int i = 0; i < length - 2; ++i) std::cout << ' ';
std::cout << prefix << "B\n"
"probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"
- "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r -p " << config.probing_multiplier << "\n"
+ "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n"
"trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n"
"trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
"trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
diff --git a/klm/lm/fragment.cc b/klm/lm/fragment.cc
new file mode 100644
index 00000000..0267cd4e
--- /dev/null
+++ b/klm/lm/fragment.cc
@@ -0,0 +1,37 @@
+#include "lm/binary_format.hh"
+#include "lm/model.hh"
+#include "lm/left.hh"
+#include "util/tokenize_piece.hh"
+
+template <class Model> void Query(const char *name) {
+ Model model(name);
+ std::string line;
+ lm::ngram::ChartState ignored;
+ while (getline(std::cin, line)) {
+ lm::ngram::RuleScore<Model> scorer(model, ignored);
+ for (util::TokenIter<util::SingleCharacter, true> i(line, ' '); i; ++i) {
+ scorer.Terminal(model.GetVocabulary().Index(*i));
+ }
+ std::cout << scorer.Finish() << '\n';
+ }
+}
+
+int main(int argc, char *argv[]) {
+ if (argc != 2) {
+ std::cerr << "Expected model file name." << std::endl;
+ return 1;
+ }
+ const char *name = argv[1];
+ lm::ngram::ModelType model_type = lm::ngram::PROBING;
+ lm::ngram::RecognizeBinary(name, model_type);
+ switch (model_type) {
+ case lm::ngram::PROBING:
+ Query<lm::ngram::ProbingModel>(name);
+ break;
+ case lm::ngram::REST_PROBING:
+ Query<lm::ngram::RestProbingModel>(name);
+ break;
+ default:
+ std::cerr << "Model type not supported yet." << std::endl;
+ }
+}
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index c00af88a..8c27232e 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -111,7 +111,7 @@ template <class M> class RuleScore {
return;
}
- float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1];
+ float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1];
float *back = backoffs, *back2 = backoffs2;
unsigned char next_use = out_.right.length;
diff --git a/klm/lm/max_order.cc b/klm/lm/max_order.cc
new file mode 100644
index 00000000..94221201
--- /dev/null
+++ b/klm/lm/max_order.cc
@@ -0,0 +1,6 @@
+#include "lm/max_order.hh"
+#include <iostream>
+
+int main(int argc, char *argv[]) {
+ std::cerr << "KenLM was compiled with a maximum supported n-gram order set to " << KENLM_MAX_ORDER << "." << std::endl;
+}
diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh
index aff9de27..989f8324 100644
--- a/klm/lm/max_order.hh
+++ b/klm/lm/max_order.hh
@@ -1,14 +1,12 @@
-#ifndef LM_MAX_ORDER__
-#define LM_MAX_ORDER__
-namespace lm {
-namespace ngram {
-// If you need higher order, change this and recompile.
-// Having this limit means that State can be
-// (kMaxOrder - 1) * sizeof(float) bytes instead of
-// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
-const unsigned char kMaxOrder = 5;
-
-} // namespace ngram
-} // namespace lm
-
-#endif // LM_MAX_ORDER__
+/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM.
+ * If not, this is the default maximum order.
+ * Having this limit means that State can be
+ * (kMaxOrder - 1) * sizeof(float) bytes instead of
+ * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
+ */
+#ifndef KENLM_MAX_ORDER
+#define KENLM_MAX_ORDER 6
+#endif
+#ifndef KENLM_ORDER_MESSAGE
+#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh."
+#endif
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index a2d31ce0..2fd20481 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -5,12 +5,14 @@
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
#include "lm/read_arpa.hh"
+#include "util/have.hh"
#include "util/murmur_hash.hh"
#include <algorithm>
#include <functional>
#include <numeric>
#include <cmath>
+#include <limits>
namespace lm {
namespace ngram {
@@ -18,17 +20,18 @@ namespace detail {
template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType;
-template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
+template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) {
+ size_t goal_size = util::CheckOverflow(Size(counts, config));
uint8_t *start = static_cast<uint8_t*>(base);
size_t allocated = VocabularyT::Size(counts[0], config);
vocab_.SetupMemory(start, allocated, counts[0], config);
start += allocated;
start = search_.SetupMemory(start, counts, config);
- if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config));
+ if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);
}
template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
@@ -47,7 +50,19 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge
P::Init(begin_sentence, null_context, vocab_, search_.Order());
}
+namespace {
+void CheckCounts(const std::vector<uint64_t> &counts) {
+ UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
+ if (sizeof(uint64_t) > sizeof(std::size_t)) {
+ for (std::vector<uint64_t>::const_iterator i = counts.begin(); i != counts.end(); ++i) {
+ UTIL_THROW_IF(*i > static_cast<uint64_t>(std::numeric_limits<size_t>::max()), util::OverflowException, "This model has " << *i << " " << (i - counts.begin() + 1) << "-grams which is too many for 32-bit machines.");
+ }
+ }
+}
+} // namespace
+
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
+ CheckCounts(params.counts);
SetupMemory(start, params.counts, config);
vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
search_.LoadedBinary();
@@ -60,12 +75,11 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
ReadARPACounts(f, counts);
-
- if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
+ CheckCounts(counts);
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
- std::size_t vocab_size = VocabularyT::Size(counts[0], config);
+ std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
@@ -73,7 +87,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get());
+ wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + backing_.search.size());
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index be872178..13ff864e 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -5,7 +5,6 @@
#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "lm/facade.hh"
-#include "lm/max_order.hh"
#include "lm/quantize.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
@@ -42,7 +41,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
* does not include small non-mapped control structures, such as this class
* itself.
*/
- static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
+ static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
/* Load the model from a file. It may be an ARPA or binary file. Binary
* files must have the format expected by this class or you'll get an
diff --git a/klm/lm/partial.hh b/klm/lm/partial.hh
new file mode 100644
index 00000000..1dede359
--- /dev/null
+++ b/klm/lm/partial.hh
@@ -0,0 +1,167 @@
+#ifndef LM_PARTIAL__
+#define LM_PARTIAL__
+
+#include "lm/return.hh"
+#include "lm/state.hh"
+
+#include <algorithm>
+
+#include <assert.h>
+
+namespace lm {
+namespace ngram {
+
+struct ExtendReturn {
+ float adjust;
+ bool make_full;
+ unsigned char next_use;
+};
+
+template <class Model> ExtendReturn ExtendLoop(
+ const Model &model,
+ unsigned char seen, const WordIndex *add_rbegin, const WordIndex *add_rend, const float *backoff_start,
+ const uint64_t *pointers, const uint64_t *pointers_end,
+ uint64_t *&pointers_write,
+ float *backoff_write) {
+ unsigned char add_length = add_rend - add_rbegin;
+
+ float backoff_buf[2][KENLM_MAX_ORDER - 1];
+ float *backoff_in = backoff_buf[0], *backoff_out = backoff_buf[1];
+ std::copy(backoff_start, backoff_start + add_length, backoff_in);
+
+ ExtendReturn value;
+ value.make_full = false;
+ value.adjust = 0.0;
+ value.next_use = add_length;
+
+ unsigned char i = 0;
+ unsigned char length = pointers_end - pointers;
+ // pointers_write is NULL means that the existing left state is full, so we should use completed probabilities.
+ if (pointers_write) {
+ // Using full context, writing to new left state.
+ for (; i < length; ++i) {
+ FullScoreReturn ret(model.ExtendLeft(
+ add_rbegin, add_rbegin + value.next_use,
+ backoff_in,
+ pointers[i], i + seen + 1,
+ backoff_out,
+ value.next_use));
+ std::swap(backoff_in, backoff_out);
+ if (ret.independent_left) {
+ value.adjust += ret.prob;
+ value.make_full = true;
+ ++i;
+ break;
+ }
+ value.adjust += ret.rest;
+ *pointers_write++ = ret.extend_left;
+ if (value.next_use != add_length) {
+ value.make_full = true;
+ ++i;
+ break;
+ }
+ }
+ }
+ // Using some of the new context.
+ for (; i < length && value.next_use; ++i) {
+ FullScoreReturn ret(model.ExtendLeft(
+ add_rbegin, add_rbegin + value.next_use,
+ backoff_in,
+ pointers[i], i + seen + 1,
+ backoff_out,
+ value.next_use));
+ std::swap(backoff_in, backoff_out);
+ value.adjust += ret.prob;
+ }
+ float unrest = model.UnRest(pointers + i, pointers_end, i + seen + 1);
+ // Using none of the new context.
+ value.adjust += unrest;
+
+ std::copy(backoff_in, backoff_in + value.next_use, backoff_write);
+ return value;
+}
+
+template <class Model> float RevealBefore(const Model &model, const Right &reveal, const unsigned char seen, bool reveal_full, Left &left, Right &right) {
+ assert(seen < reveal.length || reveal_full);
+ uint64_t *pointers_write = reveal_full ? NULL : left.pointers;
+ float backoff_buffer[KENLM_MAX_ORDER - 1];
+ ExtendReturn value(ExtendLoop(
+ model,
+ seen, reveal.words + seen, reveal.words + reveal.length, reveal.backoff + seen,
+ left.pointers, left.pointers + left.length,
+ pointers_write,
+ left.full ? backoff_buffer : (right.backoff + right.length)));
+ if (reveal_full) {
+ left.length = 0;
+ value.make_full = true;
+ } else {
+ left.length = pointers_write - left.pointers;
+ value.make_full |= (left.length == model.Order() - 1);
+ }
+ if (left.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
+ } else {
+ // If left wasn't full when it came in, put words into right state.
+ std::copy(reveal.words + seen, reveal.words + seen + value.next_use, right.words + right.length);
+ right.length += value.next_use;
+ left.full = value.make_full || (right.length == model.Order() - 1);
+ }
+ return value.adjust;
+}
+
+template <class Model> float RevealAfter(const Model &model, Left &left, Right &right, const Left &reveal, unsigned char seen) {
+ assert(seen < reveal.length || reveal.full);
+ uint64_t *pointers_write = left.full ? NULL : (left.pointers + left.length);
+ ExtendReturn value(ExtendLoop(
+ model,
+ seen, right.words, right.words + right.length, right.backoff,
+ reveal.pointers + seen, reveal.pointers + reveal.length,
+ pointers_write,
+ right.backoff));
+ if (reveal.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += right.backoff[i];
+ right.length = 0;
+ value.make_full = true;
+ } else {
+ right.length = value.next_use;
+ value.make_full |= (right.length == model.Order() - 1);
+ }
+ if (!left.full) {
+ left.length = pointers_write - left.pointers;
+ left.full = value.make_full || (left.length == model.Order() - 1);
+ }
+ return value.adjust;
+}
+
+template <class Model> float Subsume(const Model &model, Left &first_left, const Right &first_right, const Left &second_left, Right &second_right, const unsigned int between_length) {
+ assert(first_right.length < KENLM_MAX_ORDER);
+ assert(second_left.length < KENLM_MAX_ORDER);
+ assert(between_length < KENLM_MAX_ORDER - 1);
+ uint64_t *pointers_write = first_left.full ? NULL : (first_left.pointers + first_left.length);
+ float backoff_buffer[KENLM_MAX_ORDER - 1];
+ ExtendReturn value(ExtendLoop(
+ model,
+ between_length, first_right.words, first_right.words + first_right.length, first_right.backoff,
+ second_left.pointers, second_left.pointers + second_left.length,
+ pointers_write,
+ second_left.full ? backoff_buffer : (second_right.backoff + second_right.length)));
+ if (second_left.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
+ } else {
+ std::copy(first_right.words, first_right.words + value.next_use, second_right.words + second_right.length);
+ second_right.length += value.next_use;
+ value.make_full |= (second_right.length == model.Order() - 1);
+ }
+ if (!first_left.full) {
+ first_left.length = pointers_write - first_left.pointers;
+ first_left.full = value.make_full || second_left.full || (first_left.length == model.Order() - 1);
+ }
+ assert(first_left.length < KENLM_MAX_ORDER);
+ assert(second_right.length < KENLM_MAX_ORDER);
+ return value.adjust;
+}
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_PARTIAL__
diff --git a/klm/lm/partial_test.cc b/klm/lm/partial_test.cc
new file mode 100644
index 00000000..8d309c85
--- /dev/null
+++ b/klm/lm/partial_test.cc
@@ -0,0 +1,199 @@
+#include "lm/partial.hh"
+
+#include "lm/left.hh"
+#include "lm/model.hh"
+#include "util/tokenize_piece.hh"
+
+#define BOOST_TEST_MODULE PartialTest
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+
+namespace lm {
+namespace ngram {
+namespace {
+
+const char *TestLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 2) {
+ return "test.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[1];
+}
+
+Config SilentConfig() {
+ Config config;
+ config.arpa_complain = Config::NONE;
+ config.messages = NULL;
+ return config;
+}
+
+struct ModelFixture {
+ ModelFixture() : m(TestLocation(), SilentConfig()) {}
+
+ RestProbingModel m;
+};
+
+BOOST_FIXTURE_TEST_SUITE(suite, ModelFixture)
+
+BOOST_AUTO_TEST_CASE(SimpleBefore) {
+ Left left;
+ left.full = false;
+ left.length = 0;
+ Right right;
+ right.length = 0;
+
+ Right reveal;
+ reveal.length = 1;
+ WordIndex period = m.GetVocabulary().Index(".");
+ reveal.words[0] = period;
+ reveal.backoff[0] = -0.845098;
+
+ BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 0, false, left, right), 0.001);
+ BOOST_CHECK_EQUAL(0, left.length);
+ BOOST_CHECK(!left.full);
+ BOOST_CHECK_EQUAL(1, right.length);
+ BOOST_CHECK_EQUAL(period, right.words[0]);
+ BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001);
+
+ WordIndex more = m.GetVocabulary().Index("more");
+ reveal.words[1] = more;
+ reveal.backoff[1] = -0.4771212;
+ reveal.length = 2;
+ BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 1, false, left, right), 0.001);
+ BOOST_CHECK_EQUAL(0, left.length);
+ BOOST_CHECK(!left.full);
+ BOOST_CHECK_EQUAL(2, right.length);
+ BOOST_CHECK_EQUAL(period, right.words[0]);
+ BOOST_CHECK_EQUAL(more, right.words[1]);
+ BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001);
+ BOOST_CHECK_CLOSE(-0.4771212, right.backoff[1], 0.001);
+}
+
+BOOST_AUTO_TEST_CASE(AlsoWouldConsider) {
+ WordIndex would = m.GetVocabulary().Index("would");
+ WordIndex consider = m.GetVocabulary().Index("consider");
+
+ ChartState current;
+ current.left.length = 1;
+ current.left.pointers[0] = would;
+ current.left.full = false;
+ current.right.length = 1;
+ current.right.words[0] = would;
+ current.right.backoff[0] = -0.30103;
+
+ Left after;
+ after.full = false;
+ after.length = 1;
+ after.pointers[0] = consider;
+
+ // adjustment for would consider
+ BOOST_CHECK_CLOSE(-1.687872 - -0.2922095 - 0.30103, RevealAfter(m, current.left, current.right, after, 0), 0.001);
+
+ BOOST_CHECK_EQUAL(2, current.left.length);
+ BOOST_CHECK_EQUAL(would, current.left.pointers[0]);
+ BOOST_CHECK_EQUAL(false, current.left.full);
+
+ WordIndex also = m.GetVocabulary().Index("also");
+ Right before;
+ before.length = 1;
+ before.words[0] = also;
+ before.backoff[0] = -0.30103;
+ // r(would) = -0.2922095 [i would], r(would -> consider) = -1.988902 [b(would) + p(consider)]
+ // p(also -> would) = -2, p(also would -> consider) = -3
+ BOOST_CHECK_CLOSE(-2 + 0.2922095 -3 + 1.988902, RevealBefore(m, before, 0, false, current.left, current.right), 0.001);
+ BOOST_CHECK_EQUAL(0, current.left.length);
+ BOOST_CHECK(current.left.full);
+ BOOST_CHECK_EQUAL(2, current.right.length);
+ BOOST_CHECK_EQUAL(would, current.right.words[0]);
+ BOOST_CHECK_EQUAL(also, current.right.words[1]);
+}
+
+BOOST_AUTO_TEST_CASE(EndSentence) {
+ WordIndex loin = m.GetVocabulary().Index("loin");
+ WordIndex period = m.GetVocabulary().Index(".");
+ WordIndex eos = m.GetVocabulary().EndSentence();
+
+ ChartState between;
+ between.left.length = 1;
+ between.left.pointers[0] = eos;
+ between.left.full = true;
+ between.right.length = 0;
+
+ Right before;
+ before.words[0] = period;
+ before.words[1] = loin;
+ before.backoff[0] = -0.845098;
+ before.backoff[1] = 0.0;
+
+ before.length = 1;
+ BOOST_CHECK_CLOSE(-0.0410707, RevealBefore(m, before, 0, true, between.left, between.right), 0.001);
+ BOOST_CHECK_EQUAL(0, between.left.length);
+}
+
+float ScoreFragment(const RestProbingModel &model, unsigned int *begin, unsigned int *end, ChartState &out) {
+ RuleScore<RestProbingModel> scorer(model, out);
+ for (unsigned int *i = begin; i < end; ++i) {
+ scorer.Terminal(*i);
+ }
+ return scorer.Finish();
+}
+
+void CheckAdjustment(const RestProbingModel &model, float expect, const Right &before_in, bool before_full, ChartState between, const Left &after_in) {
+ Right before(before_in);
+ Left after(after_in);
+ after.full = false;
+ float got = 0.0;
+ for (unsigned int i = 1; i < 5; ++i) {
+ if (before_in.length >= i) {
+ before.length = i;
+ got += RevealBefore(model, before, i - 1, false, between.left, between.right);
+ }
+ if (after_in.length >= i) {
+ after.length = i;
+ got += RevealAfter(model, between.left, between.right, after, i - 1);
+ }
+ }
+ if (after_in.full) {
+ after.full = true;
+ got += RevealAfter(model, between.left, between.right, after, after.length);
+ }
+ if (before_full) {
+ got += RevealBefore(model, before, before.length, true, between.left, between.right);
+ }
+ // Sometimes they're zero and BOOST_CHECK_CLOSE fails for this.
+ BOOST_CHECK(fabs(expect - got) < 0.001);
+}
+
+void FullDivide(const RestProbingModel &model, StringPiece str) {
+ std::vector<WordIndex> indices;
+ for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) {
+ indices.push_back(model.GetVocabulary().Index(*i));
+ }
+ ChartState full_state;
+ float full = ScoreFragment(model, &indices.front(), &indices.back() + 1, full_state);
+
+ ChartState before_state;
+ before_state.left.full = false;
+ RuleScore<RestProbingModel> before_scorer(model, before_state);
+ float before_score = 0.0;
+ for (unsigned int before = 0; before < indices.size(); ++before) {
+ for (unsigned int after = before; after <= indices.size(); ++after) {
+ ChartState after_state, between_state;
+ float after_score = ScoreFragment(model, &indices.front() + after, &indices.front() + indices.size(), after_state);
+ float between_score = ScoreFragment(model, &indices.front() + before, &indices.front() + after, between_state);
+ CheckAdjustment(model, full - before_score - after_score - between_score, before_state.right, before_state.left.full, between_state, after_state.left);
+ }
+ before_scorer.Terminal(indices[before]);
+ before_score = before_scorer.Finish();
+ }
+}
+
+BOOST_AUTO_TEST_CASE(Strings) {
+ FullDivide(m, "also would consider");
+ FullDivide(m, "looking on a little more loin . </s>");
+ FullDivide(m, "in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
+}
+
+BOOST_AUTO_TEST_SUITE_END()
+} // namespace
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh
index 3e9153e3..8ce2378a 100644
--- a/klm/lm/quantize.hh
+++ b/klm/lm/quantize.hh
@@ -17,14 +17,14 @@
namespace lm {
namespace ngram {
-class Config;
+struct Config;
/* Store values directly and don't quantize. */
class DontQuantize {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
- static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
+ static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
static uint8_t LongestBits(const Config &/*config*/) { return 31; }
@@ -138,9 +138,9 @@ class SeparatelyQuantize {
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
- static std::size_t Size(uint8_t order, const Config &config) {
- size_t longest_table = (static_cast<size_t>(1) << static_cast<size_t>(config.prob_bits)) * sizeof(float);
- size_t middle_table = (static_cast<size_t>(1) << static_cast<size_t>(config.backoff_bits)) * sizeof(float) + longest_table;
+ static uint64_t Size(uint8_t order, const Config &config) {
+ uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float);
+ uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table;
// unigrams are currently not quantized so no need for a table.
return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8;
}
@@ -217,7 +217,7 @@ class SeparatelyQuantize {
const Bins &LongestTable() const { return longest_; }
private:
- Bins tables_[kMaxOrder - 1][2];
+ Bins tables_[KENLM_MAX_ORDER - 1][2];
Bins longest_;
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index 2d9a337d..b709fef9 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -2,14 +2,20 @@
#include "lm/blank.hh"
+#include <cmath>
#include <cstdlib>
#include <iostream>
+#include <sstream>
#include <vector>
#include <ctype.h>
#include <string.h>
#include <stdint.h>
+#ifdef WIN32
+#include <float.h>
+#endif
+
namespace lm {
// 1 for '\t', '\n', and ' '. This is stricter than isspace.
@@ -26,6 +32,15 @@ bool IsEntirelyWhiteSpace(const StringPiece &line) {
const char kBinaryMagic[] = "mmap lm http://kheafield.com/code";
+// strtoull isn't portable enough :-(
+uint64_t ReadCount(const std::string &from) {
+ std::stringstream stream(from);
+ uint64_t ret;
+ stream >> ret;
+ UTIL_THROW_IF(!stream, FormatLoadException, "Bad count " << from);
+ return ret;
+}
+
} // namespace
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
@@ -47,15 +62,11 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
// So strtol doesn't go off the end of line.
std::string remaining(line.data() + 6, line.size() - 6);
char *end_ptr;
- unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10);
+ unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10);
if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line);
if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line);
++end_ptr;
- const char *start = end_ptr;
- long int count = std::strtol(start, &end_ptr, 10);
- if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count);
- if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line);
- number.push_back(count);
+ number.push_back(ReadCount(end_ptr));
}
}
@@ -93,7 +104,16 @@ void ReadBackoff(util::FilePiece &in, float &backoff) {
case '\t':
backoff = in.ReadFloat();
if (backoff == ngram::kExtensionBackoff) backoff = ngram::kNoExtensionBackoff;
- if ((in.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff");
+ {
+#ifdef WIN32
+ int float_class = _fpclass(backoff);
+ UTIL_THROW_IF(float_class == _FPCLASS_SNAN || float_class == _FPCLASS_QNAN || float_class == _FPCLASS_NINF || float_class == _FPCLASS_PINF, FormatLoadException, "Bad backoff " << backoff);
+#else
+ int float_class = std::fpclassify(backoff);
+ UTIL_THROW_IF(float_class == FP_NAN || float_class == FP_INFINITE, FormatLoadException, "Bad backoff " << backoff);
+#endif
+ }
+ UTIL_THROW_IF(in.get() != '\n', FormatLoadException, "Expected newline after backoff");
break;
case '\n':
backoff = ngram::kNoExtensionBackoff;
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 13942309..a1623834 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -234,7 +234,7 @@ template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, c
ApplyBuild(f, counts, config, vocab, warn, build);
}
-template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
+template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
switch (config.rest_function) {
case Config::REST_MAX:
{
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index 7e8c1220..a52f107b 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -74,8 +74,8 @@ template <class Value> class HashedSearch {
// TODO: move probing_multiplier here with next binary file format update.
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
- static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
- std::size_t ret = Unigram::Size(counts[0]);
+ static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
+ uint64_t ret = Unigram::Size(counts[0]);
for (unsigned char n = 1; n < counts.size() - 1; ++n) {
ret += Middle::Size(counts[n], config.probing_multiplier);
}
@@ -160,8 +160,8 @@ template <class Value> class HashedSearch {
#endif
{}
- static std::size_t Size(uint64_t count) {
- return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
+ static uint64_t Size(uint64_t count) {
+ return (count + 1) * sizeof(typename Value::Weights); // +1 for hallucinate <unk>
}
const typename Value::Weights &Lookup(WordIndex index) const {
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 18e80d5a..debcfd07 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -89,7 +89,7 @@ class BackoffMessages {
if (!HasExtension(weights.backoff)) {
weights.backoff = kExtensionBackoff;
UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed.");
- WriteOrThrow(unigrams, &weights, sizeof(weights));
+ util::WriteOrThrow(unigrams, &weights, sizeof(weights));
}
const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex));
base[write_to.array][write_to.index] += weights.backoff;
@@ -180,7 +180,7 @@ const float kBadProb = std::numeric_limits<float>::infinity();
class SRISucks {
public:
SRISucks() {
- for (BackoffMessages *i = messages_; i != messages_ + kMaxOrder - 1; ++i)
+ for (BackoffMessages *i = messages_; i != messages_ + KENLM_MAX_ORDER - 1; ++i)
i->Init(sizeof(ProbPointer) + sizeof(WordIndex) * (i - messages_ + 1));
}
@@ -196,7 +196,7 @@ class SRISucks {
}
void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) {
- for (unsigned char i = 0; i < kMaxOrder - 1; ++i) {
+ for (unsigned char i = 0; i < KENLM_MAX_ORDER - 1; ++i) {
it_[i] = values_[i].empty() ? NULL : &*values_[i].begin();
}
messages_[0].Apply(it_, unigram_file);
@@ -221,10 +221,10 @@ class SRISucks {
private:
// This used to be one array. Then I needed to separate it by order for quantization to work.
- std::vector<float> values_[kMaxOrder - 1];
- BackoffMessages messages_[kMaxOrder - 1];
+ std::vector<float> values_[KENLM_MAX_ORDER - 1];
+ BackoffMessages messages_[KENLM_MAX_ORDER - 1];
- float *it_[kMaxOrder - 1];
+ float *it_[KENLM_MAX_ORDER - 1];
};
class FindBlanks {
@@ -337,7 +337,7 @@ struct Gram {
template <class Doing> class BlankManager {
public:
BlankManager(unsigned char total_order, Doing &doing) : total_order_(total_order), been_length_(0), doing_(doing) {
- for (float *i = basis_; i != basis_ + kMaxOrder - 1; ++i) *i = kBadProb;
+ for (float *i = basis_; i != basis_ + KENLM_MAX_ORDER - 1; ++i) *i = kBadProb;
}
void Visit(const WordIndex *to, unsigned char length, float prob) {
@@ -373,10 +373,10 @@ template <class Doing> class BlankManager {
private:
const unsigned char total_order_;
- WordIndex been_[kMaxOrder];
+ WordIndex been_[KENLM_MAX_ORDER];
unsigned char been_length_;
- float basis_[kMaxOrder];
+ float basis_[KENLM_MAX_ORDER];
Doing &doing_;
};
@@ -470,8 +470,8 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c
} // namespace
template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
- RecordReader inputs[kMaxOrder - 1];
- RecordReader contexts[kMaxOrder - 1];
+ RecordReader inputs[KENLM_MAX_ORDER - 1];
+ RecordReader contexts[KENLM_MAX_ORDER - 1];
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Init(files.Full(i), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff)));
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 10b22ab1..1264baf5 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -44,8 +44,8 @@ template <class Quant, class Bhiksha> class TrieSearch {
Bhiksha::UpdateConfigFromBinary(fd, config);
}
- static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
- std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
+ static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
+ uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
}
diff --git a/klm/lm/sri_test.cc b/klm/lm/sri_test.cc
deleted file mode 100644
index e697d722..00000000
--- a/klm/lm/sri_test.cc
+++ /dev/null
@@ -1,65 +0,0 @@
-#include "lm/sri.hh"
-
-#include <stdlib.h>
-
-#define BOOST_TEST_MODULE SRITest
-#include <boost/test/unit_test.hpp>
-
-namespace lm {
-namespace sri {
-namespace {
-
-#define StartTest(word, ngram, score) \
- ret = model.FullScore( \
- state, \
- model.GetVocabulary().Index(word), \
- out);\
- BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
- BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
- BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
-
-#define AppendTest(word, ngram, score) \
- StartTest(word, ngram, score) \
- state = out;
-
-template <class M> void Starters(M &model) {
- FullScoreReturn ret;
- Model::State state(model.BeginSentenceState());
- Model::State out;
-
- StartTest("looking", 2, -0.4846522);
-
- // , probability plus <s> backoff
- StartTest(",", 1, -1.383514 + -0.4149733);
- // <unk> probability plus <s> backoff
- StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
-}
-
-template <class M> void Continuation(M &model) {
- FullScoreReturn ret;
- Model::State state(model.BeginSentenceState());
- Model::State out;
-
- AppendTest("looking", 2, -0.484652);
- AppendTest("on", 3, -0.348837);
- AppendTest("a", 4, -0.0155266);
- AppendTest("little", 5, -0.00306122);
- State preserve = state;
- AppendTest("the", 1, -4.04005);
- AppendTest("biarritz", 1, -1.9889);
- AppendTest("not_found", 0, -2.29666);
- AppendTest("more", 1, -1.20632);
- AppendTest(".", 2, -0.51363);
- AppendTest("</s>", 3, -0.0191651);
-
- state = preserve;
- AppendTest("more", 5, -0.00181395);
- AppendTest("loin", 5, -0.0432557);
-}
-
-BOOST_AUTO_TEST_CASE(starters) { Model m("test.arpa", 5); Starters(m); }
-BOOST_AUTO_TEST_CASE(continuation) { Model m("test.arpa", 5); Continuation(m); }
-
-} // namespace
-} // namespace sri
-} // namespace lm
diff --git a/klm/lm/state.hh b/klm/lm/state.hh
index c7438414..551510a8 100644
--- a/klm/lm/state.hh
+++ b/klm/lm/state.hh
@@ -32,7 +32,7 @@ class State {
// Call this before using raw memcmp.
void ZeroRemaining() {
- for (unsigned char i = length; i < kMaxOrder - 1; ++i) {
+ for (unsigned char i = length; i < KENLM_MAX_ORDER - 1; ++i) {
words[i] = 0;
backoff[i] = 0.0;
}
@@ -42,11 +42,13 @@ class State {
// You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
// This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
- WordIndex words[kMaxOrder - 1];
- float backoff[kMaxOrder - 1];
+ WordIndex words[KENLM_MAX_ORDER - 1];
+ float backoff[KENLM_MAX_ORDER - 1];
unsigned char length;
};
+typedef State Right;
+
inline uint64_t hash_value(const State &state, uint64_t seed = 0) {
return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed);
}
@@ -72,11 +74,11 @@ struct Left {
}
void ZeroRemaining() {
- for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i)
+ for (uint64_t * i = pointers + length; i < pointers + KENLM_MAX_ORDER - 1; ++i)
*i = 0;
}
- uint64_t pointers[kMaxOrder - 1];
+ uint64_t pointers[KENLM_MAX_ORDER - 1];
unsigned char length;
bool full;
};
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 0f1ca574..d9895f89 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -36,7 +36,7 @@ bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_
}
} // namespace
-std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
+uint64_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits;
// Extra entry for next pointer at the end.
// +7 then / 8 to round up bits and convert to bytes
@@ -57,7 +57,7 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)
max_vocab_ = max_vocab;
}
-template <class Bhiksha> std::size_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
+template <class Bhiksha> uint64_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
}
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index eff93292..9ea3c546 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -11,7 +11,7 @@
namespace lm {
namespace ngram {
-class Config;
+struct Config;
namespace trie {
struct NodeRange {
@@ -49,7 +49,7 @@ class Unigram {
unigram_ = static_cast<UnigramValue*>(start);
}
- static std::size_t Size(uint64_t count) {
+ static uint64_t Size(uint64_t count) {
// +1 in case unknown doesn't appear. +1 for the final next.
return (count + 2) * sizeof(UnigramValue);
}
@@ -84,7 +84,7 @@ class BitPacked {
}
protected:
- static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
+ static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
@@ -99,7 +99,7 @@ class BitPacked {
template <class Bhiksha> class BitPackedMiddle : public BitPacked {
public:
- static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
+ static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
// next_source need not be initialized.
BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
@@ -128,7 +128,7 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked {
class BitPackedLongest : public BitPacked {
public:
- static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
+ static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
return BaseSize(entries, max_vocab, quant_bits);
}
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index b80fed02..8663e94e 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -22,12 +22,6 @@
namespace lm {
namespace ngram {
namespace trie {
-
-void WriteOrThrow(FILE *to, const void *data, size_t size) {
- assert(size);
- if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
-}
-
namespace {
typedef util::SizedIterator NGramIter;
@@ -95,12 +89,12 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
if (context_begin == context_end) return out.release();
PartialIter i(context_begin);
- WriteOrThrow(out.get(), i->Data(), context_size);
+ util::WriteOrThrow(out.get(), i->Data(), context_size);
const void *previous = i->Data();
++i;
for (; i != context_end; ++i) {
if (memcmp(previous, i->Data(), context_size)) {
- WriteOrThrow(out.get(), i->Data(), context_size);
+ util::WriteOrThrow(out.get(), i->Data(), context_size);
previous = i->Data();
}
}
@@ -116,7 +110,7 @@ struct ThrowCombine {
// Useful for context files that just contain records with no value.
struct FirstCombine {
void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const {
- WriteOrThrow(out, first, entry_size);
+ util::WriteOrThrow(out, first, entry_size);
}
};
@@ -129,10 +123,10 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f
EntryCompare less(order);
while (first && second) {
if (less(first.Data(), second.Data())) {
- WriteOrThrow(out_file.get(), first.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), first.Data(), entry_size);
++first;
} else if (less(second.Data(), first.Data())) {
- WriteOrThrow(out_file.get(), second.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), second.Data(), entry_size);
++second;
} else {
combine(entry_size, first.Data(), second.Data(), out_file.get());
@@ -140,7 +134,7 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f
}
}
for (RecordReader &remains = (first ? first : second); remains; ++remains) {
- WriteOrThrow(out_file.get(), remains.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), remains.Data(), entry_size);
}
return out_file.release();
}
@@ -148,19 +142,23 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f
} // namespace
void RecordReader::Init(FILE *file, std::size_t entry_size) {
- rewind(file);
- file_ = file;
+ entry_size_ = entry_size;
data_.reset(malloc(entry_size));
UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer");
- remains_ = true;
- entry_size_ = entry_size;
- ++*this;
+ file_ = file;
+ if (file) {
+ rewind(file);
+ remains_ = true;
+ ++*this;
+ } else {
+ remains_ = false;
+ }
}
void RecordReader::Overwrite(const void *start, std::size_t amount) {
long internal = (uint8_t*)start - (uint8_t*)data_.get();
UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
- WriteOrThrow(file_, start, amount);
+ util::WriteOrThrow(file_, start, amount);
long forward = entry_size_ - internal - amount;
#if !defined(_WIN32) && !defined(_WIN64)
if (forward)
@@ -169,9 +167,13 @@ void RecordReader::Overwrite(const void *start, std::size_t amount) {
}
void RecordReader::Rewind() {
- rewind(file_);
- remains_ = true;
- ++*this;
+ if (file_) {
+ rewind(file_);
+ remains_ = true;
+ ++*this;
+ } else {
+ remains_ = false;
+ }
}
SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh
index 3036319d..2197b80c 100644
--- a/klm/lm/trie_sort.hh
+++ b/klm/lm/trie_sort.hh
@@ -25,12 +25,10 @@ namespace lm {
class PositiveProbWarn;
namespace ngram {
class SortedVocabulary;
-class Config;
+struct Config;
namespace trie {
-void WriteOrThrow(FILE *to, const void *data, size_t size);
-
class EntryCompare : public std::binary_function<const void*, const void*, bool> {
public:
explicit EntryCompare(unsigned char order) : order_(order) {}
@@ -107,7 +105,7 @@ class SortedFiles {
util::scoped_fd unigram_;
- util::scoped_FILE full_[kMaxOrder - 1], context_[kMaxOrder - 1];
+ util::scoped_FILE full_[KENLM_MAX_ORDER - 1], context_[KENLM_MAX_ORDER - 1];
};
} // namespace trie
diff --git a/klm/lm/value.hh b/klm/lm/value.hh
index 85e53f14..ba716713 100644
--- a/klm/lm/value.hh
+++ b/klm/lm/value.hh
@@ -6,7 +6,7 @@
#include "lm/weights.hh"
#include "util/bit_packing.hh"
-#include <inttypes.h>
+#include <stdint.h>
namespace lm {
namespace ngram {
diff --git a/klm/lm/value_build.hh b/klm/lm/value_build.hh
index 687a41a0..461e6a5c 100644
--- a/klm/lm/value_build.hh
+++ b/klm/lm/value_build.hh
@@ -10,9 +10,9 @@
namespace lm {
namespace ngram {
-class Config;
-class BackoffValue;
-class RestValue;
+struct Config;
+struct BackoffValue;
+struct RestValue;
class NoRestBuild {
public:
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 5de68f16..11c27518 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -80,14 +80,14 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
buffer_.push_back(0);
}
-void WriteWordsWrapper::Write(int fd) {
- util::SeekEnd(fd);
+void WriteWordsWrapper::Write(int fd, uint64_t start) {
+ util::SeekOrThrow(fd, start);
util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
}
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
-std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) {
+uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
// Lead with the number of entries.
return sizeof(uint64_t) + sizeof(uint64_t) * entries;
}
@@ -165,7 +165,7 @@ struct ProbingVocabularyHeader {
ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
-std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) {
+uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) {
return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier);
}
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index c3efcb4a..de54eb06 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -13,11 +13,11 @@
#include <vector>
namespace lm {
-class ProbBackoff;
+struct ProbBackoff;
class EnumerateVocab;
namespace ngram {
-class Config;
+struct Config;
namespace detail {
uint64_t HashForVocab(const char *str, std::size_t len);
@@ -35,7 +35,7 @@ class WriteWordsWrapper : public EnumerateVocab {
void Add(WordIndex index, const StringPiece &str);
- void Write(int fd);
+ void Write(int fd, uint64_t start);
private:
EnumerateVocab *inner_;
@@ -62,7 +62,7 @@ class SortedVocabulary : public base::Vocabulary {
}
// Size for purposes of file writing
- static size_t Size(std::size_t entries, const Config &config);
+ static uint64_t Size(uint64_t entries, const Config &config);
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
WordIndex Bound() const { return bound_; }
@@ -129,7 +129,7 @@ class ProbingVocabulary : public base::Vocabulary {
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
}
- static size_t Size(std::size_t entries, const Config &config);
+ static uint64_t Size(uint64_t entries, const Config &config);
// Vocab words are [0, Bound()).
WordIndex Bound() const { return bound_; }
diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh
index 67841c30..e09557a7 100644
--- a/klm/lm/word_index.hh
+++ b/klm/lm/word_index.hh
@@ -2,8 +2,11 @@
#ifndef LM_WORD_INDEX__
#define LM_WORD_INDEX__
+#include <limits.h>
+
namespace lm {
typedef unsigned int WordIndex;
+const WordIndex kMaxWordIndex = UINT_MAX;
} // namespace lm
typedef lm::WordIndex LMWordIndex;
diff --git a/klm/search/Jamfile b/klm/search/Jamfile
new file mode 100644
index 00000000..bc95c53a
--- /dev/null
+++ b/klm/search/Jamfile
@@ -0,0 +1,5 @@
+lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
+
+import testing ;
+
+unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ;
diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am
new file mode 100644
index 00000000..ccc5b7f6
--- /dev/null
+++ b/klm/search/Makefile.am
@@ -0,0 +1,11 @@
+noinst_LIBRARIES = libksearch.a
+
+libksearch_a_SOURCES = \
+ edge_generator.cc \
+ rule.cc \
+ vertex.cc \
+ vertex_generator.cc \
+ weights.cc
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
+
diff --git a/klm/search/config.hh b/klm/search/config.hh
new file mode 100644
index 00000000..ef8e2354
--- /dev/null
+++ b/klm/search/config.hh
@@ -0,0 +1,25 @@
+#ifndef SEARCH_CONFIG__
+#define SEARCH_CONFIG__
+
+#include "search/weights.hh"
+#include "util/string_piece.hh"
+
+namespace search {
+
+class Config {
+ public:
+ Config(const Weights &weights, unsigned int pop_limit) :
+ weights_(weights), pop_limit_(pop_limit) {}
+
+ const Weights &GetWeights() const { return weights_; }
+
+ unsigned int PopLimit() const { return pop_limit_; }
+
+ private:
+ Weights weights_;
+ unsigned int pop_limit_;
+};
+
+} // namespace search
+
+#endif // SEARCH_CONFIG__
diff --git a/klm/search/context.hh b/klm/search/context.hh
new file mode 100644
index 00000000..62163144
--- /dev/null
+++ b/klm/search/context.hh
@@ -0,0 +1,65 @@
+#ifndef SEARCH_CONTEXT__
+#define SEARCH_CONTEXT__
+
+#include "lm/model.hh"
+#include "search/config.hh"
+#include "search/final.hh"
+#include "search/types.hh"
+#include "search/vertex.hh"
+#include "util/exception.hh"
+#include "util/pool.hh"
+
+#include <boost/pool/object_pool.hpp>
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <vector>
+
+namespace search {
+
+class Weights;
+
+class ContextBase {
+ public:
+ explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {}
+
+ util::Pool &FinalPool() {
+ return final_pool_;
+ }
+
+ VertexNode *NewVertexNode() {
+ VertexNode *ret = vertex_node_pool_.construct();
+ assert(ret);
+ return ret;
+ }
+
+ void DeleteVertexNode(VertexNode *node) {
+ vertex_node_pool_.destroy(node);
+ }
+
+ unsigned int PopLimit() const { return pop_limit_; }
+
+ const Weights &GetWeights() const { return weights_; }
+
+ private:
+ util::Pool final_pool_;
+
+ boost::object_pool<VertexNode> vertex_node_pool_;
+
+ unsigned int pop_limit_;
+
+ const Weights &weights_;
+};
+
+template <class Model> class Context : public ContextBase {
+ public:
+ Context(const Config &config, const Model &model) : ContextBase(config), model_(model) {}
+
+ const Model &LanguageModel() const { return model_; }
+
+ private:
+ const Model &model_;
+};
+
+} // namespace search
+
+#endif // SEARCH_CONTEXT__
diff --git a/klm/search/edge.hh b/klm/search/edge.hh
new file mode 100644
index 00000000..187904bf
--- /dev/null
+++ b/klm/search/edge.hh
@@ -0,0 +1,54 @@
+#ifndef SEARCH_EDGE__
+#define SEARCH_EDGE__
+
+#include "lm/state.hh"
+#include "search/header.hh"
+#include "search/types.hh"
+#include "search/vertex.hh"
+#include "util/pool.hh"
+
+#include <functional>
+
+#include <stdint.h>
+
+namespace search {
+
+// Copyable, but the copy will be shallow.
+class PartialEdge : public Header {
+ public:
+ // Allow default construction for STL.
+ PartialEdge() {}
+
+ PartialEdge(util::Pool &pool, Arity arity)
+ : Header(pool.Allocate(Size(arity, arity + 1)), arity) {}
+
+ PartialEdge(util::Pool &pool, Arity arity, Arity chart_states)
+ : Header(pool.Allocate(Size(arity, chart_states)), arity) {}
+
+ // Non-terminals
+ const PartialVertex *NT() const {
+ return reinterpret_cast<const PartialVertex*>(After());
+ }
+ PartialVertex *NT() {
+ return reinterpret_cast<PartialVertex*>(After());
+ }
+
+ const lm::ngram::ChartState &CompletedState() const {
+ return *Between();
+ }
+ const lm::ngram::ChartState *Between() const {
+ return reinterpret_cast<const lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex));
+ }
+ lm::ngram::ChartState *Between() {
+ return reinterpret_cast<lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex));
+ }
+
+ private:
+ static std::size_t Size(Arity arity, Arity chart_states) {
+ return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState);
+ }
+};
+
+
+} // namespace search
+#endif // SEARCH_EDGE__
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
new file mode 100644
index 00000000..260159b1
--- /dev/null
+++ b/klm/search/edge_generator.cc
@@ -0,0 +1,110 @@
+#include "search/edge_generator.hh"
+
+#include "lm/left.hh"
+#include "lm/partial.hh"
+#include "search/context.hh"
+#include "search/vertex.hh"
+
+#include <numeric>
+
+namespace search {
+
+namespace {
+
+template <class Model> void FastScore(const Context<Model> &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) {
+ lm::ngram::ChartState *between = update.Between();
+ lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1];
+
+ float adjustment = 0.0;
+ const lm::ngram::ChartState &previous_reveal = previous_vertex.State();
+ const PartialVertex &update_nt = update.NT()[victim];
+ const lm::ngram::ChartState &update_reveal = update_nt.State();
+ if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) {
+ adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length);
+ }
+ if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) {
+ adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right);
+ }
+ if (update_nt.Complete()) {
+ if (update_reveal.left.full) {
+ before->left.full = true;
+ } else {
+ assert(update_reveal.left.length == update_reveal.right.length);
+ adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length);
+ }
+ before->right = after->right;
+ // Shift the others shifted one down, covering after.
+ for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) {
+ *cover = *(cover + 1);
+ }
+ }
+ update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM());
+}
+
+} // namespace
+
+template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
+ assert(!generate_.empty());
+ PartialEdge top = generate_.top();
+ generate_.pop();
+ PartialVertex *const top_nt = top.NT();
+ const Arity arity = top.GetArity();
+
+ Arity victim = 0;
+ Arity victim_completed;
+ Arity incomplete;
+ // Select victim or return if complete.
+ {
+ Arity completed = 0;
+ unsigned char lowest_length = 255;
+ for (Arity i = 0; i != arity; ++i) {
+ if (top_nt[i].Complete()) {
+ ++completed;
+ } else if (top_nt[i].Length() < lowest_length) {
+ lowest_length = top_nt[i].Length();
+ victim = i;
+ victim_completed = completed;
+ }
+ }
+ if (lowest_length == 255) {
+ return top;
+ }
+ incomplete = arity - completed;
+ }
+
+ PartialVertex old_value(top_nt[victim]);
+ PartialVertex alternate_changed;
+ if (top_nt[victim].Split(alternate_changed)) {
+ PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1);
+ alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound());
+
+ alternate.SetNote(top.GetNote());
+
+ PartialVertex *alternate_nt = alternate.NT();
+ for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i];
+ alternate_nt[victim] = alternate_changed;
+ for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i];
+
+ memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1));
+
+ // TODO: dedupe?
+ generate_.push(alternate);
+ }
+
+ // top is now the continuation.
+ FastScore(context, victim, victim - victim_completed, incomplete, old_value, top);
+ // TODO: dedupe?
+ generate_.push(top);
+
+ // Invalid indicates no new hypothesis generated.
+ return PartialEdge();
+}
+
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context);
+
+} // namespace search
diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh
new file mode 100644
index 00000000..582c78b7
--- /dev/null
+++ b/klm/search/edge_generator.hh
@@ -0,0 +1,57 @@
+#ifndef SEARCH_EDGE_GENERATOR__
+#define SEARCH_EDGE_GENERATOR__
+
+#include "search/edge.hh"
+#include "search/note.hh"
+#include "search/types.hh"
+
+#include <queue>
+
+namespace lm {
+namespace ngram {
+class ChartState;
+} // namespace ngram
+} // namespace lm
+
+namespace search {
+
+template <class Model> class Context;
+
+class EdgeGenerator {
+ public:
+ EdgeGenerator() {}
+
+ PartialEdge AllocateEdge(Arity arity) {
+ return PartialEdge(partial_edge_pool_, arity);
+ }
+
+ void AddEdge(PartialEdge edge) {
+ generate_.push(edge);
+ }
+
+ bool Empty() const { return generate_.empty(); }
+
+ // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge.
+ template <class Model> PartialEdge Pop(Context<Model> &context);
+
+ template <class Model, class Output> void Search(Context<Model> &context, Output &output) {
+ unsigned to_pop = context.PopLimit();
+ while (to_pop > 0 && !generate_.empty()) {
+ PartialEdge got(Pop(context));
+ if (got.Valid()) {
+ output.NewHypothesis(got);
+ --to_pop;
+ }
+ }
+ output.FinishedSearch();
+ }
+
+ private:
+ util::Pool partial_edge_pool_;
+
+ typedef std::priority_queue<PartialEdge> Generate;
+ Generate generate_;
+};
+
+} // namespace search
+#endif // SEARCH_EDGE_GENERATOR__
diff --git a/klm/search/final.hh b/klm/search/final.hh
new file mode 100644
index 00000000..50e62cf2
--- /dev/null
+++ b/klm/search/final.hh
@@ -0,0 +1,36 @@
+#ifndef SEARCH_FINAL__
+#define SEARCH_FINAL__
+
+#include "search/header.hh"
+#include "util/pool.hh"
+
+namespace search {
+
+// A full hypothesis with pointers to children.
+class Final : public Header {
+ public:
+ Final() {}
+
+ Final(util::Pool &pool, Score score, Arity arity, Note note)
+ : Header(pool.Allocate(Size(arity)), arity) {
+ SetScore(score);
+ SetNote(note);
+ }
+
+ // These are arrays of length GetArity().
+ Final *Children() {
+ return reinterpret_cast<Final*>(After());
+ }
+ const Final *Children() const {
+ return reinterpret_cast<const Final*>(After());
+ }
+
+ private:
+ static std::size_t Size(Arity arity) {
+ return kHeaderSize + arity * sizeof(const Final);
+ }
+};
+
+} // namespace search
+
+#endif // SEARCH_FINAL__
diff --git a/klm/search/header.hh b/klm/search/header.hh
new file mode 100644
index 00000000..25550dbe
--- /dev/null
+++ b/klm/search/header.hh
@@ -0,0 +1,57 @@
+#ifndef SEARCH_HEADER__
+#define SEARCH_HEADER__
+
+// Header consisting of Score, Arity, and Note
+
+#include "search/note.hh"
+#include "search/types.hh"
+
+#include <stdint.h>
+
+namespace search {
+
+// Copying is shallow.
+class Header {
+ public:
+ bool Valid() const { return base_; }
+
+ Score GetScore() const {
+ return *reinterpret_cast<const float*>(base_);
+ }
+ void SetScore(Score to) {
+ *reinterpret_cast<float*>(base_) = to;
+ }
+ bool operator<(const Header &other) const {
+ return GetScore() < other.GetScore();
+ }
+
+ Arity GetArity() const {
+ return *reinterpret_cast<const Arity*>(base_ + sizeof(Score));
+ }
+
+ Note GetNote() const {
+ return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity));
+ }
+ void SetNote(Note to) {
+ *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to;
+ }
+
+ protected:
+ Header() : base_(NULL) {}
+
+ Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) {
+ *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity;
+ }
+
+ static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note);
+
+ uint8_t *After() { return base_ + kHeaderSize; }
+ const uint8_t *After() const { return base_ + kHeaderSize; }
+
+ private:
+ uint8_t *base_;
+};
+
+} // namespace search
+
+#endif // SEARCH_HEADER__
diff --git a/klm/search/note.hh b/klm/search/note.hh
new file mode 100644
index 00000000..50bed06e
--- /dev/null
+++ b/klm/search/note.hh
@@ -0,0 +1,12 @@
+#ifndef SEARCH_NOTE__
+#define SEARCH_NOTE__
+
+namespace search {
+
+union Note {
+ const void *vp;
+};
+
+} // namespace search
+
+#endif // SEARCH_NOTE__
diff --git a/klm/search/rule.cc b/klm/search/rule.cc
new file mode 100644
index 00000000..5b00207e
--- /dev/null
+++ b/klm/search/rule.cc
@@ -0,0 +1,43 @@
+#include "search/rule.hh"
+
+#include "search/context.hh"
+#include "search/final.hh"
+
+#include <ostream>
+
+#include <math.h>
+
+namespace search {
+
+template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) {
+ unsigned int oov_count = 0;
+ float prob = 0.0;
+ const Model &model = context.LanguageModel();
+ const lm::WordIndex oov = model.GetVocabulary().NotFound();
+ for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) {
+ lm::ngram::RuleScore<Model> scorer(model, *(writing++));
+ // TODO: optimize
+ if (prepend_bos && (word == words.begin())) {
+ scorer.BeginSentence();
+ }
+ for (; ; ++word) {
+ if (word == words.end()) {
+ prob += scorer.Finish();
+ return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM();
+ }
+ if (*word == kNonTerminal) break;
+ if (*word == oov) ++oov_count;
+ scorer.Terminal(*word);
+ }
+ prob += scorer.Finish();
+ }
+}
+
+template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+
+} // namespace search
diff --git a/klm/search/rule.hh b/klm/search/rule.hh
new file mode 100644
index 00000000..0ce2794d
--- /dev/null
+++ b/klm/search/rule.hh
@@ -0,0 +1,20 @@
+#ifndef SEARCH_RULE__
+#define SEARCH_RULE__
+
+#include "lm/left.hh"
+#include "lm/word_index.hh"
+#include "search/types.hh"
+
+#include <vector>
+
+namespace search {
+
+template <class Model> class Context;
+
+const lm::WordIndex kNonTerminal = lm::kMaxWordIndex;
+
+template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out);
+
+} // namespace search
+
+#endif // SEARCH_RULE__
diff --git a/klm/search/types.hh b/klm/search/types.hh
new file mode 100644
index 00000000..06eb5bfa
--- /dev/null
+++ b/klm/search/types.hh
@@ -0,0 +1,14 @@
+#ifndef SEARCH_TYPES__
+#define SEARCH_TYPES__
+
+#include <stdint.h>
+
+namespace search {
+
+typedef float Score;
+
+typedef uint32_t Arity;
+
+} // namespace search
+
+#endif // SEARCH_TYPES__
diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc
new file mode 100644
index 00000000..11f4631f
--- /dev/null
+++ b/klm/search/vertex.cc
@@ -0,0 +1,42 @@
+#include "search/vertex.hh"
+
+#include "search/context.hh"
+
+#include <algorithm>
+#include <functional>
+
+#include <assert.h>
+
+namespace search {
+
+namespace {
+
+struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> {
+ bool operator()(const VertexNode *first, const VertexNode *second) const {
+ return first->Bound() > second->Bound();
+ }
+};
+
+} // namespace
+
+void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) {
+ if (Complete()) {
+ assert(end_.Valid());
+ assert(extend_.empty());
+ bound_ = end_.GetScore();
+ return;
+ }
+ if (extend_.size() == 1 && parent_ptr) {
+ *parent_ptr = extend_[0];
+ extend_[0]->SortAndSet(context, parent_ptr);
+ context.DeleteVertexNode(this);
+ return;
+ }
+ for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
+ (*i)->SortAndSet(context, &*i);
+ }
+ std::sort(extend_.begin(), extend_.end(), GreaterByBound());
+ bound_ = extend_.front()->Bound();
+}
+
+} // namespace search
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
new file mode 100644
index 00000000..52bc1dfe
--- /dev/null
+++ b/klm/search/vertex.hh
@@ -0,0 +1,159 @@
+#ifndef SEARCH_VERTEX__
+#define SEARCH_VERTEX__
+
+#include "lm/left.hh"
+#include "search/final.hh"
+#include "search/types.hh"
+
+#include <boost/unordered_set.hpp>
+
+#include <queue>
+#include <vector>
+
+#include <stdint.h>
+
+namespace search {
+
+class ContextBase;
+
+class VertexNode {
+ public:
+ VertexNode() {}
+
+ void InitRoot() {
+ extend_.clear();
+ state_.left.full = false;
+ state_.left.length = 0;
+ state_.right.length = 0;
+ right_full_ = false;
+ end_ = Final();
+ }
+
+ lm::ngram::ChartState &MutableState() { return state_; }
+ bool &MutableRightFull() { return right_full_; }
+
+ void AddExtend(VertexNode *next) {
+ extend_.push_back(next);
+ }
+
+ void SetEnd(Final end) {
+ assert(!end_.Valid());
+ end_ = end;
+ }
+
+ void SortAndSet(ContextBase &context, VertexNode **parent_pointer);
+
+ // Should only happen to a root node when the entire vertex is empty.
+ bool Empty() const {
+ return !end_.Valid() && extend_.empty();
+ }
+
+ bool Complete() const {
+ return end_.Valid();
+ }
+
+ const lm::ngram::ChartState &State() const { return state_; }
+ bool RightFull() const { return right_full_; }
+
+ Score Bound() const {
+ return bound_;
+ }
+
+ unsigned char Length() const {
+ return state_.left.length + state_.right.length;
+ }
+
+ // Will be invalid unless this is a leaf.
+ const Final End() const { return end_; }
+
+ const VertexNode &operator[](size_t index) const {
+ return *extend_[index];
+ }
+
+ size_t Size() const {
+ return extend_.size();
+ }
+
+ private:
+ std::vector<VertexNode*> extend_;
+
+ lm::ngram::ChartState state_;
+ bool right_full_;
+
+ Score bound_;
+ Final end_;
+};
+
+class PartialVertex {
+ public:
+ PartialVertex() {}
+
+ explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {}
+
+ bool Empty() const { return back_->Empty(); }
+
+ bool Complete() const { return back_->Complete(); }
+
+ const lm::ngram::ChartState &State() const { return back_->State(); }
+ bool RightFull() const { return back_->RightFull(); }
+
+ Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); }
+
+ unsigned char Length() const { return back_->Length(); }
+
+ bool HasAlternative() const {
+ return index_ + 1 < back_->Size();
+ }
+
+ // Split into continuation and alternative, rendering this the continuation.
+ bool Split(PartialVertex &alternative) {
+ assert(!Complete());
+ bool ret;
+ if (index_ + 1 < back_->Size()) {
+ alternative.index_ = index_ + 1;
+ alternative.back_ = back_;
+ ret = true;
+ } else {
+ ret = false;
+ }
+ back_ = &((*back_)[index_]);
+ index_ = 0;
+ return ret;
+ }
+
+ const Final End() const {
+ return back_->End();
+ }
+
+ private:
+ const VertexNode *back_;
+ unsigned int index_;
+};
+
+class Vertex {
+ public:
+ Vertex() {}
+
+ PartialVertex RootPartial() const { return PartialVertex(root_); }
+
+ const Final BestChild() const {
+ PartialVertex top(RootPartial());
+ if (top.Empty()) {
+ return Final();
+ } else {
+ PartialVertex continuation;
+ while (!top.Complete()) {
+ top.Split(continuation);
+ }
+ return top.End();
+ }
+ }
+
+ private:
+ friend class VertexGenerator;
+
+ VertexNode root_;
+};
+
+} // namespace search
+#endif // SEARCH_VERTEX__
diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc
new file mode 100644
index 00000000..0945fe55
--- /dev/null
+++ b/klm/search/vertex_generator.cc
@@ -0,0 +1,94 @@
+#include "search/vertex_generator.hh"
+
+#include "lm/left.hh"
+#include "search/context.hh"
+#include "search/edge.hh"
+
+#include <stdint.h>
+
+namespace search {
+
+VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) {
+ gen.root_.InitRoot();
+}
+
+namespace {
+
+const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
+
+// Parallel structure to VertexNode.
+struct Trie {
+ Trie() : under(NULL) {}
+
+ VertexNode *under;
+ boost::unordered_map<uint64_t, Trie> extend;
+};
+
+Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
+ Trie &next = node.extend[added];
+ if (!next.under) {
+ next.under = context.NewVertexNode();
+ lm::ngram::ChartState &writing = next.under->MutableState();
+ writing = state;
+ writing.left.full &= left_full && state.left.full;
+ next.under->MutableRightFull() = right_full && state.left.full;
+ writing.left.length = left;
+ writing.right.length = right;
+ node.under->AddExtend(next.under);
+ }
+ return next;
+}
+
+void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) {
+ Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote());
+ Final *child_out = final.Children();
+ const PartialVertex *part = partial.NT();
+ const PartialVertex *const part_end_loop = part + partial.GetArity();
+ for (; part != part_end_loop; ++part, ++child_out)
+ *child_out = part->End();
+
+ starter.under->SetEnd(final);
+}
+
+void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
+ const lm::ngram::ChartState &state = partial.CompletedState();
+
+ unsigned char left = 0, right = 0;
+ Trie *node = &root;
+ while (true) {
+ if (left == state.left.length) {
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false);
+ for (; right < state.right.length; ++right) {
+ node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false);
+ }
+ break;
+ }
+ node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false);
+ left++;
+ if (right == state.right.length) {
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true);
+ for (; left < state.left.length; ++left) {
+ node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true);
+ }
+ break;
+ }
+ node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false);
+ right++;
+ }
+
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
+ CompleteTransition(context, *node, partial);
+}
+
+} // namespace
+
+void VertexGenerator::FinishedSearch() {
+ Trie root;
+ root.under = &gen_.root_;
+ for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) {
+ AddHypothesis(context_, root, i->second);
+ }
+ root.under->SortAndSet(context_, NULL);
+}
+
+} // namespace search
diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh
new file mode 100644
index 00000000..60e86112
--- /dev/null
+++ b/klm/search/vertex_generator.hh
@@ -0,0 +1,46 @@
+#ifndef SEARCH_VERTEX_GENERATOR__
+#define SEARCH_VERTEX_GENERATOR__
+
+#include "search/edge.hh"
+#include "search/vertex.hh"
+
+#include <boost/unordered_map.hpp>
+
+namespace lm {
+namespace ngram {
+class ChartState;
+} // namespace ngram
+} // namespace lm
+
+namespace search {
+
+class ContextBase;
+class Final;
+
+class VertexGenerator {
+ public:
+ VertexGenerator(ContextBase &context, Vertex &gen);
+
+ void NewHypothesis(PartialEdge partial) {
+ const lm::ngram::ChartState &state = partial.CompletedState();
+ std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial)));
+ if (!ret.second && ret.first->second < partial) {
+ ret.first->second = partial;
+ }
+ }
+
+ void FinishedSearch();
+
+ const Vertex &Generating() const { return gen_; }
+
+ private:
+ ContextBase &context_;
+
+ Vertex &gen_;
+
+ typedef boost::unordered_map<uint64_t, PartialEdge> Existing;
+ Existing existing_;
+};
+
+} // namespace search
+#endif // SEARCH_VERTEX_GENERATOR__
diff --git a/klm/search/weights.cc b/klm/search/weights.cc
new file mode 100644
index 00000000..d65471ad
--- /dev/null
+++ b/klm/search/weights.cc
@@ -0,0 +1,71 @@
+#include "search/weights.hh"
+#include "util/tokenize_piece.hh"
+
+#include <cstdlib>
+
+namespace search {
+
+namespace {
+struct Insert {
+ void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const {
+ std::string copy(name.data(), name.size());
+ map[copy] = score;
+ }
+};
+
+struct DotProduct {
+ search::Score total;
+ DotProduct() : total(0.0) {}
+
+ void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) {
+ boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name));
+ if (i != map.end())
+ total += score * i->second;
+ }
+};
+
+template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) {
+ for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) {
+ util::TokenIter<util::SingleCharacter> equals(*spaces, '=');
+ UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces);
+ StringPiece name(*equals);
+ UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces);
+ char *end;
+ // Assumes proper termination.
+ double value = std::strtod(equals->data(), &end);
+ UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals);
+ UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces);
+ op(map, name, value);
+ }
+}
+
+} // namespace
+
+Weights::Weights(StringPiece text) {
+ Insert op;
+ Parse<Map, Insert>(text, map_, op);
+ lm_ = Steal("LanguageModel");
+ oov_ = Steal("OOV");
+ word_penalty_ = Steal("WordPenalty");
+}
+
+Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {}
+
+search::Score Weights::DotNoLM(StringPiece text) const {
+ DotProduct dot;
+ Parse<const Map, DotProduct>(text, map_, dot);
+ return dot.total;
+}
+
+float Weights::Steal(const std::string &str) {
+ Map::iterator i(map_.find(str));
+ if (i == map_.end()) {
+ return 0.0;
+ } else {
+ float ret = i->second;
+ map_.erase(i);
+ return ret;
+ }
+}
+
+} // namespace search
diff --git a/klm/search/weights.hh b/klm/search/weights.hh
new file mode 100644
index 00000000..df1c419f
--- /dev/null
+++ b/klm/search/weights.hh
@@ -0,0 +1,52 @@
+// For now, the individual features are not kept.
+#ifndef SEARCH_WEIGHTS__
+#define SEARCH_WEIGHTS__
+
+#include "search/types.hh"
+#include "util/exception.hh"
+#include "util/string_piece.hh"
+
+#include <boost/unordered_map.hpp>
+
+#include <string>
+
+namespace search {
+
+class WeightParseException : public util::Exception {
+ public:
+ WeightParseException() {}
+ ~WeightParseException() throw() {}
+};
+
+class Weights {
+ public:
+ // Parses weights, sets lm_weight_, removes it from map_.
+ explicit Weights(StringPiece text);
+
+ // Just the three scores we care about adding.
+ Weights(Score lm, Score oov, Score word_penalty);
+
+ Score DotNoLM(StringPiece text) const;
+
+ Score LM() const { return lm_; }
+
+ Score OOV() const { return oov_; }
+
+ Score WordPenalty() const { return word_penalty_; }
+
+ // Mostly for testing.
+ const boost::unordered_map<std::string, Score> &GetMap() const { return map_; }
+
+ private:
+ float Steal(const std::string &str);
+
+ typedef boost::unordered_map<std::string, Score> Map;
+
+ Map map_;
+
+ Score lm_, oov_, word_penalty_;
+};
+
+} // namespace search
+
+#endif // SEARCH_WEIGHTS__
diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc
new file mode 100644
index 00000000..4811ff06
--- /dev/null
+++ b/klm/search/weights_test.cc
@@ -0,0 +1,38 @@
+#include "search/weights.hh"
+
+#define BOOST_TEST_MODULE WeightTest
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+
+namespace search {
+namespace {
+
+#define CHECK_WEIGHT(value, string) \
+ i = parsed.find(string); \
+ BOOST_REQUIRE(i != parsed.end()); \
+ BOOST_CHECK_CLOSE((value), i->second, 0.001);
+
+BOOST_AUTO_TEST_CASE(parse) {
+ // These are not real feature weights.
+ Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
+ const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap();
+ boost::unordered_map<std::string, search::Score>::const_iterator i;
+ CHECK_WEIGHT(0.0, "rarity");
+ CHECK_WEIGHT(0.0, "phrase-SGT");
+ CHECK_WEIGHT(9.45117, "phrase-TGS");
+ CHECK_WEIGHT(2.33833, "lexical-SGT");
+ BOOST_CHECK(parsed.end() == parsed.find("lm"));
+ BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001);
+ CHECK_WEIGHT(-28.3317, "lexical-TGS");
+ CHECK_WEIGHT(5.0, "glue?");
+}
+
+BOOST_AUTO_TEST_CASE(dot) {
+ Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
+ BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001);
+ BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001);
+ BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001);
+}
+
+} // namespace
+} // namespace search
diff --git a/klm/util/Jamfile b/klm/util/Jamfile
deleted file mode 100644
index 3ee2c2c2..00000000
--- a/klm/util/Jamfile
+++ /dev/null
@@ -1,10 +0,0 @@
-lib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc usage.cc ../..//z : <include>.. : : <include>.. ;
-
-import testing ;
-
-unit-test bit_packing_test : bit_packing_test.cc kenutil ../..///boost_unit_test_framework ;
-run file_piece_test.cc kenutil ../..///boost_unit_test_framework : : file_piece.cc ;
-unit-test joint_sort_test : joint_sort_test.cc kenutil ../..///boost_unit_test_framework ;
-unit-test probing_hash_table_test : probing_hash_table_test.cc kenutil ../..///boost_unit_test_framework ;
-unit-test sorted_uniform_test : sorted_uniform_test.cc kenutil ../..///boost_unit_test_framework ;
-unit-test tokenize_piece_test : tokenize_piece_test.cc kenutil ../..///boost_unit_test_framework ;
diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am
index 5ceccf2c..5306850f 100644
--- a/klm/util/Makefile.am
+++ b/klm/util/Makefile.am
@@ -26,6 +26,8 @@ libklm_util_a_SOURCES = \
file_piece.cc \
mmap.cc \
murmur_hash.cc \
+ pool.cc \
+ string_piece.cc \
usage.cc
AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc
index 07b14e26..eb635ad8 100644
--- a/klm/util/ersatz_progress.cc
+++ b/klm/util/ersatz_progress.cc
@@ -9,16 +9,16 @@ namespace util {
namespace { const unsigned char kWidth = 100; }
-ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<std::size_t>::max()), complete_(next_), out_(NULL) {}
+ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<uint64_t>::max()), complete_(next_), out_(NULL) {}
ErsatzProgress::~ErsatzProgress() {
if (out_) Finished();
}
-ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std::string &message)
+ErsatzProgress::ErsatzProgress(uint64_t complete, std::ostream *to, const std::string &message)
: current_(0), next_(complete / kWidth), complete_(complete), stones_written_(0), out_(to) {
if (!out_) {
- next_ = std::numeric_limits<std::size_t>::max();
+ next_ = std::numeric_limits<uint64_t>::max();
return;
}
if (!message.empty()) *out_ << message << '\n';
@@ -28,14 +28,14 @@ ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std
void ErsatzProgress::Milestone() {
if (!out_) { current_ = 0; return; }
if (!complete_) return;
- unsigned char stone = std::min(static_cast<std::size_t>(kWidth), (current_ * kWidth) / complete_);
+ unsigned char stone = std::min(static_cast<uint64_t>(kWidth), (current_ * kWidth) / complete_);
for (; stones_written_ < stone; ++stones_written_) {
(*out_) << '*';
}
if (stone == kWidth) {
(*out_) << std::endl;
- next_ = std::numeric_limits<std::size_t>::max();
+ next_ = std::numeric_limits<uint64_t>::max();
out_ = NULL;
} else {
next_ = std::max(next_, (stone * complete_) / kWidth);
diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh
index f709dc51..9909736d 100644
--- a/klm/util/ersatz_progress.hh
+++ b/klm/util/ersatz_progress.hh
@@ -4,6 +4,8 @@
#include <iostream>
#include <string>
+#include <stdint.h>
+
// Ersatz version of boost::progress so core language model doesn't depend on
// boost. Also adds option to print nothing.
@@ -14,7 +16,7 @@ class ErsatzProgress {
ErsatzProgress();
// Null means no output. The null value is useful for passing along the ostream pointer from another caller.
- explicit ErsatzProgress(std::size_t complete, std::ostream *to = &std::cerr, const std::string &message = "");
+ explicit ErsatzProgress(uint64_t complete, std::ostream *to = &std::cerr, const std::string &message = "");
~ErsatzProgress();
@@ -23,12 +25,12 @@ class ErsatzProgress {
return *this;
}
- ErsatzProgress &operator+=(std::size_t amount) {
+ ErsatzProgress &operator+=(uint64_t amount) {
if ((current_ += amount) >= next_) Milestone();
return *this;
}
- void Set(std::size_t to) {
+ void Set(uint64_t to) {
if ((current_ = to) >= next_) Milestone();
Milestone();
}
@@ -40,7 +42,7 @@ class ErsatzProgress {
private:
void Milestone();
- std::size_t current_, next_, complete_;
+ uint64_t current_, next_, complete_;
unsigned char stones_written_;
std::ostream *out_;
diff --git a/klm/util/exception.cc b/klm/util/exception.cc
index c4f8c04c..3806e6de 100644
--- a/klm/util/exception.cc
+++ b/klm/util/exception.cc
@@ -84,4 +84,7 @@ EndOfFileException::EndOfFileException() throw() {
}
EndOfFileException::~EndOfFileException() throw() {}
+OverflowException::OverflowException() throw() {}
+OverflowException::~OverflowException() throw() {}
+
} // namespace util
diff --git a/klm/util/exception.hh b/klm/util/exception.hh
index 6d6a37cb..053a850b 100644
--- a/klm/util/exception.hh
+++ b/klm/util/exception.hh
@@ -2,9 +2,12 @@
#define UTIL_EXCEPTION__
#include <exception>
+#include <limits>
#include <sstream>
#include <string>
+#include <stdint.h>
+
namespace util {
template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data);
@@ -111,6 +114,25 @@ class EndOfFileException : public Exception {
~EndOfFileException() throw();
};
+class OverflowException : public Exception {
+ public:
+ OverflowException() throw();
+ ~OverflowException() throw();
+};
+
+template <unsigned len> inline std::size_t CheckOverflowInternal(uint64_t value) {
+ UTIL_THROW_IF(value > static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), OverflowException, "Integer overflow detected. This model is too big for 32-bit code.");
+ return value;
+}
+
+template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) {
+ return value;
+}
+
+inline std::size_t CheckOverflow(uint64_t value) {
+ return CheckOverflowInternal<sizeof(std::size_t)>(value);
+}
+
} // namespace util
#endif // UTIL_EXCEPTION__
diff --git a/klm/util/file.cc b/klm/util/file.cc
index 6a3885a7..6bf879ac 100644
--- a/klm/util/file.cc
+++ b/klm/util/file.cc
@@ -6,6 +6,7 @@
#include <cstdio>
#include <iostream>
+#include <assert.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
@@ -44,6 +45,16 @@ int OpenReadOrThrow(const char *name) {
return ret;
}
+int CreateOrThrow(const char *name) {
+ int ret;
+#if defined(_WIN32) || defined(_WIN64)
+ UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name);
+#else
+ UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name);
+#endif
+ return ret;
+}
+
uint64_t SizeFile(int fd) {
#if defined(_WIN32) || defined(_WIN64)
__int64 ret = _filelengthi64(fd);
@@ -101,6 +112,11 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
}
}
+void WriteOrThrow(FILE *to, const void *data, std::size_t size) {
+ assert(size);
+ if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
+}
+
void FSyncOrThrow(int fd) {
// Apparently windows doesn't have fsync?
#if !defined(_WIN32) && !defined(_WIN64)
@@ -109,8 +125,13 @@ void FSyncOrThrow(int fd) {
}
namespace {
-void InternalSeek(int fd, off_t off, int whence) {
+void InternalSeek(int fd, int64_t off, int whence) {
+#if defined(_WIN32) || defined(_WIN64)
+ UTIL_THROW_IF((__int64)-1 == _lseeki64(fd, off, whence), ErrnoException, "Windows seek failed");
+
+#else
UTIL_THROW_IF((off_t)-1 == lseek(fd, off, whence), ErrnoException, "Seek failed");
+#endif
}
} // namespace
@@ -133,6 +154,12 @@ std::FILE *FDOpenOrThrow(scoped_fd &file) {
return ret;
}
+std::FILE *FOpenOrThrow(const char *path, const char *mode) {
+ std::FILE *ret;
+ UTIL_THROW_IF(!(ret = fopen(path, mode)), util::ErrnoException, "Could not fopen " << path << " for " << mode);
+ return ret;
+}
+
TempMaker::TempMaker(const std::string &prefix) : base_(prefix) {
base_ += "XXXXXX";
}
@@ -232,7 +259,9 @@ mkstemp_and_unlink(char *tmpl)
/* Modified for windows and to unlink */
// fd = open (tmpl, O_RDWR | O_CREAT | O_EXCL, _S_IREAD | _S_IWRITE);
- fd = _open (tmpl, _O_RDWR | _O_CREAT | _O_TEMPORARY | _O_EXCL | _O_BINARY, _S_IREAD | _S_IWRITE);
+ int flags = _O_RDWR | _O_CREAT | _O_EXCL | _O_BINARY;
+ flags |= _O_TEMPORARY;
+ fd = _open (tmpl, flags, _S_IREAD | _S_IWRITE);
if (fd >= 0)
{
errno = save_errno;
@@ -250,17 +279,18 @@ mkstemp_and_unlink(char *tmpl)
int
mkstemp_and_unlink(char *tmpl) {
int ret = mkstemp(tmpl);
- if (ret == -1) return -1;
- UTIL_THROW_IF(unlink(tmpl), util::ErrnoException, "Failed to delete " << tmpl);
+ if (ret != -1) {
+ UTIL_THROW_IF(unlink(tmpl), util::ErrnoException, "Failed to delete " << tmpl);
+ }
return ret;
}
#endif
int TempMaker::Make() const {
- std::string copy(base_);
- copy.push_back(0);
+ std::string name(base_);
+ name.push_back(0);
int ret;
- UTIL_THROW_IF(-1 == (ret = mkstemp_and_unlink(&copy[0])), util::ErrnoException, "Failed to make a temporary based on " << base_);
+ UTIL_THROW_IF(-1 == (ret = mkstemp_and_unlink(&name[0])), util::ErrnoException, "Failed to make a temporary based on " << base_);
return ret;
}
diff --git a/klm/util/file.hh b/klm/util/file.hh
index 5c57e2a9..185cb1f3 100644
--- a/klm/util/file.hh
+++ b/klm/util/file.hh
@@ -65,7 +65,10 @@ class scoped_FILE {
std::FILE *file_;
};
+// Open for read only.
int OpenReadOrThrow(const char *name);
+// Create file if it doesn't exist, truncate if it does. Opened for write.
+int CreateOrThrow(const char *name);
// Return value for SizeFile when it can't size properly.
const uint64_t kBadSize = (uint64_t)-1;
@@ -77,6 +80,7 @@ void ReadOrThrow(int fd, void *to, std::size_t size);
std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount);
void WriteOrThrow(int fd, const void *data_void, std::size_t size);
+void WriteOrThrow(FILE *to, const void *data, std::size_t size);
void FSyncOrThrow(int fd);
@@ -87,12 +91,14 @@ void SeekEnd(int fd);
std::FILE *FDOpenOrThrow(scoped_fd &file);
+std::FILE *FOpenOrThrow(const char *path, const char *mode);
+
class TempMaker {
public:
explicit TempMaker(const std::string &prefix);
+ // These will already be unlinked for you.
int Make() const;
-
std::FILE *MakeFile() const;
private:
diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc
index a205995a..280f438c 100644
--- a/klm/util/file_piece.cc
+++ b/klm/util/file_piece.cc
@@ -5,6 +5,8 @@
#include "util/mmap.hh"
#ifdef WIN32
#include <io.h>
+#else
+#include <unistd.h>
#endif // WIN32
#include <iostream>
@@ -27,7 +29,7 @@ ParseNumberException::ParseNumberException(StringPiece value) throw() {
#ifdef HAVE_ZLIB
GZException::GZException(gzFile file) {
int num;
- *this << gzerror( file, &num) << " from zlib";
+ *this << gzerror(file, &num) << " from zlib";
}
#endif // HAVE_ZLIB
diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc
index 576fd4cc..bc9e3f81 100644
--- a/klm/util/mmap.cc
+++ b/klm/util/mmap.cc
@@ -19,8 +19,8 @@
#include <windows.h>
#include <io.h>
#else
-#include <unistd.h>
#include <sys/mman.h>
+#include <unistd.h>
#endif
namespace util {
@@ -171,20 +171,6 @@ void *MapZeroedWrite(int fd, std::size_t size) {
return MapOrThrow(size, true, kFileFlags, false, fd, 0);
}
-namespace {
-
-int CreateOrThrow(const char *name) {
- int ret;
-#if defined(_WIN32) || defined(_WIN64)
- UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name);
-#else
- UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name);
-#endif
- return ret;
-}
-
-} // namespace
-
void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) {
file.reset(CreateOrThrow(name));
try {
diff --git a/klm/util/pool.cc b/klm/util/pool.cc
new file mode 100644
index 00000000..2dffd06f
--- /dev/null
+++ b/klm/util/pool.cc
@@ -0,0 +1,35 @@
+#include "util/pool.hh"
+
+#include <stdlib.h>
+
+namespace util {
+
+Pool::Pool() {
+ current_ = NULL;
+ current_end_ = NULL;
+}
+
+Pool::~Pool() {
+ FreeAll();
+}
+
+void Pool::FreeAll() {
+ for (std::vector<void *>::const_iterator i(free_list_.begin()); i != free_list_.end(); ++i) {
+ free(*i);
+ }
+ free_list_.clear();
+ current_ = NULL;
+ current_end_ = NULL;
+}
+
+void *Pool::More(std::size_t size) {
+ std::size_t amount = std::max(static_cast<size_t>(32) << free_list_.size(), size);
+ uint8_t *ret = static_cast<uint8_t*>(malloc(amount));
+ if (!ret) throw std::bad_alloc();
+ free_list_.push_back(ret);
+ current_ = ret + size;
+ current_end_ = ret + amount;
+ return ret;
+}
+
+} // namespace util
diff --git a/klm/util/pool.hh b/klm/util/pool.hh
new file mode 100644
index 00000000..72f8a0c8
--- /dev/null
+++ b/klm/util/pool.hh
@@ -0,0 +1,45 @@
+// Very simple pool. It can only allocate memory. And all of the memory it
+// allocates must be freed at the same time.
+
+#ifndef UTIL_POOL__
+#define UTIL_POOL__
+
+#include <vector>
+
+#include <stdint.h>
+
+namespace util {
+
+class Pool {
+ public:
+ Pool();
+
+ ~Pool();
+
+ void *Allocate(std::size_t size) {
+ void *ret = current_;
+ current_ += size;
+ if (current_ < current_end_) {
+ return ret;
+ } else {
+ return More(size);
+ }
+ }
+
+ void FreeAll();
+
+ private:
+ void *More(std::size_t size);
+
+ std::vector<void *> free_list_;
+
+ uint8_t *current_, *current_end_;
+
+ // no copying
+ Pool(const Pool &);
+ Pool &operator=(const Pool &);
+};
+
+} // namespace util
+
+#endif // UTIL_POOL__
diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh
index 3354b68e..4a8aff35 100644
--- a/klm/util/probing_hash_table.hh
+++ b/klm/util/probing_hash_table.hh
@@ -8,6 +8,7 @@
#include <functional>
#include <assert.h>
+#include <stdint.h>
namespace util {
@@ -42,8 +43,8 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
typedef EqualT Equal;
public:
- static std::size_t Size(std::size_t entries, float multiplier) {
- std::size_t buckets = std::max(entries + 1, static_cast<std::size_t>(multiplier * static_cast<float>(entries)));
+ static uint64_t Size(uint64_t entries, float multiplier) {
+ uint64_t buckets = std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries)));
return buckets * sizeof(Entry);
}
diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc
new file mode 100644
index 00000000..b422cefc
--- /dev/null
+++ b/klm/util/string_piece.cc
@@ -0,0 +1,192 @@
+// Copyright 2004 The RE2 Authors. All Rights Reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in string_piece.hh.
+
+#include "util/string_piece.hh"
+
+#include <algorithm>
+
+#include <limits.h>
+
+#ifndef HAVE_ICU
+
+typedef StringPiece::size_type size_type;
+
+void StringPiece::CopyToString(std::string* target) const {
+ target->assign(ptr_, length_);
+}
+
+size_type StringPiece::find(const StringPiece& s, size_type pos) const {
+ if (length_ < 0 || pos > static_cast<size_type>(length_))
+ return npos;
+
+ const char* result = std::search(ptr_ + pos, ptr_ + length_,
+ s.ptr_, s.ptr_ + s.length_);
+ const size_type xpos = result - ptr_;
+ return xpos + s.length_ <= length_ ? xpos : npos;
+}
+
+size_type StringPiece::find(char c, size_type pos) const {
+ if (length_ <= 0 || pos >= static_cast<size_type>(length_)) {
+ return npos;
+ }
+ const char* result = std::find(ptr_ + pos, ptr_ + length_, c);
+ return result != ptr_ + length_ ? result - ptr_ : npos;
+}
+
+size_type StringPiece::rfind(const StringPiece& s, size_type pos) const {
+ if (length_ < s.length_) return npos;
+ const size_t ulen = length_;
+ if (s.length_ == 0) return std::min(ulen, pos);
+
+ const char* last = ptr_ + std::min(ulen - s.length_, pos) + s.length_;
+ const char* result = std::find_end(ptr_, last, s.ptr_, s.ptr_ + s.length_);
+ return result != last ? result - ptr_ : npos;
+}
+
+size_type StringPiece::rfind(char c, size_type pos) const {
+ if (length_ <= 0) return npos;
+ for (int i = std::min(pos, static_cast<size_type>(length_ - 1));
+ i >= 0; --i) {
+ if (ptr_[i] == c) {
+ return i;
+ }
+ }
+ return npos;
+}
+
+// For each character in characters_wanted, sets the index corresponding
+// to the ASCII code of that character to 1 in table. This is used by
+// the find_.*_of methods below to tell whether or not a character is in
+// the lookup table in constant time.
+// The argument `table' must be an array that is large enough to hold all
+// the possible values of an unsigned char. Thus it should be be declared
+// as follows:
+// bool table[UCHAR_MAX + 1]
+static inline void BuildLookupTable(const StringPiece& characters_wanted,
+ bool* table) {
+ const size_type length = characters_wanted.length();
+ const char* const data = characters_wanted.data();
+ for (size_type i = 0; i < length; ++i) {
+ table[static_cast<unsigned char>(data[i])] = true;
+ }
+}
+
+size_type StringPiece::find_first_of(const StringPiece& s,
+ size_type pos) const {
+ if (length_ == 0 || s.length_ == 0)
+ return npos;
+
+ // Avoid the cost of BuildLookupTable() for a single-character search.
+ if (s.length_ == 1)
+ return find_first_of(s.ptr_[0], pos);
+
+ bool lookup[UCHAR_MAX + 1] = { false };
+ BuildLookupTable(s, lookup);
+ for (size_type i = pos; i < length_; ++i) {
+ if (lookup[static_cast<unsigned char>(ptr_[i])]) {
+ return i;
+ }
+ }
+ return npos;
+}
+
+size_type StringPiece::find_first_not_of(const StringPiece& s,
+ size_type pos) const {
+ if (length_ == 0)
+ return npos;
+
+ if (s.length_ == 0)
+ return 0;
+
+ // Avoid the cost of BuildLookupTable() for a single-character search.
+ if (s.length_ == 1)
+ return find_first_not_of(s.ptr_[0], pos);
+
+ bool lookup[UCHAR_MAX + 1] = { false };
+ BuildLookupTable(s, lookup);
+ for (size_type i = pos; i < length_; ++i) {
+ if (!lookup[static_cast<unsigned char>(ptr_[i])]) {
+ return i;
+ }
+ }
+ return npos;
+}
+
+size_type StringPiece::find_first_not_of(char c, size_type pos) const {
+ if (length_ == 0)
+ return npos;
+
+ for (; pos < length_; ++pos) {
+ if (ptr_[pos] != c) {
+ return pos;
+ }
+ }
+ return npos;
+}
+
+size_type StringPiece::find_last_of(const StringPiece& s, size_type pos) const {
+ if (length_ == 0 || s.length_ == 0)
+ return npos;
+
+ // Avoid the cost of BuildLookupTable() for a single-character search.
+ if (s.length_ == 1)
+ return find_last_of(s.ptr_[0], pos);
+
+ bool lookup[UCHAR_MAX + 1] = { false };
+ BuildLookupTable(s, lookup);
+ for (size_type i = std::min(pos, length_ - 1); ; --i) {
+ if (lookup[static_cast<unsigned char>(ptr_[i])])
+ return i;
+ if (i == 0)
+ break;
+ }
+ return npos;
+}
+
+size_type StringPiece::find_last_not_of(const StringPiece& s,
+ size_type pos) const {
+ if (length_ == 0)
+ return npos;
+
+ size_type i = std::min(pos, length_ - 1);
+ if (s.length_ == 0)
+ return i;
+
+ // Avoid the cost of BuildLookupTable() for a single-character search.
+ if (s.length_ == 1)
+ return find_last_not_of(s.ptr_[0], pos);
+
+ bool lookup[UCHAR_MAX + 1] = { false };
+ BuildLookupTable(s, lookup);
+ for (; ; --i) {
+ if (!lookup[static_cast<unsigned char>(ptr_[i])])
+ return i;
+ if (i == 0)
+ break;
+ }
+ return npos;
+}
+
+size_type StringPiece::find_last_not_of(char c, size_type pos) const {
+ if (length_ == 0)
+ return npos;
+
+ for (size_type i = std::min(pos, length_ - 1); ; --i) {
+ if (ptr_[i] != c)
+ return i;
+ if (i == 0)
+ break;
+ }
+ return npos;
+}
+
+StringPiece StringPiece::substr(size_type pos, size_type n) const {
+ if (pos > length_) pos = length_;
+ if (n > length_ - pos) n = length_ - pos;
+ return StringPiece(ptr_ + pos, n);
+}
+
+const size_type StringPiece::npos = size_type(-1);
+
+#endif // !HAVE_ICU
diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh
index 5de053aa..be6a643d 100644
--- a/klm/util/string_piece.hh
+++ b/klm/util/string_piece.hh
@@ -85,6 +85,11 @@ U_NAMESPACE_BEGIN
#include <string>
#include <string.h>
+#ifdef WIN32
+#undef max
+#undef min
+#endif
+
class StringPiece {
public:
typedef size_t size_type;
diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh
index c7e1c863..4a7f5460 100644
--- a/klm/util/tokenize_piece.hh
+++ b/klm/util/tokenize_piece.hh
@@ -54,6 +54,18 @@ class AnyCharacter {
StringPiece chars_;
};
+class AnyCharacterLast {
+ public:
+ explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {}
+
+ StringPiece Find(const StringPiece &in) const {
+ return StringPiece(std::find_end(in.data(), in.data() + in.size(), chars_.data(), chars_.data() + chars_.size()), 1);
+ }
+
+ private:
+ StringPiece chars_;
+};
+
template <class Find, bool SkipEmpty = false> class TokenIter : public boost::iterator_facade<TokenIter<Find, SkipEmpty>, const StringPiece, boost::forward_traversal_tag> {
public:
TokenIter() {}