diff options
author | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-11-05 15:29:46 +0100 |
---|---|---|
committer | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-11-05 15:29:46 +0100 |
commit | 1db70a45d59946560fbd5db6487b55a8674ef973 (patch) | |
tree | 172585dafe4d1462f22d8200e733d52dddb55b1e /klm | |
parent | 4dd5216d3afa9ab72b150e250a3c30a5f223ce53 (diff) | |
parent | 6bbf03ac46bd57400aa9e65a321a304a234af935 (diff) |
merge upstream/master
Diffstat (limited to 'klm')
67 files changed, 1973 insertions, 244 deletions
diff --git a/klm/lm/Jamfile b/klm/lm/Jamfile deleted file mode 100644 index b1971d88..00000000 --- a/klm/lm/Jamfile +++ /dev/null @@ -1,14 +0,0 @@ -lib kenlm : bhiksha.cc binary_format.cc config.cc lm_exception.cc model.cc quantize.cc read_arpa.cc search_hashed.cc search_trie.cc trie.cc trie_sort.cc value_build.cc virtual_interface.cc vocab.cc ../util//kenutil : <include>.. : : <include>.. <library>../util//kenutil ; - -import testing ; - -run left_test.cc ../util//kenutil kenlm ../..//boost_unit_test_framework : : test.arpa ; -run model_test.cc ../util//kenutil kenlm ../..//boost_unit_test_framework : : test.arpa test_nounk.arpa ; - -exe query : ngram_query.cc kenlm ../util//kenutil ; -exe build_binary : build_binary.cc kenlm ../util//kenutil ; - -install legacy : build_binary query - : <location>$(TOP)/klm/lm <install-type>EXE <install-dependencies>on <link>shared:<dll-path>$(TOP)/klm/lm <link>shared:<install-type>LIB ; - -alias programs : build_binary query ; diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc index cdeafb47..088ea98d 100644 --- a/klm/lm/bhiksha.cc +++ b/klm/lm/bhiksha.cc @@ -1,6 +1,7 @@ #include "lm/bhiksha.hh" #include "lm/config.hh" #include "util/file.hh" +#include "util/exception.hh" #include <limits> @@ -49,7 +50,7 @@ std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &con } } // namespace -std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { +uint64_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */; } diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index 5182ee2e..8ff88654 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -23,7 +23,7 @@ namespace lm { namespace ngram { -class Config; +struct Config; namespace trie { @@ -33,7 +33,7 @@ class DontBhiksha { static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} - static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } + static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) { return util::RequiredBits(max_next); @@ -67,7 +67,7 @@ class ArrayBhiksha { static void UpdateConfigFromBinary(int fd, Config &config); - static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); + static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config); diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index a56e998e..efa67056 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -83,7 +83,13 @@ void WriteHeader(void *to, const Parameters ¶ms) { uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { if (config.write_mmap) { std::size_t total = TotalHeaderSize(order) + memory_size; - backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED); + backing.file.reset(util::CreateOrThrow(config.write_mmap)); + if (config.write_method == Config::WRITE_MMAP) { + backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED); + } else { + util::ResizeOrThrow(backing.file.get(), 0); + util::MapAnonymous(total, backing.vocab); + } strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order)); return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order); } else { @@ -121,12 +127,14 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { if (!config.write_mmap) return; - util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); switch (config.write_method) { case Config::WRITE_MMAP: + util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); util::SyncOrThrow(backing.search.get(), backing.search.size()); break; case Config::WRITE_AFTER: + util::SeekOrThrow(backing.file.get(), 0); + util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size()); util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad); util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size()); util::FSyncOrThrow(backing.file.get()); @@ -141,6 +149,10 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_ params.fixed.has_vocabulary = config.include_vocab; params.fixed.search_version = search_version; WriteHeader(backing.vocab.get(), params); + if (config.write_method == Config::WRITE_AFTER) { + util::SeekOrThrow(backing.file.get(), 0); + util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size())); + } } namespace detail { @@ -200,10 +212,10 @@ void SeekPastHeader(int fd, const Parameters ¶ms) { util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); } -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { +uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) { const uint64_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. - std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size; + std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size); if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index dd795f62..bf699d5f 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -70,7 +70,7 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet void SeekPastHeader(int fd, const Parameters ¶ms); -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing); +uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing); void ComplainAboutARPA(const Config &config, ModelType model_type); @@ -90,7 +90,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to) new_config.probing_multiplier = params.fixed.probing_multiplier; detail::SeekPastHeader(backing.file.get(), params); To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); - std::size_t memory_size = To::Size(params.counts, new_config); + uint64_t memory_size = To::Size(params.counts, new_config); uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); to.InitializeFromBinary(start, params, new_config, backing.file.get()); } else { diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index c4a01cb4..2b8c9d5b 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -11,6 +11,8 @@ #ifdef WIN32 #include "util/getopt.hh" +#else +#include <unistd.h> #endif namespace lm { @@ -25,7 +27,11 @@ void Usage(const char *name) { "-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" "-w mmap|after determines how writing is done.\n" " mmap maps the binary file and writes to it. Default for trie.\n" -" after allocates anonymous memory, builds, and writes. Default for probing.\n\n" +" after allocates anonymous memory, builds, and writes. Default for probing.\n" +"-r \"order1.arpa order2 order3 order4\" adds lower-order rest costs from these\n" +" model files. order1.arpa must be an ARPA file. All others may be ARPA or\n" +" the same data structure as being built. All files must have the same\n" +" vocabulary. For probing, the unigrams must be in the same order.\n\n" "type is either probing or trie. Default is probing.\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" @@ -81,16 +87,16 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector<uint64_t> counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t sizes[6]; + uint64_t sizes[6]; sizes[0] = ProbingModel::Size(counts, config); sizes[1] = RestProbingModel::Size(counts, config); sizes[2] = TrieModel::Size(counts, config); sizes[3] = QuantTrieModel::Size(counts, config); sizes[4] = ArrayTrieModel::Size(counts, config); sizes[5] = QuantArrayTrieModel::Size(counts, config); - std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); - std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); - std::size_t divide; + uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); + uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); + uint64_t divide; char prefix; if (min_length < (1 << 10) * 10) { prefix = ' '; @@ -111,7 +117,7 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { for (long int i = 0; i < length - 2; ++i) std::cout << ' '; std::cout << prefix << "B\n" "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" - "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r -p " << config.probing_multiplier << "\n" + "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n" "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n" "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" diff --git a/klm/lm/fragment.cc b/klm/lm/fragment.cc new file mode 100644 index 00000000..0267cd4e --- /dev/null +++ b/klm/lm/fragment.cc @@ -0,0 +1,37 @@ +#include "lm/binary_format.hh" +#include "lm/model.hh" +#include "lm/left.hh" +#include "util/tokenize_piece.hh" + +template <class Model> void Query(const char *name) { + Model model(name); + std::string line; + lm::ngram::ChartState ignored; + while (getline(std::cin, line)) { + lm::ngram::RuleScore<Model> scorer(model, ignored); + for (util::TokenIter<util::SingleCharacter, true> i(line, ' '); i; ++i) { + scorer.Terminal(model.GetVocabulary().Index(*i)); + } + std::cout << scorer.Finish() << '\n'; + } +} + +int main(int argc, char *argv[]) { + if (argc != 2) { + std::cerr << "Expected model file name." << std::endl; + return 1; + } + const char *name = argv[1]; + lm::ngram::ModelType model_type = lm::ngram::PROBING; + lm::ngram::RecognizeBinary(name, model_type); + switch (model_type) { + case lm::ngram::PROBING: + Query<lm::ngram::ProbingModel>(name); + break; + case lm::ngram::REST_PROBING: + Query<lm::ngram::RestProbingModel>(name); + break; + default: + std::cerr << "Model type not supported yet." << std::endl; + } +} diff --git a/klm/lm/left.hh b/klm/lm/left.hh index c00af88a..8c27232e 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -111,7 +111,7 @@ template <class M> class RuleScore { return; } - float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; + float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1]; float *back = backoffs, *back2 = backoffs2; unsigned char next_use = out_.right.length; diff --git a/klm/lm/max_order.cc b/klm/lm/max_order.cc new file mode 100644 index 00000000..94221201 --- /dev/null +++ b/klm/lm/max_order.cc @@ -0,0 +1,6 @@ +#include "lm/max_order.hh" +#include <iostream> + +int main(int argc, char *argv[]) { + std::cerr << "KenLM was compiled with a maximum supported n-gram order set to " << KENLM_MAX_ORDER << "." << std::endl; +} diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index aff9de27..989f8324 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -1,14 +1,12 @@ -#ifndef LM_MAX_ORDER__ -#define LM_MAX_ORDER__ -namespace lm { -namespace ngram { -// If you need higher order, change this and recompile. -// Having this limit means that State can be -// (kMaxOrder - 1) * sizeof(float) bytes instead of -// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead -const unsigned char kMaxOrder = 5; - -} // namespace ngram -} // namespace lm - -#endif // LM_MAX_ORDER__ +/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM. + * If not, this is the default maximum order. + * Having this limit means that State can be + * (kMaxOrder - 1) * sizeof(float) bytes instead of + * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead + */ +#ifndef KENLM_MAX_ORDER +#define KENLM_MAX_ORDER 6 +#endif +#ifndef KENLM_ORDER_MESSAGE +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh." +#endif diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a2d31ce0..2fd20481 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -5,12 +5,14 @@ #include "lm/search_hashed.hh" #include "lm/search_trie.hh" #include "lm/read_arpa.hh" +#include "util/have.hh" #include "util/murmur_hash.hh" #include <algorithm> #include <functional> #include <numeric> #include <cmath> +#include <limits> namespace lm { namespace ngram { @@ -18,17 +20,18 @@ namespace detail { template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType; -template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) { +template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) { return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) { + size_t goal_size = util::CheckOverflow(Size(counts, config)); uint8_t *start = static_cast<uint8_t*>(base); size_t allocated = VocabularyT::Size(counts[0], config); vocab_.SetupMemory(start, allocated, counts[0], config); start += allocated; start = search_.SetupMemory(start, counts, config); - if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config)); + if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size); } template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { @@ -47,7 +50,19 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge P::Init(begin_sentence, null_context, vocab_, search_.Order()); } +namespace { +void CheckCounts(const std::vector<uint64_t> &counts) { + UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE); + if (sizeof(uint64_t) > sizeof(std::size_t)) { + for (std::vector<uint64_t>::const_iterator i = counts.begin(); i != counts.end(); ++i) { + UTIL_THROW_IF(*i > static_cast<uint64_t>(std::numeric_limits<size_t>::max()), util::OverflowException, "This model has " << *i << " " << (i - counts.begin() + 1) << "-grams which is too many for 32-bit machines."); + } + } +} +} // namespace + template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { + CheckCounts(params.counts); SetupMemory(start, params.counts, config); vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab); search_.LoadedBinary(); @@ -60,12 +75,11 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT std::vector<uint64_t> counts; // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. ReadARPACounts(f, counts); - - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); + CheckCounts(counts); if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); - std::size_t vocab_size = VocabularyT::Size(counts[0], config); + std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); @@ -73,7 +87,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT WriteWordsWrapper wrap(config.enumerate_vocab); vocab_.ConfigureEnumerate(&wrap, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get()); + wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + backing_.search.size()); } else { vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); diff --git a/klm/lm/model.hh b/klm/lm/model.hh index be872178..13ff864e 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -5,7 +5,6 @@ #include "lm/binary_format.hh" #include "lm/config.hh" #include "lm/facade.hh" -#include "lm/max_order.hh" #include "lm/quantize.hh" #include "lm/search_hashed.hh" #include "lm/search_trie.hh" @@ -42,7 +41,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod * does not include small non-mapped control structures, such as this class * itself. */ - static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); + static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); /* Load the model from a file. It may be an ARPA or binary file. Binary * files must have the format expected by this class or you'll get an diff --git a/klm/lm/partial.hh b/klm/lm/partial.hh new file mode 100644 index 00000000..1dede359 --- /dev/null +++ b/klm/lm/partial.hh @@ -0,0 +1,167 @@ +#ifndef LM_PARTIAL__ +#define LM_PARTIAL__ + +#include "lm/return.hh" +#include "lm/state.hh" + +#include <algorithm> + +#include <assert.h> + +namespace lm { +namespace ngram { + +struct ExtendReturn { + float adjust; + bool make_full; + unsigned char next_use; +}; + +template <class Model> ExtendReturn ExtendLoop( + const Model &model, + unsigned char seen, const WordIndex *add_rbegin, const WordIndex *add_rend, const float *backoff_start, + const uint64_t *pointers, const uint64_t *pointers_end, + uint64_t *&pointers_write, + float *backoff_write) { + unsigned char add_length = add_rend - add_rbegin; + + float backoff_buf[2][KENLM_MAX_ORDER - 1]; + float *backoff_in = backoff_buf[0], *backoff_out = backoff_buf[1]; + std::copy(backoff_start, backoff_start + add_length, backoff_in); + + ExtendReturn value; + value.make_full = false; + value.adjust = 0.0; + value.next_use = add_length; + + unsigned char i = 0; + unsigned char length = pointers_end - pointers; + // pointers_write is NULL means that the existing left state is full, so we should use completed probabilities. + if (pointers_write) { + // Using full context, writing to new left state. + for (; i < length; ++i) { + FullScoreReturn ret(model.ExtendLeft( + add_rbegin, add_rbegin + value.next_use, + backoff_in, + pointers[i], i + seen + 1, + backoff_out, + value.next_use)); + std::swap(backoff_in, backoff_out); + if (ret.independent_left) { + value.adjust += ret.prob; + value.make_full = true; + ++i; + break; + } + value.adjust += ret.rest; + *pointers_write++ = ret.extend_left; + if (value.next_use != add_length) { + value.make_full = true; + ++i; + break; + } + } + } + // Using some of the new context. + for (; i < length && value.next_use; ++i) { + FullScoreReturn ret(model.ExtendLeft( + add_rbegin, add_rbegin + value.next_use, + backoff_in, + pointers[i], i + seen + 1, + backoff_out, + value.next_use)); + std::swap(backoff_in, backoff_out); + value.adjust += ret.prob; + } + float unrest = model.UnRest(pointers + i, pointers_end, i + seen + 1); + // Using none of the new context. + value.adjust += unrest; + + std::copy(backoff_in, backoff_in + value.next_use, backoff_write); + return value; +} + +template <class Model> float RevealBefore(const Model &model, const Right &reveal, const unsigned char seen, bool reveal_full, Left &left, Right &right) { + assert(seen < reveal.length || reveal_full); + uint64_t *pointers_write = reveal_full ? NULL : left.pointers; + float backoff_buffer[KENLM_MAX_ORDER - 1]; + ExtendReturn value(ExtendLoop( + model, + seen, reveal.words + seen, reveal.words + reveal.length, reveal.backoff + seen, + left.pointers, left.pointers + left.length, + pointers_write, + left.full ? backoff_buffer : (right.backoff + right.length))); + if (reveal_full) { + left.length = 0; + value.make_full = true; + } else { + left.length = pointers_write - left.pointers; + value.make_full |= (left.length == model.Order() - 1); + } + if (left.full) { + for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i]; + } else { + // If left wasn't full when it came in, put words into right state. + std::copy(reveal.words + seen, reveal.words + seen + value.next_use, right.words + right.length); + right.length += value.next_use; + left.full = value.make_full || (right.length == model.Order() - 1); + } + return value.adjust; +} + +template <class Model> float RevealAfter(const Model &model, Left &left, Right &right, const Left &reveal, unsigned char seen) { + assert(seen < reveal.length || reveal.full); + uint64_t *pointers_write = left.full ? NULL : (left.pointers + left.length); + ExtendReturn value(ExtendLoop( + model, + seen, right.words, right.words + right.length, right.backoff, + reveal.pointers + seen, reveal.pointers + reveal.length, + pointers_write, + right.backoff)); + if (reveal.full) { + for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += right.backoff[i]; + right.length = 0; + value.make_full = true; + } else { + right.length = value.next_use; + value.make_full |= (right.length == model.Order() - 1); + } + if (!left.full) { + left.length = pointers_write - left.pointers; + left.full = value.make_full || (left.length == model.Order() - 1); + } + return value.adjust; +} + +template <class Model> float Subsume(const Model &model, Left &first_left, const Right &first_right, const Left &second_left, Right &second_right, const unsigned int between_length) { + assert(first_right.length < KENLM_MAX_ORDER); + assert(second_left.length < KENLM_MAX_ORDER); + assert(between_length < KENLM_MAX_ORDER - 1); + uint64_t *pointers_write = first_left.full ? NULL : (first_left.pointers + first_left.length); + float backoff_buffer[KENLM_MAX_ORDER - 1]; + ExtendReturn value(ExtendLoop( + model, + between_length, first_right.words, first_right.words + first_right.length, first_right.backoff, + second_left.pointers, second_left.pointers + second_left.length, + pointers_write, + second_left.full ? backoff_buffer : (second_right.backoff + second_right.length))); + if (second_left.full) { + for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i]; + } else { + std::copy(first_right.words, first_right.words + value.next_use, second_right.words + second_right.length); + second_right.length += value.next_use; + value.make_full |= (second_right.length == model.Order() - 1); + } + if (!first_left.full) { + first_left.length = pointers_write - first_left.pointers; + first_left.full = value.make_full || second_left.full || (first_left.length == model.Order() - 1); + } + assert(first_left.length < KENLM_MAX_ORDER); + assert(second_right.length < KENLM_MAX_ORDER); + return value.adjust; +} + +} // namespace ngram +} // namespace lm + +#endif // LM_PARTIAL__ diff --git a/klm/lm/partial_test.cc b/klm/lm/partial_test.cc new file mode 100644 index 00000000..8d309c85 --- /dev/null +++ b/klm/lm/partial_test.cc @@ -0,0 +1,199 @@ +#include "lm/partial.hh" + +#include "lm/left.hh" +#include "lm/model.hh" +#include "util/tokenize_piece.hh" + +#define BOOST_TEST_MODULE PartialTest +#include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp> + +namespace lm { +namespace ngram { +namespace { + +const char *TestLocation() { + if (boost::unit_test::framework::master_test_suite().argc < 2) { + return "test.arpa"; + } + return boost::unit_test::framework::master_test_suite().argv[1]; +} + +Config SilentConfig() { + Config config; + config.arpa_complain = Config::NONE; + config.messages = NULL; + return config; +} + +struct ModelFixture { + ModelFixture() : m(TestLocation(), SilentConfig()) {} + + RestProbingModel m; +}; + +BOOST_FIXTURE_TEST_SUITE(suite, ModelFixture) + +BOOST_AUTO_TEST_CASE(SimpleBefore) { + Left left; + left.full = false; + left.length = 0; + Right right; + right.length = 0; + + Right reveal; + reveal.length = 1; + WordIndex period = m.GetVocabulary().Index("."); + reveal.words[0] = period; + reveal.backoff[0] = -0.845098; + + BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 0, false, left, right), 0.001); + BOOST_CHECK_EQUAL(0, left.length); + BOOST_CHECK(!left.full); + BOOST_CHECK_EQUAL(1, right.length); + BOOST_CHECK_EQUAL(period, right.words[0]); + BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001); + + WordIndex more = m.GetVocabulary().Index("more"); + reveal.words[1] = more; + reveal.backoff[1] = -0.4771212; + reveal.length = 2; + BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 1, false, left, right), 0.001); + BOOST_CHECK_EQUAL(0, left.length); + BOOST_CHECK(!left.full); + BOOST_CHECK_EQUAL(2, right.length); + BOOST_CHECK_EQUAL(period, right.words[0]); + BOOST_CHECK_EQUAL(more, right.words[1]); + BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001); + BOOST_CHECK_CLOSE(-0.4771212, right.backoff[1], 0.001); +} + +BOOST_AUTO_TEST_CASE(AlsoWouldConsider) { + WordIndex would = m.GetVocabulary().Index("would"); + WordIndex consider = m.GetVocabulary().Index("consider"); + + ChartState current; + current.left.length = 1; + current.left.pointers[0] = would; + current.left.full = false; + current.right.length = 1; + current.right.words[0] = would; + current.right.backoff[0] = -0.30103; + + Left after; + after.full = false; + after.length = 1; + after.pointers[0] = consider; + + // adjustment for would consider + BOOST_CHECK_CLOSE(-1.687872 - -0.2922095 - 0.30103, RevealAfter(m, current.left, current.right, after, 0), 0.001); + + BOOST_CHECK_EQUAL(2, current.left.length); + BOOST_CHECK_EQUAL(would, current.left.pointers[0]); + BOOST_CHECK_EQUAL(false, current.left.full); + + WordIndex also = m.GetVocabulary().Index("also"); + Right before; + before.length = 1; + before.words[0] = also; + before.backoff[0] = -0.30103; + // r(would) = -0.2922095 [i would], r(would -> consider) = -1.988902 [b(would) + p(consider)] + // p(also -> would) = -2, p(also would -> consider) = -3 + BOOST_CHECK_CLOSE(-2 + 0.2922095 -3 + 1.988902, RevealBefore(m, before, 0, false, current.left, current.right), 0.001); + BOOST_CHECK_EQUAL(0, current.left.length); + BOOST_CHECK(current.left.full); + BOOST_CHECK_EQUAL(2, current.right.length); + BOOST_CHECK_EQUAL(would, current.right.words[0]); + BOOST_CHECK_EQUAL(also, current.right.words[1]); +} + +BOOST_AUTO_TEST_CASE(EndSentence) { + WordIndex loin = m.GetVocabulary().Index("loin"); + WordIndex period = m.GetVocabulary().Index("."); + WordIndex eos = m.GetVocabulary().EndSentence(); + + ChartState between; + between.left.length = 1; + between.left.pointers[0] = eos; + between.left.full = true; + between.right.length = 0; + + Right before; + before.words[0] = period; + before.words[1] = loin; + before.backoff[0] = -0.845098; + before.backoff[1] = 0.0; + + before.length = 1; + BOOST_CHECK_CLOSE(-0.0410707, RevealBefore(m, before, 0, true, between.left, between.right), 0.001); + BOOST_CHECK_EQUAL(0, between.left.length); +} + +float ScoreFragment(const RestProbingModel &model, unsigned int *begin, unsigned int *end, ChartState &out) { + RuleScore<RestProbingModel> scorer(model, out); + for (unsigned int *i = begin; i < end; ++i) { + scorer.Terminal(*i); + } + return scorer.Finish(); +} + +void CheckAdjustment(const RestProbingModel &model, float expect, const Right &before_in, bool before_full, ChartState between, const Left &after_in) { + Right before(before_in); + Left after(after_in); + after.full = false; + float got = 0.0; + for (unsigned int i = 1; i < 5; ++i) { + if (before_in.length >= i) { + before.length = i; + got += RevealBefore(model, before, i - 1, false, between.left, between.right); + } + if (after_in.length >= i) { + after.length = i; + got += RevealAfter(model, between.left, between.right, after, i - 1); + } + } + if (after_in.full) { + after.full = true; + got += RevealAfter(model, between.left, between.right, after, after.length); + } + if (before_full) { + got += RevealBefore(model, before, before.length, true, between.left, between.right); + } + // Sometimes they're zero and BOOST_CHECK_CLOSE fails for this. + BOOST_CHECK(fabs(expect - got) < 0.001); +} + +void FullDivide(const RestProbingModel &model, StringPiece str) { + std::vector<WordIndex> indices; + for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) { + indices.push_back(model.GetVocabulary().Index(*i)); + } + ChartState full_state; + float full = ScoreFragment(model, &indices.front(), &indices.back() + 1, full_state); + + ChartState before_state; + before_state.left.full = false; + RuleScore<RestProbingModel> before_scorer(model, before_state); + float before_score = 0.0; + for (unsigned int before = 0; before < indices.size(); ++before) { + for (unsigned int after = before; after <= indices.size(); ++after) { + ChartState after_state, between_state; + float after_score = ScoreFragment(model, &indices.front() + after, &indices.front() + indices.size(), after_state); + float between_score = ScoreFragment(model, &indices.front() + before, &indices.front() + after, between_state); + CheckAdjustment(model, full - before_score - after_score - between_score, before_state.right, before_state.left.full, between_state, after_state.left); + } + before_scorer.Terminal(indices[before]); + before_score = before_scorer.Finish(); + } +} + +BOOST_AUTO_TEST_CASE(Strings) { + FullDivide(m, "also would consider"); + FullDivide(m, "looking on a little more loin . </s>"); + FullDivide(m, "in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>"); +} + +BOOST_AUTO_TEST_SUITE_END() +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 3e9153e3..8ce2378a 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -17,14 +17,14 @@ namespace lm { namespace ngram { -class Config; +struct Config; /* Store values directly and don't quantize. */ class DontQuantize { public: static const ModelType kModelTypeAdd = static_cast<ModelType>(0); static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} - static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } + static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } static uint8_t LongestBits(const Config &/*config*/) { return 31; } @@ -138,9 +138,9 @@ class SeparatelyQuantize { static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); - static std::size_t Size(uint8_t order, const Config &config) { - size_t longest_table = (static_cast<size_t>(1) << static_cast<size_t>(config.prob_bits)) * sizeof(float); - size_t middle_table = (static_cast<size_t>(1) << static_cast<size_t>(config.backoff_bits)) * sizeof(float) + longest_table; + static uint64_t Size(uint8_t order, const Config &config) { + uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float); + uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table; // unigrams are currently not quantized so no need for a table. return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8; } @@ -217,7 +217,7 @@ class SeparatelyQuantize { const Bins &LongestTable() const { return longest_; } private: - Bins tables_[kMaxOrder - 1][2]; + Bins tables_[KENLM_MAX_ORDER - 1][2]; Bins longest_; diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 2d9a337d..b709fef9 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -2,14 +2,20 @@ #include "lm/blank.hh" +#include <cmath> #include <cstdlib> #include <iostream> +#include <sstream> #include <vector> #include <ctype.h> #include <string.h> #include <stdint.h> +#ifdef WIN32 +#include <float.h> +#endif + namespace lm { // 1 for '\t', '\n', and ' '. This is stricter than isspace. @@ -26,6 +32,15 @@ bool IsEntirelyWhiteSpace(const StringPiece &line) { const char kBinaryMagic[] = "mmap lm http://kheafield.com/code"; +// strtoull isn't portable enough :-( +uint64_t ReadCount(const std::string &from) { + std::stringstream stream(from); + uint64_t ret; + stream >> ret; + UTIL_THROW_IF(!stream, FormatLoadException, "Bad count " << from); + return ret; +} + } // namespace void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { @@ -47,15 +62,11 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { // So strtol doesn't go off the end of line. std::string remaining(line.data() + 6, line.size() - 6); char *end_ptr; - unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10); + unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10); if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line); if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line); ++end_ptr; - const char *start = end_ptr; - long int count = std::strtol(start, &end_ptr, 10); - if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count); - if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line); - number.push_back(count); + number.push_back(ReadCount(end_ptr)); } } @@ -93,7 +104,16 @@ void ReadBackoff(util::FilePiece &in, float &backoff) { case '\t': backoff = in.ReadFloat(); if (backoff == ngram::kExtensionBackoff) backoff = ngram::kNoExtensionBackoff; - if ((in.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); + { +#ifdef WIN32 + int float_class = _fpclass(backoff); + UTIL_THROW_IF(float_class == _FPCLASS_SNAN || float_class == _FPCLASS_QNAN || float_class == _FPCLASS_NINF || float_class == _FPCLASS_PINF, FormatLoadException, "Bad backoff " << backoff); +#else + int float_class = std::fpclassify(backoff); + UTIL_THROW_IF(float_class == FP_NAN || float_class == FP_INFINITE, FormatLoadException, "Bad backoff " << backoff); +#endif + } + UTIL_THROW_IF(in.get() != '\n', FormatLoadException, "Expected newline after backoff"); break; case '\n': backoff = ngram::kNoExtensionBackoff; diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 13942309..a1623834 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -234,7 +234,7 @@ template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, c ApplyBuild(f, counts, config, vocab, warn, build); } -template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { +template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { switch (config.rest_function) { case Config::REST_MAX: { diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 7e8c1220..a52f107b 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -74,8 +74,8 @@ template <class Value> class HashedSearch { // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} - static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { - std::size_t ret = Unigram::Size(counts[0]); + static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { + uint64_t ret = Unigram::Size(counts[0]); for (unsigned char n = 1; n < counts.size() - 1; ++n) { ret += Middle::Size(counts[n], config.probing_multiplier); } @@ -160,8 +160,8 @@ template <class Value> class HashedSearch { #endif {} - static std::size_t Size(uint64_t count) { - return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk> + static uint64_t Size(uint64_t count) { + return (count + 1) * sizeof(typename Value::Weights); // +1 for hallucinate <unk> } const typename Value::Weights &Lookup(WordIndex index) const { diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 18e80d5a..debcfd07 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -89,7 +89,7 @@ class BackoffMessages { if (!HasExtension(weights.backoff)) { weights.backoff = kExtensionBackoff; UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed."); - WriteOrThrow(unigrams, &weights, sizeof(weights)); + util::WriteOrThrow(unigrams, &weights, sizeof(weights)); } const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex)); base[write_to.array][write_to.index] += weights.backoff; @@ -180,7 +180,7 @@ const float kBadProb = std::numeric_limits<float>::infinity(); class SRISucks { public: SRISucks() { - for (BackoffMessages *i = messages_; i != messages_ + kMaxOrder - 1; ++i) + for (BackoffMessages *i = messages_; i != messages_ + KENLM_MAX_ORDER - 1; ++i) i->Init(sizeof(ProbPointer) + sizeof(WordIndex) * (i - messages_ + 1)); } @@ -196,7 +196,7 @@ class SRISucks { } void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) { - for (unsigned char i = 0; i < kMaxOrder - 1; ++i) { + for (unsigned char i = 0; i < KENLM_MAX_ORDER - 1; ++i) { it_[i] = values_[i].empty() ? NULL : &*values_[i].begin(); } messages_[0].Apply(it_, unigram_file); @@ -221,10 +221,10 @@ class SRISucks { private: // This used to be one array. Then I needed to separate it by order for quantization to work. - std::vector<float> values_[kMaxOrder - 1]; - BackoffMessages messages_[kMaxOrder - 1]; + std::vector<float> values_[KENLM_MAX_ORDER - 1]; + BackoffMessages messages_[KENLM_MAX_ORDER - 1]; - float *it_[kMaxOrder - 1]; + float *it_[KENLM_MAX_ORDER - 1]; }; class FindBlanks { @@ -337,7 +337,7 @@ struct Gram { template <class Doing> class BlankManager { public: BlankManager(unsigned char total_order, Doing &doing) : total_order_(total_order), been_length_(0), doing_(doing) { - for (float *i = basis_; i != basis_ + kMaxOrder - 1; ++i) *i = kBadProb; + for (float *i = basis_; i != basis_ + KENLM_MAX_ORDER - 1; ++i) *i = kBadProb; } void Visit(const WordIndex *to, unsigned char length, float prob) { @@ -373,10 +373,10 @@ template <class Doing> class BlankManager { private: const unsigned char total_order_; - WordIndex been_[kMaxOrder]; + WordIndex been_[KENLM_MAX_ORDER]; unsigned char been_length_; - float basis_[kMaxOrder]; + float basis_[KENLM_MAX_ORDER]; Doing &doing_; }; @@ -470,8 +470,8 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c } // namespace template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { - RecordReader inputs[kMaxOrder - 1]; - RecordReader contexts[kMaxOrder - 1]; + RecordReader inputs[KENLM_MAX_ORDER - 1]; + RecordReader contexts[KENLM_MAX_ORDER - 1]; for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Init(files.Full(i), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff))); diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 10b22ab1..1264baf5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -44,8 +44,8 @@ template <class Quant, class Bhiksha> class TrieSearch { Bhiksha::UpdateConfigFromBinary(fd, config); } - static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { - std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); + static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { + uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); for (unsigned char i = 1; i < counts.size() - 1; ++i) { ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config); } diff --git a/klm/lm/sri_test.cc b/klm/lm/sri_test.cc deleted file mode 100644 index e697d722..00000000 --- a/klm/lm/sri_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include "lm/sri.hh" - -#include <stdlib.h> - -#define BOOST_TEST_MODULE SRITest -#include <boost/test/unit_test.hpp> - -namespace lm { -namespace sri { -namespace { - -#define StartTest(word, ngram, score) \ - ret = model.FullScore( \ - state, \ - model.GetVocabulary().Index(word), \ - out);\ - BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ - BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \ - BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); - -#define AppendTest(word, ngram, score) \ - StartTest(word, ngram, score) \ - state = out; - -template <class M> void Starters(M &model) { - FullScoreReturn ret; - Model::State state(model.BeginSentenceState()); - Model::State out; - - StartTest("looking", 2, -0.4846522); - - // , probability plus <s> backoff - StartTest(",", 1, -1.383514 + -0.4149733); - // <unk> probability plus <s> backoff - StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); -} - -template <class M> void Continuation(M &model) { - FullScoreReturn ret; - Model::State state(model.BeginSentenceState()); - Model::State out; - - AppendTest("looking", 2, -0.484652); - AppendTest("on", 3, -0.348837); - AppendTest("a", 4, -0.0155266); - AppendTest("little", 5, -0.00306122); - State preserve = state; - AppendTest("the", 1, -4.04005); - AppendTest("biarritz", 1, -1.9889); - AppendTest("not_found", 0, -2.29666); - AppendTest("more", 1, -1.20632); - AppendTest(".", 2, -0.51363); - AppendTest("</s>", 3, -0.0191651); - - state = preserve; - AppendTest("more", 5, -0.00181395); - AppendTest("loin", 5, -0.0432557); -} - -BOOST_AUTO_TEST_CASE(starters) { Model m("test.arpa", 5); Starters(m); } -BOOST_AUTO_TEST_CASE(continuation) { Model m("test.arpa", 5); Continuation(m); } - -} // namespace -} // namespace sri -} // namespace lm diff --git a/klm/lm/state.hh b/klm/lm/state.hh index c7438414..551510a8 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -32,7 +32,7 @@ class State { // Call this before using raw memcmp. void ZeroRemaining() { - for (unsigned char i = length; i < kMaxOrder - 1; ++i) { + for (unsigned char i = length; i < KENLM_MAX_ORDER - 1; ++i) { words[i] = 0; backoff[i] = 0.0; } @@ -42,11 +42,13 @@ class State { // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. - WordIndex words[kMaxOrder - 1]; - float backoff[kMaxOrder - 1]; + WordIndex words[KENLM_MAX_ORDER - 1]; + float backoff[KENLM_MAX_ORDER - 1]; unsigned char length; }; +typedef State Right; + inline uint64_t hash_value(const State &state, uint64_t seed = 0) { return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed); } @@ -72,11 +74,11 @@ struct Left { } void ZeroRemaining() { - for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) + for (uint64_t * i = pointers + length; i < pointers + KENLM_MAX_ORDER - 1; ++i) *i = 0; } - uint64_t pointers[kMaxOrder - 1]; + uint64_t pointers[KENLM_MAX_ORDER - 1]; unsigned char length; bool full; }; diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 0f1ca574..d9895f89 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -36,7 +36,7 @@ bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_ } } // namespace -std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { +uint64_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits; // Extra entry for next pointer at the end. // +7 then / 8 to round up bits and convert to bytes @@ -57,7 +57,7 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) max_vocab_ = max_vocab; } -template <class Bhiksha> std::size_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { +template <class Bhiksha> uint64_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config)); } diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index eff93292..9ea3c546 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -11,7 +11,7 @@ namespace lm { namespace ngram { -class Config; +struct Config; namespace trie { struct NodeRange { @@ -49,7 +49,7 @@ class Unigram { unigram_ = static_cast<UnigramValue*>(start); } - static std::size_t Size(uint64_t count) { + static uint64_t Size(uint64_t count) { // +1 in case unknown doesn't appear. +1 for the final next. return (count + 2) * sizeof(UnigramValue); } @@ -84,7 +84,7 @@ class BitPacked { } protected: - static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); + static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); @@ -99,7 +99,7 @@ class BitPacked { template <class Bhiksha> class BitPackedMiddle : public BitPacked { public: - static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); + static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); // next_source need not be initialized. BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); @@ -128,7 +128,7 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked { class BitPackedLongest : public BitPacked { public: - static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { + static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { return BaseSize(entries, max_vocab, quant_bits); } diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index b80fed02..8663e94e 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -22,12 +22,6 @@ namespace lm { namespace ngram { namespace trie { - -void WriteOrThrow(FILE *to, const void *data, size_t size) { - assert(size); - if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); -} - namespace { typedef util::SizedIterator NGramIter; @@ -95,12 +89,12 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. if (context_begin == context_end) return out.release(); PartialIter i(context_begin); - WriteOrThrow(out.get(), i->Data(), context_size); + util::WriteOrThrow(out.get(), i->Data(), context_size); const void *previous = i->Data(); ++i; for (; i != context_end; ++i) { if (memcmp(previous, i->Data(), context_size)) { - WriteOrThrow(out.get(), i->Data(), context_size); + util::WriteOrThrow(out.get(), i->Data(), context_size); previous = i->Data(); } } @@ -116,7 +110,7 @@ struct ThrowCombine { // Useful for context files that just contain records with no value. struct FirstCombine { void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const { - WriteOrThrow(out, first, entry_size); + util::WriteOrThrow(out, first, entry_size); } }; @@ -129,10 +123,10 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f EntryCompare less(order); while (first && second) { if (less(first.Data(), second.Data())) { - WriteOrThrow(out_file.get(), first.Data(), entry_size); + util::WriteOrThrow(out_file.get(), first.Data(), entry_size); ++first; } else if (less(second.Data(), first.Data())) { - WriteOrThrow(out_file.get(), second.Data(), entry_size); + util::WriteOrThrow(out_file.get(), second.Data(), entry_size); ++second; } else { combine(entry_size, first.Data(), second.Data(), out_file.get()); @@ -140,7 +134,7 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f } } for (RecordReader &remains = (first ? first : second); remains; ++remains) { - WriteOrThrow(out_file.get(), remains.Data(), entry_size); + util::WriteOrThrow(out_file.get(), remains.Data(), entry_size); } return out_file.release(); } @@ -148,19 +142,23 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f } // namespace void RecordReader::Init(FILE *file, std::size_t entry_size) { - rewind(file); - file_ = file; + entry_size_ = entry_size; data_.reset(malloc(entry_size)); UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer"); - remains_ = true; - entry_size_ = entry_size; - ++*this; + file_ = file; + if (file) { + rewind(file); + remains_ = true; + ++*this; + } else { + remains_ = false; + } } void RecordReader::Overwrite(const void *start, std::size_t amount) { long internal = (uint8_t*)start - (uint8_t*)data_.get(); UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); - WriteOrThrow(file_, start, amount); + util::WriteOrThrow(file_, start, amount); long forward = entry_size_ - internal - amount; #if !defined(_WIN32) && !defined(_WIN64) if (forward) @@ -169,9 +167,13 @@ void RecordReader::Overwrite(const void *start, std::size_t amount) { } void RecordReader::Rewind() { - rewind(file_); - remains_ = true; - ++*this; + if (file_) { + rewind(file_); + remains_ = true; + ++*this; + } else { + remains_ = false; + } } SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh index 3036319d..2197b80c 100644 --- a/klm/lm/trie_sort.hh +++ b/klm/lm/trie_sort.hh @@ -25,12 +25,10 @@ namespace lm { class PositiveProbWarn; namespace ngram { class SortedVocabulary; -class Config; +struct Config; namespace trie { -void WriteOrThrow(FILE *to, const void *data, size_t size); - class EntryCompare : public std::binary_function<const void*, const void*, bool> { public: explicit EntryCompare(unsigned char order) : order_(order) {} @@ -107,7 +105,7 @@ class SortedFiles { util::scoped_fd unigram_; - util::scoped_FILE full_[kMaxOrder - 1], context_[kMaxOrder - 1]; + util::scoped_FILE full_[KENLM_MAX_ORDER - 1], context_[KENLM_MAX_ORDER - 1]; }; } // namespace trie diff --git a/klm/lm/value.hh b/klm/lm/value.hh index 85e53f14..ba716713 100644 --- a/klm/lm/value.hh +++ b/klm/lm/value.hh @@ -6,7 +6,7 @@ #include "lm/weights.hh" #include "util/bit_packing.hh" -#include <inttypes.h> +#include <stdint.h> namespace lm { namespace ngram { diff --git a/klm/lm/value_build.hh b/klm/lm/value_build.hh index 687a41a0..461e6a5c 100644 --- a/klm/lm/value_build.hh +++ b/klm/lm/value_build.hh @@ -10,9 +10,9 @@ namespace lm { namespace ngram { -class Config; -class BackoffValue; -class RestValue; +struct Config; +struct BackoffValue; +struct RestValue; class NoRestBuild { public: diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 5de68f16..11c27518 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -80,14 +80,14 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { buffer_.push_back(0); } -void WriteWordsWrapper::Write(int fd) { - util::SeekEnd(fd); +void WriteWordsWrapper::Write(int fd, uint64_t start) { + util::SeekOrThrow(fd, start); util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); } SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} -std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { +uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) { // Lead with the number of entries. return sizeof(uint64_t) + sizeof(uint64_t) * entries; } @@ -165,7 +165,7 @@ struct ProbingVocabularyHeader { ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} -std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { +uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) { return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index c3efcb4a..de54eb06 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -13,11 +13,11 @@ #include <vector> namespace lm { -class ProbBackoff; +struct ProbBackoff; class EnumerateVocab; namespace ngram { -class Config; +struct Config; namespace detail { uint64_t HashForVocab(const char *str, std::size_t len); @@ -35,7 +35,7 @@ class WriteWordsWrapper : public EnumerateVocab { void Add(WordIndex index, const StringPiece &str); - void Write(int fd); + void Write(int fd, uint64_t start); private: EnumerateVocab *inner_; @@ -62,7 +62,7 @@ class SortedVocabulary : public base::Vocabulary { } // Size for purposes of file writing - static size_t Size(std::size_t entries, const Config &config); + static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. WordIndex Bound() const { return bound_; } @@ -129,7 +129,7 @@ class ProbingVocabulary : public base::Vocabulary { return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; } - static size_t Size(std::size_t entries, const Config &config); + static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()). WordIndex Bound() const { return bound_; } diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh index 67841c30..e09557a7 100644 --- a/klm/lm/word_index.hh +++ b/klm/lm/word_index.hh @@ -2,8 +2,11 @@ #ifndef LM_WORD_INDEX__ #define LM_WORD_INDEX__ +#include <limits.h> + namespace lm { typedef unsigned int WordIndex; +const WordIndex kMaxWordIndex = UINT_MAX; } // namespace lm typedef lm::WordIndex LMWordIndex; diff --git a/klm/search/Jamfile b/klm/search/Jamfile new file mode 100644 index 00000000..bc95c53a --- /dev/null +++ b/klm/search/Jamfile @@ -0,0 +1,5 @@ +lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ; + +import testing ; + +unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ; diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am new file mode 100644 index 00000000..ccc5b7f6 --- /dev/null +++ b/klm/search/Makefile.am @@ -0,0 +1,11 @@ +noinst_LIBRARIES = libksearch.a + +libksearch_a_SOURCES = \ + edge_generator.cc \ + rule.cc \ + vertex.cc \ + vertex_generator.cc \ + weights.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. + diff --git a/klm/search/config.hh b/klm/search/config.hh new file mode 100644 index 00000000..ef8e2354 --- /dev/null +++ b/klm/search/config.hh @@ -0,0 +1,25 @@ +#ifndef SEARCH_CONFIG__ +#define SEARCH_CONFIG__ + +#include "search/weights.hh" +#include "util/string_piece.hh" + +namespace search { + +class Config { + public: + Config(const Weights &weights, unsigned int pop_limit) : + weights_(weights), pop_limit_(pop_limit) {} + + const Weights &GetWeights() const { return weights_; } + + unsigned int PopLimit() const { return pop_limit_; } + + private: + Weights weights_; + unsigned int pop_limit_; +}; + +} // namespace search + +#endif // SEARCH_CONFIG__ diff --git a/klm/search/context.hh b/klm/search/context.hh new file mode 100644 index 00000000..62163144 --- /dev/null +++ b/klm/search/context.hh @@ -0,0 +1,65 @@ +#ifndef SEARCH_CONTEXT__ +#define SEARCH_CONTEXT__ + +#include "lm/model.hh" +#include "search/config.hh" +#include "search/final.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "util/exception.hh" +#include "util/pool.hh" + +#include <boost/pool/object_pool.hpp> +#include <boost/ptr_container/ptr_vector.hpp> + +#include <vector> + +namespace search { + +class Weights; + +class ContextBase { + public: + explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} + + util::Pool &FinalPool() { + return final_pool_; + } + + VertexNode *NewVertexNode() { + VertexNode *ret = vertex_node_pool_.construct(); + assert(ret); + return ret; + } + + void DeleteVertexNode(VertexNode *node) { + vertex_node_pool_.destroy(node); + } + + unsigned int PopLimit() const { return pop_limit_; } + + const Weights &GetWeights() const { return weights_; } + + private: + util::Pool final_pool_; + + boost::object_pool<VertexNode> vertex_node_pool_; + + unsigned int pop_limit_; + + const Weights &weights_; +}; + +template <class Model> class Context : public ContextBase { + public: + Context(const Config &config, const Model &model) : ContextBase(config), model_(model) {} + + const Model &LanguageModel() const { return model_; } + + private: + const Model &model_; +}; + +} // namespace search + +#endif // SEARCH_CONTEXT__ diff --git a/klm/search/edge.hh b/klm/search/edge.hh new file mode 100644 index 00000000..187904bf --- /dev/null +++ b/klm/search/edge.hh @@ -0,0 +1,54 @@ +#ifndef SEARCH_EDGE__ +#define SEARCH_EDGE__ + +#include "lm/state.hh" +#include "search/header.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "util/pool.hh" + +#include <functional> + +#include <stdint.h> + +namespace search { + +// Copyable, but the copy will be shallow. +class PartialEdge : public Header { + public: + // Allow default construction for STL. + PartialEdge() {} + + PartialEdge(util::Pool &pool, Arity arity) + : Header(pool.Allocate(Size(arity, arity + 1)), arity) {} + + PartialEdge(util::Pool &pool, Arity arity, Arity chart_states) + : Header(pool.Allocate(Size(arity, chart_states)), arity) {} + + // Non-terminals + const PartialVertex *NT() const { + return reinterpret_cast<const PartialVertex*>(After()); + } + PartialVertex *NT() { + return reinterpret_cast<PartialVertex*>(After()); + } + + const lm::ngram::ChartState &CompletedState() const { + return *Between(); + } + const lm::ngram::ChartState *Between() const { + return reinterpret_cast<const lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex)); + } + lm::ngram::ChartState *Between() { + return reinterpret_cast<lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex)); + } + + private: + static std::size_t Size(Arity arity, Arity chart_states) { + return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState); + } +}; + + +} // namespace search +#endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc new file mode 100644 index 00000000..260159b1 --- /dev/null +++ b/klm/search/edge_generator.cc @@ -0,0 +1,110 @@ +#include "search/edge_generator.hh" + +#include "lm/left.hh" +#include "lm/partial.hh" +#include "search/context.hh" +#include "search/vertex.hh" + +#include <numeric> + +namespace search { + +namespace { + +template <class Model> void FastScore(const Context<Model> &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) { + lm::ngram::ChartState *between = update.Between(); + lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1]; + + float adjustment = 0.0; + const lm::ngram::ChartState &previous_reveal = previous_vertex.State(); + const PartialVertex &update_nt = update.NT()[victim]; + const lm::ngram::ChartState &update_reveal = update_nt.State(); + if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { + adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); + } + if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) { + adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); + } + if (update_nt.Complete()) { + if (update_reveal.left.full) { + before->left.full = true; + } else { + assert(update_reveal.left.length == update_reveal.right.length); + adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); + } + before->right = after->right; + // Shift the others shifted one down, covering after. + for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) { + *cover = *(cover + 1); + } + } + update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); +} + +} // namespace + +template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) { + assert(!generate_.empty()); + PartialEdge top = generate_.top(); + generate_.pop(); + PartialVertex *const top_nt = top.NT(); + const Arity arity = top.GetArity(); + + Arity victim = 0; + Arity victim_completed; + Arity incomplete; + // Select victim or return if complete. + { + Arity completed = 0; + unsigned char lowest_length = 255; + for (Arity i = 0; i != arity; ++i) { + if (top_nt[i].Complete()) { + ++completed; + } else if (top_nt[i].Length() < lowest_length) { + lowest_length = top_nt[i].Length(); + victim = i; + victim_completed = completed; + } + } + if (lowest_length == 255) { + return top; + } + incomplete = arity - completed; + } + + PartialVertex old_value(top_nt[victim]); + PartialVertex alternate_changed; + if (top_nt[victim].Split(alternate_changed)) { + PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1); + alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound()); + + alternate.SetNote(top.GetNote()); + + PartialVertex *alternate_nt = alternate.NT(); + for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i]; + alternate_nt[victim] = alternate_changed; + for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i]; + + memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1)); + + // TODO: dedupe? + generate_.push(alternate); + } + + // top is now the continuation. + FastScore(context, victim, victim - victim_completed, incomplete, old_value, top); + // TODO: dedupe? + generate_.push(top); + + // Invalid indicates no new hypothesis generated. + return PartialEdge(); +} + +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context); + +} // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh new file mode 100644 index 00000000..582c78b7 --- /dev/null +++ b/klm/search/edge_generator.hh @@ -0,0 +1,57 @@ +#ifndef SEARCH_EDGE_GENERATOR__ +#define SEARCH_EDGE_GENERATOR__ + +#include "search/edge.hh" +#include "search/note.hh" +#include "search/types.hh" + +#include <queue> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +template <class Model> class Context; + +class EdgeGenerator { + public: + EdgeGenerator() {} + + PartialEdge AllocateEdge(Arity arity) { + return PartialEdge(partial_edge_pool_, arity); + } + + void AddEdge(PartialEdge edge) { + generate_.push(edge); + } + + bool Empty() const { return generate_.empty(); } + + // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge. + template <class Model> PartialEdge Pop(Context<Model> &context); + + template <class Model, class Output> void Search(Context<Model> &context, Output &output) { + unsigned to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + PartialEdge got(Pop(context)); + if (got.Valid()) { + output.NewHypothesis(got); + --to_pop; + } + } + output.FinishedSearch(); + } + + private: + util::Pool partial_edge_pool_; + + typedef std::priority_queue<PartialEdge> Generate; + Generate generate_; +}; + +} // namespace search +#endif // SEARCH_EDGE_GENERATOR__ diff --git a/klm/search/final.hh b/klm/search/final.hh new file mode 100644 index 00000000..50e62cf2 --- /dev/null +++ b/klm/search/final.hh @@ -0,0 +1,36 @@ +#ifndef SEARCH_FINAL__ +#define SEARCH_FINAL__ + +#include "search/header.hh" +#include "util/pool.hh" + +namespace search { + +// A full hypothesis with pointers to children. +class Final : public Header { + public: + Final() {} + + Final(util::Pool &pool, Score score, Arity arity, Note note) + : Header(pool.Allocate(Size(arity)), arity) { + SetScore(score); + SetNote(note); + } + + // These are arrays of length GetArity(). + Final *Children() { + return reinterpret_cast<Final*>(After()); + } + const Final *Children() const { + return reinterpret_cast<const Final*>(After()); + } + + private: + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Final); + } +}; + +} // namespace search + +#endif // SEARCH_FINAL__ diff --git a/klm/search/header.hh b/klm/search/header.hh new file mode 100644 index 00000000..25550dbe --- /dev/null +++ b/klm/search/header.hh @@ -0,0 +1,57 @@ +#ifndef SEARCH_HEADER__ +#define SEARCH_HEADER__ + +// Header consisting of Score, Arity, and Note + +#include "search/note.hh" +#include "search/types.hh" + +#include <stdint.h> + +namespace search { + +// Copying is shallow. +class Header { + public: + bool Valid() const { return base_; } + + Score GetScore() const { + return *reinterpret_cast<const float*>(base_); + } + void SetScore(Score to) { + *reinterpret_cast<float*>(base_) = to; + } + bool operator<(const Header &other) const { + return GetScore() < other.GetScore(); + } + + Arity GetArity() const { + return *reinterpret_cast<const Arity*>(base_ + sizeof(Score)); + } + + Note GetNote() const { + return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity)); + } + void SetNote(Note to) { + *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to; + } + + protected: + Header() : base_(NULL) {} + + Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) { + *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity; + } + + static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note); + + uint8_t *After() { return base_ + kHeaderSize; } + const uint8_t *After() const { return base_ + kHeaderSize; } + + private: + uint8_t *base_; +}; + +} // namespace search + +#endif // SEARCH_HEADER__ diff --git a/klm/search/note.hh b/klm/search/note.hh new file mode 100644 index 00000000..50bed06e --- /dev/null +++ b/klm/search/note.hh @@ -0,0 +1,12 @@ +#ifndef SEARCH_NOTE__ +#define SEARCH_NOTE__ + +namespace search { + +union Note { + const void *vp; +}; + +} // namespace search + +#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc new file mode 100644 index 00000000..5b00207e --- /dev/null +++ b/klm/search/rule.cc @@ -0,0 +1,43 @@ +#include "search/rule.hh" + +#include "search/context.hh" +#include "search/final.hh" + +#include <ostream> + +#include <math.h> + +namespace search { + +template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) { + unsigned int oov_count = 0; + float prob = 0.0; + const Model &model = context.LanguageModel(); + const lm::WordIndex oov = model.GetVocabulary().NotFound(); + for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) { + lm::ngram::RuleScore<Model> scorer(model, *(writing++)); + // TODO: optimize + if (prepend_bos && (word == words.begin())) { + scorer.BeginSentence(); + } + for (; ; ++word) { + if (word == words.end()) { + prob += scorer.Finish(); + return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); + } + if (*word == kNonTerminal) break; + if (*word == oov) ++oov_count; + scorer.Terminal(*word); + } + prob += scorer.Finish(); + } +} + +template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); + +} // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh new file mode 100644 index 00000000..0ce2794d --- /dev/null +++ b/klm/search/rule.hh @@ -0,0 +1,20 @@ +#ifndef SEARCH_RULE__ +#define SEARCH_RULE__ + +#include "lm/left.hh" +#include "lm/word_index.hh" +#include "search/types.hh" + +#include <vector> + +namespace search { + +template <class Model> class Context; + +const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; + +template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out); + +} // namespace search + +#endif // SEARCH_RULE__ diff --git a/klm/search/types.hh b/klm/search/types.hh new file mode 100644 index 00000000..06eb5bfa --- /dev/null +++ b/klm/search/types.hh @@ -0,0 +1,14 @@ +#ifndef SEARCH_TYPES__ +#define SEARCH_TYPES__ + +#include <stdint.h> + +namespace search { + +typedef float Score; + +typedef uint32_t Arity; + +} // namespace search + +#endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc new file mode 100644 index 00000000..11f4631f --- /dev/null +++ b/klm/search/vertex.cc @@ -0,0 +1,42 @@ +#include "search/vertex.hh" + +#include "search/context.hh" + +#include <algorithm> +#include <functional> + +#include <assert.h> + +namespace search { + +namespace { + +struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> { + bool operator()(const VertexNode *first, const VertexNode *second) const { + return first->Bound() > second->Bound(); + } +}; + +} // namespace + +void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { + if (Complete()) { + assert(end_.Valid()); + assert(extend_.empty()); + bound_ = end_.GetScore(); + return; + } + if (extend_.size() == 1 && parent_ptr) { + *parent_ptr = extend_[0]; + extend_[0]->SortAndSet(context, parent_ptr); + context.DeleteVertexNode(this); + return; + } + for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { + (*i)->SortAndSet(context, &*i); + } + std::sort(extend_.begin(), extend_.end(), GreaterByBound()); + bound_ = extend_.front()->Bound(); +} + +} // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh new file mode 100644 index 00000000..52bc1dfe --- /dev/null +++ b/klm/search/vertex.hh @@ -0,0 +1,159 @@ +#ifndef SEARCH_VERTEX__ +#define SEARCH_VERTEX__ + +#include "lm/left.hh" +#include "search/final.hh" +#include "search/types.hh" + +#include <boost/unordered_set.hpp> + +#include <queue> +#include <vector> + +#include <stdint.h> + +namespace search { + +class ContextBase; + +class VertexNode { + public: + VertexNode() {} + + void InitRoot() { + extend_.clear(); + state_.left.full = false; + state_.left.length = 0; + state_.right.length = 0; + right_full_ = false; + end_ = Final(); + } + + lm::ngram::ChartState &MutableState() { return state_; } + bool &MutableRightFull() { return right_full_; } + + void AddExtend(VertexNode *next) { + extend_.push_back(next); + } + + void SetEnd(Final end) { + assert(!end_.Valid()); + end_ = end; + } + + void SortAndSet(ContextBase &context, VertexNode **parent_pointer); + + // Should only happen to a root node when the entire vertex is empty. + bool Empty() const { + return !end_.Valid() && extend_.empty(); + } + + bool Complete() const { + return end_.Valid(); + } + + const lm::ngram::ChartState &State() const { return state_; } + bool RightFull() const { return right_full_; } + + Score Bound() const { + return bound_; + } + + unsigned char Length() const { + return state_.left.length + state_.right.length; + } + + // Will be invalid unless this is a leaf. + const Final End() const { return end_; } + + const VertexNode &operator[](size_t index) const { + return *extend_[index]; + } + + size_t Size() const { + return extend_.size(); + } + + private: + std::vector<VertexNode*> extend_; + + lm::ngram::ChartState state_; + bool right_full_; + + Score bound_; + Final end_; +}; + +class PartialVertex { + public: + PartialVertex() {} + + explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} + + bool Empty() const { return back_->Empty(); } + + bool Complete() const { return back_->Complete(); } + + const lm::ngram::ChartState &State() const { return back_->State(); } + bool RightFull() const { return back_->RightFull(); } + + Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } + + unsigned char Length() const { return back_->Length(); } + + bool HasAlternative() const { + return index_ + 1 < back_->Size(); + } + + // Split into continuation and alternative, rendering this the continuation. + bool Split(PartialVertex &alternative) { + assert(!Complete()); + bool ret; + if (index_ + 1 < back_->Size()) { + alternative.index_ = index_ + 1; + alternative.back_ = back_; + ret = true; + } else { + ret = false; + } + back_ = &((*back_)[index_]); + index_ = 0; + return ret; + } + + const Final End() const { + return back_->End(); + } + + private: + const VertexNode *back_; + unsigned int index_; +}; + +class Vertex { + public: + Vertex() {} + + PartialVertex RootPartial() const { return PartialVertex(root_); } + + const Final BestChild() const { + PartialVertex top(RootPartial()); + if (top.Empty()) { + return Final(); + } else { + PartialVertex continuation; + while (!top.Complete()) { + top.Split(continuation); + } + return top.End(); + } + } + + private: + friend class VertexGenerator; + + VertexNode root_; +}; + +} // namespace search +#endif // SEARCH_VERTEX__ diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc new file mode 100644 index 00000000..0945fe55 --- /dev/null +++ b/klm/search/vertex_generator.cc @@ -0,0 +1,94 @@ +#include "search/vertex_generator.hh" + +#include "lm/left.hh" +#include "search/context.hh" +#include "search/edge.hh" + +#include <stdint.h> + +namespace search { + +VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { + gen.root_.InitRoot(); +} + +namespace { + +const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); + +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map<uint64_t, Trie> extend; +}; + +Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { + Trie &next = node.extend[added]; + if (!next.under) { + next.under = context.NewVertexNode(); + lm::ngram::ChartState &writing = next.under->MutableState(); + writing = state; + writing.left.full &= left_full && state.left.full; + next.under->MutableRightFull() = right_full && state.left.full; + writing.left.length = left; + writing.right.length = right; + node.under->AddExtend(next.under); + } + return next; +} + +void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { + Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); + Final *child_out = final.Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = part->End(); + + starter.under->SetEnd(final); +} + +void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + + unsigned char left = 0, right = 0; + Trie *node = &root; + while (true) { + if (left == state.left.length) { + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false); + for (; right < state.right.length; ++right) { + node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false); + } + break; + } + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false); + left++; + if (right == state.right.length) { + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true); + for (; left < state.left.length; ++left) { + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true); + } + break; + } + node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false); + right++; + } + + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); + CompleteTransition(context, *node, partial); +} + +} // namespace + +void VertexGenerator::FinishedSearch() { + Trie root; + root.under = &gen_.root_; + for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, i->second); + } + root.under->SortAndSet(context_, NULL); +} + +} // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh new file mode 100644 index 00000000..60e86112 --- /dev/null +++ b/klm/search/vertex_generator.hh @@ -0,0 +1,46 @@ +#ifndef SEARCH_VERTEX_GENERATOR__ +#define SEARCH_VERTEX_GENERATOR__ + +#include "search/edge.hh" +#include "search/vertex.hh" + +#include <boost/unordered_map.hpp> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +class ContextBase; +class Final; + +class VertexGenerator { + public: + VertexGenerator(ContextBase &context, Vertex &gen); + + void NewHypothesis(PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial))); + if (!ret.second && ret.first->second < partial) { + ret.first->second = partial; + } + } + + void FinishedSearch(); + + const Vertex &Generating() const { return gen_; } + + private: + ContextBase &context_; + + Vertex &gen_; + + typedef boost::unordered_map<uint64_t, PartialEdge> Existing; + Existing existing_; +}; + +} // namespace search +#endif // SEARCH_VERTEX_GENERATOR__ diff --git a/klm/search/weights.cc b/klm/search/weights.cc new file mode 100644 index 00000000..d65471ad --- /dev/null +++ b/klm/search/weights.cc @@ -0,0 +1,71 @@ +#include "search/weights.hh" +#include "util/tokenize_piece.hh" + +#include <cstdlib> + +namespace search { + +namespace { +struct Insert { + void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const { + std::string copy(name.data(), name.size()); + map[copy] = score; + } +}; + +struct DotProduct { + search::Score total; + DotProduct() : total(0.0) {} + + void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) { + boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name)); + if (i != map.end()) + total += score * i->second; + } +}; + +template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) { + for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) { + util::TokenIter<util::SingleCharacter> equals(*spaces, '='); + UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); + StringPiece name(*equals); + UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); + char *end; + // Assumes proper termination. + double value = std::strtod(equals->data(), &end); + UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); + UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); + op(map, name, value); + } +} + +} // namespace + +Weights::Weights(StringPiece text) { + Insert op; + Parse<Map, Insert>(text, map_, op); + lm_ = Steal("LanguageModel"); + oov_ = Steal("OOV"); + word_penalty_ = Steal("WordPenalty"); +} + +Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} + +search::Score Weights::DotNoLM(StringPiece text) const { + DotProduct dot; + Parse<const Map, DotProduct>(text, map_, dot); + return dot.total; +} + +float Weights::Steal(const std::string &str) { + Map::iterator i(map_.find(str)); + if (i == map_.end()) { + return 0.0; + } else { + float ret = i->second; + map_.erase(i); + return ret; + } +} + +} // namespace search diff --git a/klm/search/weights.hh b/klm/search/weights.hh new file mode 100644 index 00000000..df1c419f --- /dev/null +++ b/klm/search/weights.hh @@ -0,0 +1,52 @@ +// For now, the individual features are not kept. +#ifndef SEARCH_WEIGHTS__ +#define SEARCH_WEIGHTS__ + +#include "search/types.hh" +#include "util/exception.hh" +#include "util/string_piece.hh" + +#include <boost/unordered_map.hpp> + +#include <string> + +namespace search { + +class WeightParseException : public util::Exception { + public: + WeightParseException() {} + ~WeightParseException() throw() {} +}; + +class Weights { + public: + // Parses weights, sets lm_weight_, removes it from map_. + explicit Weights(StringPiece text); + + // Just the three scores we care about adding. + Weights(Score lm, Score oov, Score word_penalty); + + Score DotNoLM(StringPiece text) const; + + Score LM() const { return lm_; } + + Score OOV() const { return oov_; } + + Score WordPenalty() const { return word_penalty_; } + + // Mostly for testing. + const boost::unordered_map<std::string, Score> &GetMap() const { return map_; } + + private: + float Steal(const std::string &str); + + typedef boost::unordered_map<std::string, Score> Map; + + Map map_; + + Score lm_, oov_, word_penalty_; +}; + +} // namespace search + +#endif // SEARCH_WEIGHTS__ diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc new file mode 100644 index 00000000..4811ff06 --- /dev/null +++ b/klm/search/weights_test.cc @@ -0,0 +1,38 @@ +#include "search/weights.hh" + +#define BOOST_TEST_MODULE WeightTest +#include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp> + +namespace search { +namespace { + +#define CHECK_WEIGHT(value, string) \ + i = parsed.find(string); \ + BOOST_REQUIRE(i != parsed.end()); \ + BOOST_CHECK_CLOSE((value), i->second, 0.001); + +BOOST_AUTO_TEST_CASE(parse) { + // These are not real feature weights. + Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); + const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap(); + boost::unordered_map<std::string, search::Score>::const_iterator i; + CHECK_WEIGHT(0.0, "rarity"); + CHECK_WEIGHT(0.0, "phrase-SGT"); + CHECK_WEIGHT(9.45117, "phrase-TGS"); + CHECK_WEIGHT(2.33833, "lexical-SGT"); + BOOST_CHECK(parsed.end() == parsed.find("lm")); + BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); + CHECK_WEIGHT(-28.3317, "lexical-TGS"); + CHECK_WEIGHT(5.0, "glue?"); +} + +BOOST_AUTO_TEST_CASE(dot) { + Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); + BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); + BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); + BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); +} + +} // namespace +} // namespace search diff --git a/klm/util/Jamfile b/klm/util/Jamfile deleted file mode 100644 index 3ee2c2c2..00000000 --- a/klm/util/Jamfile +++ /dev/null @@ -1,10 +0,0 @@ -lib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc usage.cc ../..//z : <include>.. : : <include>.. ; - -import testing ; - -unit-test bit_packing_test : bit_packing_test.cc kenutil ../..///boost_unit_test_framework ; -run file_piece_test.cc kenutil ../..///boost_unit_test_framework : : file_piece.cc ; -unit-test joint_sort_test : joint_sort_test.cc kenutil ../..///boost_unit_test_framework ; -unit-test probing_hash_table_test : probing_hash_table_test.cc kenutil ../..///boost_unit_test_framework ; -unit-test sorted_uniform_test : sorted_uniform_test.cc kenutil ../..///boost_unit_test_framework ; -unit-test tokenize_piece_test : tokenize_piece_test.cc kenutil ../..///boost_unit_test_framework ; diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 5ceccf2c..5306850f 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -26,6 +26,8 @@ libklm_util_a_SOURCES = \ file_piece.cc \ mmap.cc \ murmur_hash.cc \ + pool.cc \ + string_piece.cc \ usage.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc index 07b14e26..eb635ad8 100644 --- a/klm/util/ersatz_progress.cc +++ b/klm/util/ersatz_progress.cc @@ -9,16 +9,16 @@ namespace util { namespace { const unsigned char kWidth = 100; } -ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<std::size_t>::max()), complete_(next_), out_(NULL) {} +ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<uint64_t>::max()), complete_(next_), out_(NULL) {} ErsatzProgress::~ErsatzProgress() { if (out_) Finished(); } -ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std::string &message) +ErsatzProgress::ErsatzProgress(uint64_t complete, std::ostream *to, const std::string &message) : current_(0), next_(complete / kWidth), complete_(complete), stones_written_(0), out_(to) { if (!out_) { - next_ = std::numeric_limits<std::size_t>::max(); + next_ = std::numeric_limits<uint64_t>::max(); return; } if (!message.empty()) *out_ << message << '\n'; @@ -28,14 +28,14 @@ ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std void ErsatzProgress::Milestone() { if (!out_) { current_ = 0; return; } if (!complete_) return; - unsigned char stone = std::min(static_cast<std::size_t>(kWidth), (current_ * kWidth) / complete_); + unsigned char stone = std::min(static_cast<uint64_t>(kWidth), (current_ * kWidth) / complete_); for (; stones_written_ < stone; ++stones_written_) { (*out_) << '*'; } if (stone == kWidth) { (*out_) << std::endl; - next_ = std::numeric_limits<std::size_t>::max(); + next_ = std::numeric_limits<uint64_t>::max(); out_ = NULL; } else { next_ = std::max(next_, (stone * complete_) / kWidth); diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index f709dc51..9909736d 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -4,6 +4,8 @@ #include <iostream> #include <string> +#include <stdint.h> + // Ersatz version of boost::progress so core language model doesn't depend on // boost. Also adds option to print nothing. @@ -14,7 +16,7 @@ class ErsatzProgress { ErsatzProgress(); // Null means no output. The null value is useful for passing along the ostream pointer from another caller. - explicit ErsatzProgress(std::size_t complete, std::ostream *to = &std::cerr, const std::string &message = ""); + explicit ErsatzProgress(uint64_t complete, std::ostream *to = &std::cerr, const std::string &message = ""); ~ErsatzProgress(); @@ -23,12 +25,12 @@ class ErsatzProgress { return *this; } - ErsatzProgress &operator+=(std::size_t amount) { + ErsatzProgress &operator+=(uint64_t amount) { if ((current_ += amount) >= next_) Milestone(); return *this; } - void Set(std::size_t to) { + void Set(uint64_t to) { if ((current_ = to) >= next_) Milestone(); Milestone(); } @@ -40,7 +42,7 @@ class ErsatzProgress { private: void Milestone(); - std::size_t current_, next_, complete_; + uint64_t current_, next_, complete_; unsigned char stones_written_; std::ostream *out_; diff --git a/klm/util/exception.cc b/klm/util/exception.cc index c4f8c04c..3806e6de 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -84,4 +84,7 @@ EndOfFileException::EndOfFileException() throw() { } EndOfFileException::~EndOfFileException() throw() {} +OverflowException::OverflowException() throw() {} +OverflowException::~OverflowException() throw() {} + } // namespace util diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 6d6a37cb..053a850b 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -2,9 +2,12 @@ #define UTIL_EXCEPTION__ #include <exception> +#include <limits> #include <sstream> #include <string> +#include <stdint.h> + namespace util { template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data); @@ -111,6 +114,25 @@ class EndOfFileException : public Exception { ~EndOfFileException() throw(); }; +class OverflowException : public Exception { + public: + OverflowException() throw(); + ~OverflowException() throw(); +}; + +template <unsigned len> inline std::size_t CheckOverflowInternal(uint64_t value) { + UTIL_THROW_IF(value > static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), OverflowException, "Integer overflow detected. This model is too big for 32-bit code."); + return value; +} + +template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) { + return value; +} + +inline std::size_t CheckOverflow(uint64_t value) { + return CheckOverflowInternal<sizeof(std::size_t)>(value); +} + } // namespace util #endif // UTIL_EXCEPTION__ diff --git a/klm/util/file.cc b/klm/util/file.cc index 6a3885a7..6bf879ac 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -6,6 +6,7 @@ #include <cstdio> #include <iostream> +#include <assert.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> @@ -44,6 +45,16 @@ int OpenReadOrThrow(const char *name) { return ret; } +int CreateOrThrow(const char *name) { + int ret; +#if defined(_WIN32) || defined(_WIN64) + UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); +#else + UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name); +#endif + return ret; +} + uint64_t SizeFile(int fd) { #if defined(_WIN32) || defined(_WIN64) __int64 ret = _filelengthi64(fd); @@ -101,6 +112,11 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) { } } +void WriteOrThrow(FILE *to, const void *data, std::size_t size) { + assert(size); + if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); +} + void FSyncOrThrow(int fd) { // Apparently windows doesn't have fsync? #if !defined(_WIN32) && !defined(_WIN64) @@ -109,8 +125,13 @@ void FSyncOrThrow(int fd) { } namespace { -void InternalSeek(int fd, off_t off, int whence) { +void InternalSeek(int fd, int64_t off, int whence) { +#if defined(_WIN32) || defined(_WIN64) + UTIL_THROW_IF((__int64)-1 == _lseeki64(fd, off, whence), ErrnoException, "Windows seek failed"); + +#else UTIL_THROW_IF((off_t)-1 == lseek(fd, off, whence), ErrnoException, "Seek failed"); +#endif } } // namespace @@ -133,6 +154,12 @@ std::FILE *FDOpenOrThrow(scoped_fd &file) { return ret; } +std::FILE *FOpenOrThrow(const char *path, const char *mode) { + std::FILE *ret; + UTIL_THROW_IF(!(ret = fopen(path, mode)), util::ErrnoException, "Could not fopen " << path << " for " << mode); + return ret; +} + TempMaker::TempMaker(const std::string &prefix) : base_(prefix) { base_ += "XXXXXX"; } @@ -232,7 +259,9 @@ mkstemp_and_unlink(char *tmpl) /* Modified for windows and to unlink */ // fd = open (tmpl, O_RDWR | O_CREAT | O_EXCL, _S_IREAD | _S_IWRITE); - fd = _open (tmpl, _O_RDWR | _O_CREAT | _O_TEMPORARY | _O_EXCL | _O_BINARY, _S_IREAD | _S_IWRITE); + int flags = _O_RDWR | _O_CREAT | _O_EXCL | _O_BINARY; + flags |= _O_TEMPORARY; + fd = _open (tmpl, flags, _S_IREAD | _S_IWRITE); if (fd >= 0) { errno = save_errno; @@ -250,17 +279,18 @@ mkstemp_and_unlink(char *tmpl) int mkstemp_and_unlink(char *tmpl) { int ret = mkstemp(tmpl); - if (ret == -1) return -1; - UTIL_THROW_IF(unlink(tmpl), util::ErrnoException, "Failed to delete " << tmpl); + if (ret != -1) { + UTIL_THROW_IF(unlink(tmpl), util::ErrnoException, "Failed to delete " << tmpl); + } return ret; } #endif int TempMaker::Make() const { - std::string copy(base_); - copy.push_back(0); + std::string name(base_); + name.push_back(0); int ret; - UTIL_THROW_IF(-1 == (ret = mkstemp_and_unlink(©[0])), util::ErrnoException, "Failed to make a temporary based on " << base_); + UTIL_THROW_IF(-1 == (ret = mkstemp_and_unlink(&name[0])), util::ErrnoException, "Failed to make a temporary based on " << base_); return ret; } diff --git a/klm/util/file.hh b/klm/util/file.hh index 5c57e2a9..185cb1f3 100644 --- a/klm/util/file.hh +++ b/klm/util/file.hh @@ -65,7 +65,10 @@ class scoped_FILE { std::FILE *file_; }; +// Open for read only. int OpenReadOrThrow(const char *name); +// Create file if it doesn't exist, truncate if it does. Opened for write. +int CreateOrThrow(const char *name); // Return value for SizeFile when it can't size properly. const uint64_t kBadSize = (uint64_t)-1; @@ -77,6 +80,7 @@ void ReadOrThrow(int fd, void *to, std::size_t size); std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount); void WriteOrThrow(int fd, const void *data_void, std::size_t size); +void WriteOrThrow(FILE *to, const void *data, std::size_t size); void FSyncOrThrow(int fd); @@ -87,12 +91,14 @@ void SeekEnd(int fd); std::FILE *FDOpenOrThrow(scoped_fd &file); +std::FILE *FOpenOrThrow(const char *path, const char *mode); + class TempMaker { public: explicit TempMaker(const std::string &prefix); + // These will already be unlinked for you. int Make() const; - std::FILE *MakeFile() const; private: diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index a205995a..280f438c 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -5,6 +5,8 @@ #include "util/mmap.hh" #ifdef WIN32 #include <io.h> +#else +#include <unistd.h> #endif // WIN32 #include <iostream> @@ -27,7 +29,7 @@ ParseNumberException::ParseNumberException(StringPiece value) throw() { #ifdef HAVE_ZLIB GZException::GZException(gzFile file) { int num; - *this << gzerror( file, &num) << " from zlib"; + *this << gzerror(file, &num) << " from zlib"; } #endif // HAVE_ZLIB diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index 576fd4cc..bc9e3f81 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -19,8 +19,8 @@ #include <windows.h> #include <io.h> #else -#include <unistd.h> #include <sys/mman.h> +#include <unistd.h> #endif namespace util { @@ -171,20 +171,6 @@ void *MapZeroedWrite(int fd, std::size_t size) { return MapOrThrow(size, true, kFileFlags, false, fd, 0); } -namespace { - -int CreateOrThrow(const char *name) { - int ret; -#if defined(_WIN32) || defined(_WIN64) - UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); -#else - UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name); -#endif - return ret; -} - -} // namespace - void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) { file.reset(CreateOrThrow(name)); try { diff --git a/klm/util/pool.cc b/klm/util/pool.cc new file mode 100644 index 00000000..2dffd06f --- /dev/null +++ b/klm/util/pool.cc @@ -0,0 +1,35 @@ +#include "util/pool.hh" + +#include <stdlib.h> + +namespace util { + +Pool::Pool() { + current_ = NULL; + current_end_ = NULL; +} + +Pool::~Pool() { + FreeAll(); +} + +void Pool::FreeAll() { + for (std::vector<void *>::const_iterator i(free_list_.begin()); i != free_list_.end(); ++i) { + free(*i); + } + free_list_.clear(); + current_ = NULL; + current_end_ = NULL; +} + +void *Pool::More(std::size_t size) { + std::size_t amount = std::max(static_cast<size_t>(32) << free_list_.size(), size); + uint8_t *ret = static_cast<uint8_t*>(malloc(amount)); + if (!ret) throw std::bad_alloc(); + free_list_.push_back(ret); + current_ = ret + size; + current_end_ = ret + amount; + return ret; +} + +} // namespace util diff --git a/klm/util/pool.hh b/klm/util/pool.hh new file mode 100644 index 00000000..72f8a0c8 --- /dev/null +++ b/klm/util/pool.hh @@ -0,0 +1,45 @@ +// Very simple pool. It can only allocate memory. And all of the memory it +// allocates must be freed at the same time. + +#ifndef UTIL_POOL__ +#define UTIL_POOL__ + +#include <vector> + +#include <stdint.h> + +namespace util { + +class Pool { + public: + Pool(); + + ~Pool(); + + void *Allocate(std::size_t size) { + void *ret = current_; + current_ += size; + if (current_ < current_end_) { + return ret; + } else { + return More(size); + } + } + + void FreeAll(); + + private: + void *More(std::size_t size); + + std::vector<void *> free_list_; + + uint8_t *current_, *current_end_; + + // no copying + Pool(const Pool &); + Pool &operator=(const Pool &); +}; + +} // namespace util + +#endif // UTIL_POOL__ diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 3354b68e..4a8aff35 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -8,6 +8,7 @@ #include <functional> #include <assert.h> +#include <stdint.h> namespace util { @@ -42,8 +43,8 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry typedef EqualT Equal; public: - static std::size_t Size(std::size_t entries, float multiplier) { - std::size_t buckets = std::max(entries + 1, static_cast<std::size_t>(multiplier * static_cast<float>(entries))); + static uint64_t Size(uint64_t entries, float multiplier) { + uint64_t buckets = std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries))); return buckets * sizeof(Entry); } diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc new file mode 100644 index 00000000..b422cefc --- /dev/null +++ b/klm/util/string_piece.cc @@ -0,0 +1,192 @@ +// Copyright 2004 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in string_piece.hh. + +#include "util/string_piece.hh" + +#include <algorithm> + +#include <limits.h> + +#ifndef HAVE_ICU + +typedef StringPiece::size_type size_type; + +void StringPiece::CopyToString(std::string* target) const { + target->assign(ptr_, length_); +} + +size_type StringPiece::find(const StringPiece& s, size_type pos) const { + if (length_ < 0 || pos > static_cast<size_type>(length_)) + return npos; + + const char* result = std::search(ptr_ + pos, ptr_ + length_, + s.ptr_, s.ptr_ + s.length_); + const size_type xpos = result - ptr_; + return xpos + s.length_ <= length_ ? xpos : npos; +} + +size_type StringPiece::find(char c, size_type pos) const { + if (length_ <= 0 || pos >= static_cast<size_type>(length_)) { + return npos; + } + const char* result = std::find(ptr_ + pos, ptr_ + length_, c); + return result != ptr_ + length_ ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(const StringPiece& s, size_type pos) const { + if (length_ < s.length_) return npos; + const size_t ulen = length_; + if (s.length_ == 0) return std::min(ulen, pos); + + const char* last = ptr_ + std::min(ulen - s.length_, pos) + s.length_; + const char* result = std::find_end(ptr_, last, s.ptr_, s.ptr_ + s.length_); + return result != last ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(char c, size_type pos) const { + if (length_ <= 0) return npos; + for (int i = std::min(pos, static_cast<size_type>(length_ - 1)); + i >= 0; --i) { + if (ptr_[i] == c) { + return i; + } + } + return npos; +} + +// For each character in characters_wanted, sets the index corresponding +// to the ASCII code of that character to 1 in table. This is used by +// the find_.*_of methods below to tell whether or not a character is in +// the lookup table in constant time. +// The argument `table' must be an array that is large enough to hold all +// the possible values of an unsigned char. Thus it should be be declared +// as follows: +// bool table[UCHAR_MAX + 1] +static inline void BuildLookupTable(const StringPiece& characters_wanted, + bool* table) { + const size_type length = characters_wanted.length(); + const char* const data = characters_wanted.data(); + for (size_type i = 0; i < length; ++i) { + table[static_cast<unsigned char>(data[i])] = true; + } +} + +size_type StringPiece::find_first_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (lookup[static_cast<unsigned char>(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + if (s.length_ == 0) + return 0; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (!lookup[static_cast<unsigned char>(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (; pos < length_; ++pos) { + if (ptr_[pos] != c) { + return pos; + } + } + return npos; +} + +size_type StringPiece::find_last_of(const StringPiece& s, size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (lookup[static_cast<unsigned char>(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + size_type i = std::min(pos, length_ - 1); + if (s.length_ == 0) + return i; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (; ; --i) { + if (!lookup[static_cast<unsigned char>(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (ptr_[i] != c) + return i; + if (i == 0) + break; + } + return npos; +} + +StringPiece StringPiece::substr(size_type pos, size_type n) const { + if (pos > length_) pos = length_; + if (n > length_ - pos) n = length_ - pos; + return StringPiece(ptr_ + pos, n); +} + +const size_type StringPiece::npos = size_type(-1); + +#endif // !HAVE_ICU diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh index 5de053aa..be6a643d 100644 --- a/klm/util/string_piece.hh +++ b/klm/util/string_piece.hh @@ -85,6 +85,11 @@ U_NAMESPACE_BEGIN #include <string> #include <string.h> +#ifdef WIN32 +#undef max +#undef min +#endif + class StringPiece { public: typedef size_t size_type; diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index c7e1c863..4a7f5460 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -54,6 +54,18 @@ class AnyCharacter { StringPiece chars_; }; +class AnyCharacterLast { + public: + explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {} + + StringPiece Find(const StringPiece &in) const { + return StringPiece(std::find_end(in.data(), in.data() + in.size(), chars_.data(), chars_.data() + chars_.size()), 1); + } + + private: + StringPiece chars_; +}; + template <class Find, bool SkipEmpty = false> class TokenIter : public boost::iterator_facade<TokenIter<Find, SkipEmpty>, const StringPiece, boost::forward_traversal_tag> { public: TokenIter() {} |