From d04c0ca2d9df0e147239b18e90650ca8bd51d594 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 18 Jan 2011 15:55:40 -0500 Subject: new version of klm --- klm/lm/binary_format.cc | 73 ++++++---- klm/lm/binary_format.hh | 22 ++- klm/lm/blank.hh | 13 ++ klm/lm/build_binary.cc | 10 +- klm/lm/config.hh | 4 +- klm/lm/model.cc | 117 ++++++++------- klm/lm/model.hh | 4 +- klm/lm/model_test.cc | 74 +++++++--- klm/lm/ngram_query.cc | 45 +++++- klm/lm/search_hashed.cc | 64 +++++---- klm/lm/search_hashed.hh | 7 +- klm/lm/search_trie.cc | 369 +++++++++++++++++++++++++++++++++++------------- klm/lm/search_trie.hh | 11 +- klm/lm/trie.cc | 13 +- klm/lm/trie.hh | 7 +- klm/lm/vocab.cc | 14 +- klm/lm/vocab.hh | 7 +- 17 files changed, 582 insertions(+), 272 deletions(-) create mode 100644 klm/lm/blank.hh (limited to 'klm/lm') diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 69a06355..3d9700da 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -18,8 +18,8 @@ namespace lm { namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; -const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 1\n\0"; -const long int kMagicVersion = 1; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 3\n\0"; +const long int kMagicVersion = 2; // Test values. struct Sanity { @@ -76,6 +76,45 @@ void WriteHeader(void *to, const Parameters ¶ms) { } } // namespace + +uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { + if (config.write_mmap) { + std::size_t total = TotalHeaderSize(order) + memory_size; + backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED); + return reinterpret_cast(backing.vocab.get()) + TotalHeaderSize(order); + } else { + backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + return reinterpret_cast(backing.vocab.get()); + } +} + +uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::vector &counts, std::size_t memory_size, Backing &backing) { + if (config.write_mmap) { + // header and vocab share the same mmap. The header is written here because we know the counts. + Parameters params; + params.counts = counts; + params.fixed.order = counts.size(); + params.fixed.probing_multiplier = config.probing_multiplier; + params.fixed.model_type = model_type; + params.fixed.has_vocabulary = config.include_vocab; + WriteHeader(backing.vocab.get(), params); + + // Grow the file to accomodate the search, using zeros. + if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size)) + UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed"); + + // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. + off_t page_size = sysconf(_SC_PAGE_SIZE); + off_t alignment_cruft = backing.vocab.size() % page_size; + backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); + + return reinterpret_cast(backing.search.get()) + alignment_cruft; + } else { + backing.search.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + return reinterpret_cast(backing.search.get()); + } +} + namespace detail { bool IsBinaryFormat(int fd) { @@ -128,7 +167,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t if (file_size != util::kBadSize && static_cast(file_size) < total_map) UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); - util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.memory); + util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search); if (config.enumerate_vocab && !params.fixed.has_vocabulary) UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); @@ -137,33 +176,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words"); } - return reinterpret_cast(backing.memory.get()) + TotalHeaderSize(params.counts.size()); -} - -uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector &counts, std::size_t memory_size, Backing &backing) { - if (config.write_mmap) { - std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size; - // Write out an mmap file. - backing.memory.reset(util::MapZeroedWrite(config.write_mmap, total_map, backing.file), total_map, util::scoped_memory::MMAP_ALLOCATED); - - Parameters params; - params.counts = counts; - params.fixed.order = counts.size(); - params.fixed.probing_multiplier = config.probing_multiplier; - params.fixed.model_type = model_type; - params.fixed.has_vocabulary = config.include_vocab; - - WriteHeader(backing.memory.get(), params); - - if (params.fixed.has_vocabulary) { - if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) - UTIL_THROW(util::ErrnoException, "Failed to seek in binary file " << config.write_mmap << " to vocab words"); - } - return reinterpret_cast(backing.memory.get()) + TotalHeaderSize(counts.size()); - } else { - backing.memory.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); - return reinterpret_cast(backing.memory.get()); - } + return reinterpret_cast(backing.search.get()) + TotalHeaderSize(params.counts.size()); } void ComplainAboutARPA(const Config &config, ModelType model_type) { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index a43c883c..2d66f813 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -35,10 +35,16 @@ struct Parameters { struct Backing { // File behind memory, if any. util::scoped_fd file; + // Vocabulary lookup table. Not to be confused with the vocab words themselves. + util::scoped_memory vocab; // Raw block of memory backing the language model data structures - util::scoped_memory memory; + util::scoped_memory search; }; +uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); +// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. +uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::vector &counts, std::size_t memory_size, Backing &backing); + namespace detail { bool IsBinaryFormat(int fd); @@ -49,8 +55,6 @@ void MatchCheck(ModelType model_type, const Parameters ¶ms); uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing); -uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector &counts, std::size_t memory_size, Backing &backing); - void ComplainAboutARPA(const Config &config, ModelType model_type); } // namespace detail @@ -61,13 +65,12 @@ template void LoadLM(const char *file, const Config &config, To &to) Backing &backing = to.MutableBacking(); backing.file.reset(util::OpenReadOrThrow(file)); - Parameters params; - try { if (detail::IsBinaryFormat(backing.file.get())) { + Parameters params; detail::ReadHeader(backing.file.get(), params); detail::MatchCheck(To::kModelType, params); - // Replace the probing_multiplier. + // Replace the run-time configured probing_multiplier with the one in the file. Config new_config(config); new_config.probing_multiplier = params.fixed.probing_multiplier; std::size_t memory_size = To::Size(params.counts, new_config); @@ -75,12 +78,7 @@ template void LoadLM(const char *file, const Config &config, To &to) to.InitializeFromBinary(start, params, new_config, backing.file.get()); } else { detail::ComplainAboutARPA(config, To::kModelType); - util::FilePiece f(backing.file.release(), file, config.messages); - ReadARPACounts(f, params.counts); - std::size_t memory_size = To::Size(params.counts, config); - uint8_t *start = detail::SetupZeroed(config, To::kModelType, params.counts, memory_size, backing); - - to.InitializeFromARPA(file, f, start, params, config); + to.InitializeFromARPA(file, config); } } catch (util::Exception &e) { e << " in file " << file; diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh new file mode 100644 index 00000000..639bc98b --- /dev/null +++ b/klm/lm/blank.hh @@ -0,0 +1,13 @@ +#ifndef LM_BLANK__ +#define LM_BLANK__ +#include + +namespace lm { +namespace ngram { + +const float kBlankProb = -std::numeric_limits::quiet_NaN(); +const float kBlankBackoff = std::numeric_limits::infinity(); + +} // namespace ngram +} // namespace lm +#endif // LM_BLANK__ diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index ec034640..b340797b 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -22,9 +22,9 @@ void Usage(const char *name) { "on-disk sort to save memory.\n" "-t is the temporary directory prefix. Default is the output file name.\n" "-m is the amount of memory to use, in MB. Default is 1024MB (1GB).\n\n" -"sorted is like probing but uses a sorted uniform map instead of a hash table.\n" +/*"sorted is like probing but uses a sorted uniform map instead of a hash table.\n" "It uses more memory than trie and is also slower, so there's no real reason to\n" -"use it.\n\n" +"use it.\n\n"*/ "See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" "Passing only an input file will print memory usage of each data structure.\n" "If the ARPA file does not have , -u sets 's probability; default 0.0.\n"; @@ -52,13 +52,13 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::size_t probing_size = ProbingModel::Size(counts, config); // probing is always largest so use it to determine number of columns. long int length = std::max(5, lrint(ceil(log10(probing_size)))); - std::cout << "Memory usage:\ntype "; + std::cout << "Memory estimate:\ntype "; // right align bytes. for (long int i = 0; i < length - 5; ++i) std::cout << ' '; std::cout << "bytes\n" "probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n" - "trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n" - "sorted " << std::setw(length) << SortedModel::Size(counts, config) << "\n"; + "trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n"; +/* "sorted " << std::setw(length) << SortedModel::Size(counts, config) << "\n";*/ } } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 88240b5f..767fa5f9 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -39,7 +39,7 @@ struct Config { // this. Time is probing_multiplier / (probing_multiplier - 1). No effect // for sorted variant. // If you find yourself setting this to a low number, consider using the - // Sorted version instead which has lower memory consumption. + // TrieModel which has lower memory consumption. float probing_multiplier; // Amount of memory to use for building. The actual memory usage will be @@ -53,7 +53,7 @@ struct Config { // defaults to input file name. const char *temporary_directory_prefix; - // Level of complaining to do when an ARPA instead of a binary format. + // Level of complaining to do when loading from ARPA instead of binary format. typedef enum {ALL, EXPENSIVE, NONE} ARPALoadComplain; ARPALoadComplain arpa_complain; diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 421e72fa..c7ba4908 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -1,5 +1,6 @@ #include "lm/model.hh" +#include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/search_hashed.hh" #include "lm/search_trie.hh" @@ -21,9 +22,6 @@ size_t hash_value(const State &state) { namespace detail { template size_t GenericModel::Size(const std::vector &counts, const Config &config) { - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); - if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); - if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } @@ -59,17 +57,31 @@ template void GenericModel void GenericModel::InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters ¶ms, const Config &config) { - SetupMemory(start, params.counts, config); +template void GenericModel::InitializeFromARPA(const char *file, const Config &config) { + // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. + util::FilePiece f(backing_.file.release(), file, config.messages); + std::vector counts; + // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed with search_.VariableSizeLoad + ReadARPACounts(f, counts); + + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); + + std::size_t vocab_size = VocabularyT::Size(counts[0], config); + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); if (config.write_mmap) { - WriteWordsWrapper wrap(config.enumerate_vocab, backing_.file.get()); - vocab_.ConfigureEnumerate(&wrap, params.counts[0]); - search_.InitializeFromARPA(file, f, params.counts, config, vocab_); + WriteWordsWrapper wrap(config.enumerate_vocab); + vocab_.ConfigureEnumerate(&wrap, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + wrap.Write(backing_.file.get()); } else { - vocab_.ConfigureEnumerate(config.enumerate_vocab, params.counts[0]); - search_.InitializeFromARPA(file, f, params.counts, config, vocab_); + vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); } + // TODO: fail faster? if (!vocab_.SawUnk()) { switch(config.unknown_missing) { @@ -89,46 +101,49 @@ template void GenericModel 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << search_.unigram.Unknown().backoff); } template FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { - unsigned char backoff_start; - FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, backoff_start, out_state); - if (backoff_start - 1 < in_state.valid_length_) { - ret.prob = std::accumulate(in_state.backoff_ + backoff_start - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); + FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, out_state); + if (ret.ngram_length - 1 < in_state.valid_length_) { + ret.prob = std::accumulate(in_state.backoff_ + ret.ngram_length - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); } return ret; } template FullScoreReturn GenericModel::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { - unsigned char backoff_start; context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); - FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, backoff_start, out_state); - ret.prob += SlowBackoffLookup(context_rbegin, context_rend, backoff_start); + FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state); + ret.prob += SlowBackoffLookup(context_rbegin, context_rend, ret.ngram_length); return ret; } template void GenericModel::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { + // Generate a state from context. context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); - if (context_rend == context_rbegin || *context_rbegin == 0) { + if (context_rend == context_rbegin) { out_state.valid_length_ = 0; return; } float ignored_prob; typename Search::Node node; search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node); + // Tricky part is that an entry might be blank, but out_state.valid_length_ always has the last non-blank n-gram length. + out_state.valid_length_ = 1; float *backoff_out = out_state.backoff_ + 1; - const WordIndex *i = context_rbegin + 1; - for (; i < context_rend; ++i, ++backoff_out) { - if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) { - out_state.valid_length_ = i - context_rbegin; - std::copy(context_rbegin, i, out_state.history_); + const typename Search::Middle *mid = &*search_.middle.begin(); + for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) { + if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) { + std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); return; } + if (*backoff_out != kBlankBackoff) { + out_state.valid_length_ = i - context_rbegin + 1; + } else { + *backoff_out = 0.0; + } } - std::copy(context_rbegin, context_rend, out_state.history_); - out_state.valid_length_ = static_cast(context_rend - context_rbegin); + std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); } template float GenericModel::SlowBackoffLookup( @@ -148,7 +163,7 @@ template float GenericModel FullScoreReturn GenericModel FullScoreReturn GenericModel::const_iterator mid_iter = search_.middle.begin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { - // Ran out of history. No backoff. - backoff_start = P::Order(); - std::copy(context_rbegin, context_rend, out_state.history_ + 1); - ret.ngram_length = out_state.valid_length_ = (context_rend - context_rbegin) + 1; + // Ran out of history. Typically no backoff, but this could be a blank. + out_state.valid_length_ = ret.ngram_length; + std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); // ret.prob was already set. return ret; } if (mid_iter == search_.middle.end()) break; + float revert = ret.prob; if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) { // Didn't find an ngram using hist_iter. - // The history used in the found n-gram is [context_rbegin, hist_iter). - std::copy(context_rbegin, hist_iter, out_state.history_ + 1); - // Therefore, we found a (hist_iter - context_rbegin + 1)-gram including the last word. - ret.ngram_length = out_state.valid_length_ = (hist_iter - context_rbegin) + 1; - backoff_start = mid_iter - search_.middle.begin() + 1; + std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); + out_state.valid_length_ = ret.ngram_length; // ret.prob was already set. return ret; } + if (*backoff_out == kBlankBackoff) { + *backoff_out = 0.0; + ret.prob = revert; + } else { + ret.ngram_length = hist_iter - context_rbegin + 2; + } } - // It passed every lookup in search_.middle. That means it's at least a (P::Order() - 1)-gram. - // All that's left is to check search_.longest. + // It passed every lookup in search_.middle. All that's left is to check search_.longest. if (!search_.LookupLongest(*hist_iter, ret.prob, node)) { - // It's an (P::Order()-1)-gram - std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); - ret.ngram_length = out_state.valid_length_ = P::Order() - 1; - backoff_start = P::Order() - 1; + //assert(ret.ngram_length <= P::Order() - 1); + out_state.valid_length_ = ret.ngram_length; + std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); // ret.prob was already set. return ret; } - // It's an P::Order()-gram + // It's an P::Order()-gram. There is no blank in longest_. // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much. std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); out_state.valid_length_ = P::Order() - 1; ret.ngram_length = P::Order(); - backoff_start = P::Order(); return ret; } diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 53e5773d..8183bdf5 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -102,14 +102,14 @@ template class GenericModel : public base::Mod float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const; - FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, unsigned char &backoff_start, State &out_state) const; + FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; // Appears after Size in the cc file. void SetupMemory(void *start, const std::vector &counts, const Config &config); void InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd); - void InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters ¶ms, const Config &config); + void InitializeFromARPA(const char *file, const Config &config); Backing &MutableBacking() { return backing_; } diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index b5125a95..89bbf2e8 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -33,7 +33,7 @@ template void Starters(const M &model) { // , probability plus backoff StartTest(",", 1, -1.383514 + -0.4149733); // probability plus backoff - StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); + StartTest("this_is_not_found", 1, -1.995635 + -0.4149733); } template void Continuation(const M &model) { @@ -48,8 +48,8 @@ template void Continuation(const M &model) { State preserve = state; AppendTest("the", 1, -4.04005); AppendTest("biarritz", 1, -1.9889); - AppendTest("not_found", 0, -2.29666); - AppendTest("more", 1, -1.20632); + AppendTest("not_found", 1, -2.29666); + AppendTest("more", 1, -1.20632 - 20.0); AppendTest(".", 2, -0.51363); AppendTest("", 3, -0.0191651); @@ -58,6 +58,42 @@ template void Continuation(const M &model) { AppendTest("loin", 5, -0.0432557); } +template void Blanks(const M &model) { + FullScoreReturn ret; + State state(model.NullContextState()); + State out; + AppendTest("also", 1, -1.687872); + AppendTest("would", 2, -2); + AppendTest("consider", 3, -3); + State preserve = state; + AppendTest("higher", 4, -4); + AppendTest("looking", 5, -5); + + state = preserve; + AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103); + + state = model.NullContextState(); + // higher looking is a blank. + AppendTest("higher", 1, -1.509559); + AppendTest("looking", 1, -1.285941 - 0.30103); + AppendTest("not_found", 1, -1.995635 - 0.4771212); +} + +template void Unknowns(const M &model) { + FullScoreReturn ret; + State state(model.NullContextState()); + State out; + + AppendTest("not_found", 1, -1.995635); + State preserve = state; + AppendTest("not_found2", 2, -15.0); + AppendTest("not_found3", 2, -15.0 - 2.0); + + state = preserve; + AppendTest("however", 2, -4); + AppendTest("not_found3", 3, -6); +} + #define StatelessTest(word, provide, ngram, score) \ ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ @@ -103,16 +139,23 @@ template void Stateless(const M &model) { // biarritz StatelessTest(6, 1, 1, -1.9889); // not found - StatelessTest(7, 1, 0, -2.29666); - StatelessTest(7, 0, 0, -1.995635); + StatelessTest(7, 1, 1, -2.29666); + StatelessTest(7, 0, 1, -1.995635); WordIndex unk[1]; unk[0] = 0; model.GetState(unk, unk + 1, state); - BOOST_CHECK_EQUAL(0, state.valid_length_); + BOOST_CHECK_EQUAL(1, state.valid_length_); + BOOST_CHECK_EQUAL(static_cast(0), state.history_[0]); } -//const char *kExpectedOrderProbing[] = {"", ",", ".", "", "", "a", "also", "beyond", "biarritz", "call", "concerns", "consider", "considering", "for", "higher", "however", "i", "immediate", "in", "is", "little", "loin", "look", "looking", "more", "on", "screening", "small", "the", "to", "watch", "watching", "what", "would"}; +template void Everything(const M &m) { + Starters(m); + Continuation(m); + Blanks(m); + Unknowns(m); + Stateless(m); +} class ExpectEnumerateVocab : public EnumerateVocab { public: @@ -148,18 +191,16 @@ template void LoadingTest() { config.probing_multiplier = 2.0; ModelT m("test.arpa", config); enumerate.Check(m.GetVocabulary()); - Starters(m); - Continuation(m); - Stateless(m); + Everything(m); } BOOST_AUTO_TEST_CASE(probing) { LoadingTest(); } -BOOST_AUTO_TEST_CASE(sorted) { +/*BOOST_AUTO_TEST_CASE(sorted) { LoadingTest(); -} +}*/ BOOST_AUTO_TEST_CASE(trie) { LoadingTest(); } @@ -175,24 +216,23 @@ template void BinaryTest() { ModelT copy_model("test.arpa", config); enumerate.Check(copy_model.GetVocabulary()); enumerate.Clear(); + Everything(copy_model); } config.write_mmap = NULL; ModelT binary("test.binary", config); enumerate.Check(binary.GetVocabulary()); - Starters(binary); - Continuation(binary); - Stateless(binary); + Everything(binary); unlink("test.binary"); } BOOST_AUTO_TEST_CASE(write_and_read_probing) { BinaryTest(); } -BOOST_AUTO_TEST_CASE(write_and_read_sorted) { +/*BOOST_AUTO_TEST_CASE(write_and_read_sorted) { BinaryTest(); -} +}*/ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BinaryTest(); } diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 74457a74..3fa8cb03 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -1,3 +1,4 @@ +#include "lm/enumerate_vocab.hh" #include "lm/model.hh" #include @@ -44,29 +45,61 @@ template void Query(const Model &model) { bool got = false; while (std::cin >> word) { got = true; - ret = model.FullScore(state, model.GetVocabulary().Index(word), out); + lm::WordIndex vocab = model.GetVocabulary().Index(word); + ret = model.FullScore(state, vocab, out); total += ret.prob; - std::cout << word << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << ' '; + std::cout << word << '=' << vocab << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\n'; state = out; if (std::cin.get() == '\n') break; } if (!got && !std::cin) break; ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); total += ret.prob; - std::cout << " " << static_cast(ret.ngram_length) << ' ' << ret.prob << ' '; + std::cout << "=" << model.GetVocabulary().EndSentence() << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\n'; std::cout << "Total: " << total << '\n'; } PrintUsage("After queries:\n"); } +class PrintVocab : public lm::ngram::EnumerateVocab { + public: + void Add(lm::WordIndex index, const StringPiece &str) { + std::cerr << "vocab " << index << ' ' << str << '\n'; + } +}; + +template void Query(const char *name) { + lm::ngram::Config config; + PrintVocab printer; + config.enumerate_vocab = &printer; + Model model(name, config); + Query(model); +} + int main(int argc, char *argv[]) { if (argc < 2) { std::cerr << "Pass language model name." << std::endl; return 0; } - { - lm::ngram::Model ngram(argv[1]); - Query(ngram); + lm::ngram::ModelType model_type; + if (lm::ngram::RecognizeBinary(argv[1], model_type)) { + switch(model_type) { + case lm::ngram::HASH_PROBING: + Query(argv[1]); + break; + case lm::ngram::HASH_SORTED: + Query(argv[1]); + break; + case lm::ngram::TRIE_SORTED: + Query(argv[1]); + break; + default: + std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; + abort(); + } + } else { + Query(argv[1]); } + PrintUsage("Total time including destruction:\n"); } diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 9cb662a6..9200aeb6 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -1,5 +1,6 @@ #include "lm/search_hashed.hh" +#include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/read_arpa.hh" #include "lm/vocab.hh" @@ -13,34 +14,30 @@ namespace ngram { namespace { -/* All of the entropy is in low order bits and boost::hash does poorly with - * these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a - * stable point: 0. But 0 is which won't be queried here anyway. - */ -inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { - uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast(next) * 17894857484156487943ULL); - return ret; -} - -uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) { - if (word == word_end) return 0; - uint64_t current = static_cast(*word); - for (++word; word != word_end; ++word) { - current = CombineWordHash(current, *word); - } - return current; -} - -template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) { +template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector &middle, Store &store) { + ReadNGramHeader(f, n); + ProbBackoff blank; + blank.prob = kBlankProb; + blank.backoff = kBlankBackoff; // vocab ids of words in reverse order WordIndex vocab_ids[n]; + uint64_t keys[n - 1]; typename Store::Packing::Value value; + typename Middle::ConstIterator found; for (size_t i = 0; i < count; ++i) { ReadNGram(f, n, vocab, vocab_ids, value); - uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n); - store.Insert(Store::Packing::Make(key, value)); + keys[0] = detail::CombineWordHash(static_cast(*vocab_ids), vocab_ids[1]); + for (unsigned int h = 1; h < n - 1; ++h) { + keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); + } + store.Insert(Store::Packing::Make(keys[n-2], value)); + // Go back and insert blanks. + for (int lower = n - 3; lower >= 0; --lower) { + if (middle[lower].Find(keys[lower], found)) break; + middle[lower].Insert(Middle::Packing::Make(keys[lower], blank)); + } } store.FinishedInserting(); @@ -49,17 +46,28 @@ template void ReadNGrams(util::FilePiece &f, const unsi } // namespace namespace detail { -template template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &/*config*/, Voc &vocab) { +template template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing) { + // TODO: fix sorted. + SetupMemory(GrowForSearch(config, HASH_PROBING, counts, Size(counts, config), backing), counts, config); + Read1Grams(f, counts[0], vocab, unigram.Raw()); - // Read the n-grams. - for (unsigned int n = 2; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle[n-2]); + + try { + for (unsigned int n = 2; n < counts.size(); ++n) { + ReadNGrams(f, n, counts[n-1], vocab, middle, middle[n-2]); + } + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, longest); + } catch (util::ProbingSizeException &e) { + UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces. "); } - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, longest); } -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab); -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, SortedVocabulary &vocab); +template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); +template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, SortedVocabulary &vocab, Backing &backing); + +SortedHashedSearch::SortedHashedSearch() { + UTIL_THROW(util::Exception, "Sorted is broken at the moment, sorry"); +} } // namespace detail } // namespace ngram diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 1ee2b9e9..6dc11fb3 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -17,10 +17,11 @@ namespace util { class FilePiece; } namespace lm { namespace ngram { +struct Backing; namespace detail { inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { - uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast(next) * 17894857484156487943ULL); + uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast(1 + next) * 17894857484156487943ULL); return ret; } @@ -91,7 +92,7 @@ template struct TemplateHashedSearch : public Ha return start; } - template void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab); + template void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing); bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { node = CombineWordHash(node, word); @@ -145,6 +146,8 @@ struct ProbingHashedSearch : public TemplateHashedSearch< struct SortedHashedSearch : public TemplateHashedSearch< util::SortedUniformMap >, util::SortedUniformMap > > { + + SortedHashedSearch(); static const ModelType kModelType = HASH_SORTED; }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 12294682..3aeeeca3 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,6 +1,7 @@ /* This is where the trie is built. It's on-disk. */ #include "lm/search_trie.hh" +#include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/read_arpa.hh" #include "lm/trie.hh" @@ -13,10 +14,10 @@ #include "util/scoped.hh" #include +#include #include #include #include -#include #include //#include #include @@ -152,11 +153,11 @@ void ReadOrThrow(FILE *from, void *data, size_t size) { if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size); } +const std::size_t kCopyBufSize = 512; void CopyOrThrow(FILE *from, FILE *to, size_t size) { - const size_t kBufSize = 512; - char buf[kBufSize]; - for (size_t i = 0; i < size; i += kBufSize) { - std::size_t amount = std::min(size - i, kBufSize); + char buf[std::min(size, kCopyBufSize)]; + for (size_t i = 0; i < size; i += kCopyBufSize) { + std::size_t amount = std::min(size - i, kCopyBufSize); ReadOrThrow(from, buf, amount); WriteOrThrow(to, buf, amount); } @@ -172,8 +173,10 @@ std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::str if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing"); // Compress entries that being with the same (order-1) words. for (const uint8_t *group_begin = static_cast(mem_begin); group_begin != static_cast(mem_end);) { - const uint8_t *group_end = group_begin; - for (group_end += entry_size; (group_end != static_cast(mem_end)) && !memcmp(group_begin, group_end, prefix_size); group_end += entry_size) {} + const uint8_t *group_end; + for (group_end = group_begin + entry_size; + (group_end != static_cast(mem_end)) && !memcmp(group_begin, group_end, prefix_size); + group_end += entry_size) {} WriteOrThrow(out.get(), group_begin, prefix_size); WordIndex group_size = (group_end - group_begin) / entry_size; WriteOrThrow(out.get(), &group_size, sizeof(group_size)); @@ -188,7 +191,7 @@ std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::str class SortedFileReader { public: - SortedFileReader() {} + SortedFileReader() : ended_(false) {} void Init(const std::string &name, unsigned char order) { file_.reset(fopen(name.c_str(), "r")); @@ -206,25 +209,39 @@ class SortedFileReader { std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); } void NextHeader() { - if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get()) && !Ended()) { - UTIL_THROW(util::ErrnoException, "Short read of counts"); + if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get())) { + if (feof(file_.get())) { + ended_ = true; + } else { + UTIL_THROW(util::ErrnoException, "Short read of counts"); + } } } - void ReadCount(WordIndex &to) { - ReadOrThrow(file_.get(), &to, sizeof(WordIndex)); + WordIndex ReadCount() { + WordIndex ret; + ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); + return ret; } - void ReadWord(WordIndex &to) { - ReadOrThrow(file_.get(), &to, sizeof(WordIndex)); + WordIndex ReadWord() { + WordIndex ret; + ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); + return ret; } - template void ReadWeights(Weights &to) { - ReadOrThrow(file_.get(), &to, sizeof(Weights)); + template void ReadWeights(Weights &weights) { + ReadOrThrow(file_.get(), &weights, sizeof(Weights)); } bool Ended() { - return feof(file_.get()); + return ended_; + } + + void Rewind() { + rewind(file_.get()); + ended_ = false; + NextHeader(); } FILE *File() { return file_.get(); } @@ -233,12 +250,13 @@ class SortedFileReader { util::scoped_FILE file_; std::vector header_; + + bool ended_; }; void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) { WriteOrThrow(to, from.Header(), from.HeaderBytes()); - WordIndex count; - from.ReadCount(count); + WordIndex count = from.ReadCount(); WriteOrThrow(to, &count, sizeof(WordIndex)); CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); @@ -263,25 +281,23 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha } // Merge at the entry level. WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes()); - WordIndex first_count, second_count; - first.ReadCount(first_count); second.ReadCount(second_count); + WordIndex first_count = first.ReadCount(), second_count = second.ReadCount(); WordIndex total_count = first_count + second_count; WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex)); - WordIndex first_word, second_word; - first.ReadWord(first_word); second.ReadWord(second_word); + WordIndex first_word = first.ReadWord(), second_word = second.ReadWord(); WordIndex first_index = 0, second_index = 0; while (true) { if (first_word < second_word) { WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); CopyOrThrow(first.File(), out_file.get(), weights_size); if (++first_index == first_count) break; - first.ReadWord(first_word); + first_word = first.ReadWord(); } else { WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); CopyOrThrow(second.File(), out_file.get(), weights_size); if (++second_index == second_count) break; - second.ReadWord(second_word); + second_word = second.ReadWord(); } } if (first_index == first_count) { @@ -358,75 +374,219 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector &counts, { std::string unigram_name = file_prefix + "unigrams"; util::scoped_fd unigram_file; - util::scoped_mmap unigram_mmap; - unigram_mmap.reset(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get())); } util::scoped_memory mem; - mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED); + mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size()); ReadEnd(f); } -struct RecursiveInsertParams { - WordIndex *words; - SortedFileReader *files; - unsigned char max_order; - // This is an array of size order - 2. - BitPackedMiddle *middle; - // This has exactly one entry. - BitPackedLongest *longest; -}; - -uint64_t RecursiveInsert(RecursiveInsertParams ¶ms, unsigned char order) { - SortedFileReader &file = params.files[order - 2]; - const uint64_t ret = (order == params.max_order) ? params.longest->InsertIndex() : params.middle[order - 2].InsertIndex(); - if (std::memcmp(params.words, file.Header(), sizeof(WordIndex) * (order - 1))) - return ret; - WordIndex count; - file.ReadCount(count); - WordIndex key; - if (order == params.max_order) { - Prob value; - for (WordIndex i = 0; i < count; ++i) { - file.ReadWord(key); - file.ReadWeights(value); - params.longest->Insert(key, value.prob); - } - file.NextHeader(); - return ret; - } - ProbBackoff value; - for (WordIndex i = 0; i < count; ++i) { - file.ReadWord(params.words[order - 1]); - file.ReadWeights(value); - params.middle[order - 2].Insert( - params.words[order - 1], - value.prob, - value.backoff, - RecursiveInsert(params, order + 1)); +bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) { + for (; words != words_end; ++words, ++header) { + if (*words != *header) { + assert(*words <= *header); + return false; + } } - file.NextHeader(); - return ret; + return true; } -void BuildTrie(const std::string &file_prefix, const std::vector &counts, std::ostream *messages, TrieSearch &out) { - UnigramValue *unigrams = out.unigram.Raw(); - // Load unigrams. Leave the next pointers uninitialized. - { - std::string name(file_prefix + "unigrams"); - util::scoped_FILE file(fopen(name.c_str(), "r")); - if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed"); - for (WordIndex i = 0; i < counts[0]; ++i) { - ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); +class JustCount { + public: + JustCount(UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) + : counts_(counts), longest_counts_(counts + order - 1) {} + + void Unigrams(WordIndex begin, WordIndex end) { + counts_[0] += end - begin; } - unlink(name.c_str()); + + void MiddleBlank(const unsigned char mid_idx, WordIndex /* idx */) { + ++counts_[mid_idx + 1]; + } + + void Middle(const unsigned char mid_idx, WordIndex /*key*/, const ProbBackoff &/*weights*/) { + ++counts_[mid_idx + 1]; + } + + void Longest(WordIndex /*key*/, Prob /*prob*/) { + ++*longest_counts_; + } + + // Unigrams wrote one past. + void Cleanup() { + --counts_[0]; + } + + private: + uint64_t *const counts_, *const longest_counts_; +}; + +class WriteEntries { + public: + WriteEntries(UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + unigrams_(unigrams), + middle_(middle), + longest_(longest), + bigram_pack_((order == 2) ? static_cast(longest_) : static_cast(*middle_)) {} + + void Unigrams(WordIndex begin, WordIndex end) { + uint64_t next = bigram_pack_.InsertIndex(); + for (UnigramValue *i = unigrams_ + begin; i < unigrams_ + end; ++i) { + i->next = next; + } + } + + void MiddleBlank(const unsigned char mid_idx, WordIndex key) { + middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff); + } + + void Middle(const unsigned char mid_idx, WordIndex key, const ProbBackoff &weights) { + middle_[mid_idx].Insert(key, weights.prob, weights.backoff); + } + + void Longest(WordIndex key, Prob prob) { + longest_.Insert(key, prob.prob); + } + + void Cleanup() {} + + private: + UnigramValue *const unigrams_; + BitPackedMiddle *const middle_; + BitPackedLongest &longest_; + BitPacked &bigram_pack_; +}; + +template class RecursiveInsert { + public: + RecursiveInsert(SortedFileReader *inputs, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : + doing_(unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), words_(new WordIndex[order]), order_minus_2_(order - 2) { + } + + // Outer unigram loop. + void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) { + util::ErsatzProgress progress(progress_out, message, unigram_count + 1); + for (words_[0] = 0; ; ++words_[0]) { + WordIndex min_continue = unigram_count; + for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) { + if (other->Ended()) continue; + min_continue = std::min(min_continue, other->Header()[0]); + } + // This will write at unigram_count. This is by design so that the next pointers will make sense. + doing_.Unigrams(words_[0], min_continue + 1); + if (min_continue == unigram_count) break; + progress += min_continue - words_[0]; + words_[0] = min_continue; + Middle(0); + } + doing_.Cleanup(); + } + + private: + void Middle(const unsigned char mid_idx) { + // (mid_idx + 2)-gram. + if (mid_idx == order_minus_2_) { + Longest(); + return; + } + // Orders [2, order) + + SortedFileReader &reader = inputs_[mid_idx]; + + if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + mid_idx + 1, reader.Header())) { + // This order doesn't have a header match, but longer ones might. + MiddleAllBlank(mid_idx); + return; + } + + // There is a header match. + WordIndex count = reader.ReadCount(); + WordIndex current = reader.ReadWord(); + while (count) { + WordIndex min_continue = std::numeric_limits::max(); + for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { + if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header())) + min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); + } + while (true) { + if (current > min_continue) { + doing_.MiddleBlank(mid_idx, min_continue); + words_[mid_idx + 1] = min_continue; + Middle(mid_idx + 1); + break; + } + ProbBackoff weights; + reader.ReadWeights(weights); + doing_.Middle(mid_idx, current, weights); + --count; + if (current == min_continue) { + words_[mid_idx + 1] = min_continue; + Middle(mid_idx + 1); + if (count) current = reader.ReadWord(); + break; + } + if (!count) break; + current = reader.ReadWord(); + } + } + // Count is now zero. Finish off remaining blanks. + MiddleAllBlank(mid_idx); + reader.NextHeader(); + } + + void MiddleAllBlank(const unsigned char mid_idx) { + while (true) { + WordIndex min_continue = std::numeric_limits::max(); + for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { + if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header())) + min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); + } + if (min_continue == std::numeric_limits::max()) return; + doing_.MiddleBlank(mid_idx, min_continue); + words_[mid_idx + 1] = min_continue; + Middle(mid_idx + 1); + } + } + + void Longest() { + SortedFileReader &reader = *(inputs_end_ - 1); + if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + order_minus_2_ + 1, reader.Header())) return; + WordIndex count = reader.ReadCount(); + for (WordIndex i = 0; i < count; ++i) { + WordIndex word = reader.ReadWord(); + Prob prob; + reader.ReadWeights(prob); + doing_.Longest(word, prob); + } + reader.NextHeader(); + return; + } + + Doing doing_; + + SortedFileReader *inputs_; + SortedFileReader *inputs_end_; + + util::scoped_array words_; + + const unsigned char order_minus_2_; +}; + +void SanityCheckCounts(const std::vector &initial, const std::vector &fixed) { + if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]); + if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant"); + for (unsigned char i = 0; i < initial.size(); ++i) { + if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected. This shouldn't happen"); } +} - // inputs[0] is bigrams. +void BuildTrie(const std::string &file_prefix, const std::vector &counts, const Config &config, TrieSearch &out, Backing &backing) { SortedFileReader inputs[counts.size() - 1]; + for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; assembled << file_prefix << static_cast(i) << "_merged"; @@ -434,36 +594,49 @@ void BuildTrie(const std::string &file_prefix, const std::vector &coun unlink(assembled.str().c_str()); } - // words[0] is unigrams. - WordIndex words[counts.size()]; - RecursiveInsertParams params; - params.words = words; - params.files = inputs; - params.max_order = static_cast(counts.size()); - params.middle = &*out.middle.begin(); - params.longest = &out.longest; + std::vector fixed_counts(counts.size()); + { + RecursiveInsert counter(inputs, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); + } + SanityCheckCounts(counts, fixed_counts); + + out.SetupMemory(GrowForSearch(config, TrieSearch::kModelType, fixed_counts, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + + for (unsigned char i = 2; i <= counts.size(); ++i) { + inputs[i-2].Rewind(); + } + UnigramValue *unigrams = out.unigram.Raw(); + // Fill entries except unigram probabilities. { - util::ErsatzProgress progress(messages, "Building trie", counts[0]); - for (words[0] = 0; words[0] < counts[0]; ++words[0], ++progress) { - unigrams[words[0]].next = RecursiveInsert(params, 2); + RecursiveInsert inserter(inputs, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + inserter.Apply(config.messages, "Building trie", fixed_counts[0]); + } + + // Fill unigram probabilities. + { + std::string name(file_prefix + "unigrams"); + util::scoped_FILE file(fopen(name.c_str(), "r")); + if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed"); + for (WordIndex i = 0; i < counts[0]; ++i) { + ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); } + unlink(name.c_str()); } /* Set ending offsets so the last entry will be sized properly */ + // Last entry for unigrams was already set. if (!out.middle.empty()) { - unigrams[counts[0]].next = out.middle.front().InsertIndex(); for (size_t i = 0; i < out.middle.size() - 1; ++i) { out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); } out.middle.back().FinishedLoading(out.longest.InsertIndex()); - } else { - unigrams[counts[0]].next = out.longest.InsertIndex(); } } } // namespace -void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, SortedVocabulary &vocab) { +void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -473,7 +646,8 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const temporary_directory = file; } // Null on end is kludge to ensure null termination. - temporary_directory += "-tmp-XXXXXX\0"; + temporary_directory += "-tmp-XXXXXX"; + temporary_directory += '\0'; if (!mkdtemp(&temporary_directory[0])) { UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str()); } @@ -483,9 +657,10 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const temporary_directory += '/'; // At least 1MB sorting memory. ARPAToSortedFiles(f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory.c_str(), counts, config.messages, *this); - if (rmdir(temporary_directory.c_str())) { - std::cerr << "Failed to delete " << temporary_directory << std::endl; + + BuildTrie(temporary_directory.c_str(), counts, config, *this, backing); + if (rmdir(temporary_directory.c_str()) && config.messages) { + *config.messages << "Failed to delete " << temporary_directory << std::endl; } } diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 902f6ce6..0f720217 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -9,6 +9,7 @@ namespace lm { namespace ngram { +struct Backing; class SortedVocabulary; namespace trie { @@ -39,14 +40,18 @@ struct TrieSearch { start += Unigram::Size(counts[0]); middle.resize(counts.size() - 2); for (unsigned char i = 1; i < counts.size() - 1; ++i) { - middle[i-1].Init(start, counts[0], counts[i+1]); + middle[i-1].Init( + start, + counts[0], + counts[i+1], + (i == counts.size() - 2) ? static_cast(longest) : static_cast(middle[i])); start += Middle::Size(counts[i], counts[0], counts[i+1]); } longest.Init(start, counts[0]); return start + Longest::Size(counts.back(), counts[0]); } - void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, SortedVocabulary &vocab); + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { return unigram.Find(word, prob, backoff, node); @@ -65,7 +70,7 @@ struct TrieSearch { } bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - // TODO: don't decode prob. + // TODO: don't decode backoff. assert(begin != end); float ignored_prob, ignored_backoff; LookupUnigram(*begin, ignored_prob, ignored_backoff, node); diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 04bd2079..2c633613 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -82,7 +82,8 @@ std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr)); } -void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) { +void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) { + next_source_ = &next_source; backoff_bits_ = 32; next_bits_ = util::RequiredBits(max_next); if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); @@ -91,9 +92,8 @@ void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) { BaseInit(base, max_vocab, backoff_bits_ + next_bits_); } -void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t next) { +void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { assert(word <= word_mask_); - assert(next <= next_mask_); uint64_t at_pointer = insert_index_ * total_bits_; util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word); @@ -102,6 +102,8 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t at_pointer += prob_bits_; util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff); at_pointer += backoff_bits_; + uint64_t next = next_source_->InsertIndex(); + assert(next <= next_mask_); util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next); ++insert_index_; @@ -109,7 +111,9 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) { + return false; + } at_pointer *= total_bits_; at_pointer += word_bits_; prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); @@ -144,7 +148,6 @@ void BitPackedMiddle::FinishedLoading(uint64_t next_end) { util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end); } - void BitPackedLongest::Insert(WordIndex index, float prob) { assert(index <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 35dc2c96..6aef050c 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -89,9 +89,10 @@ class BitPackedMiddle : public BitPacked { static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next); - void Init(void *base, uint64_t max_vocab, uint64_t max_next); + // next_source need not be initialized. + void Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); - void Insert(WordIndex word, float prob, float backoff, uint64_t next); + void Insert(WordIndex word, float prob, float backoff); bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; @@ -102,6 +103,8 @@ class BitPackedMiddle : public BitPacked { private: uint8_t backoff_bits_, next_bits_; uint64_t next_mask_; + + const BitPacked *next_source_; }; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index c30428b2..ae79c727 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -68,15 +68,19 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) { } // namespace -WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner, int fd) : inner_(inner), fd_(fd) {} +WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {} WriteWordsWrapper::~WriteWordsWrapper() {} void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { if (inner_) inner_->Add(index, str); - WriteOrThrow(fd_, str.data(), str.size()); - char null_byte = 0; - // Inefficient because it's unbuffered. Sue me. - WriteOrThrow(fd_, &null_byte, 1); + buffer_.append(str.data(), str.size()); + buffer_.push_back(0); +} + +void WriteWordsWrapper::Write(int fd) { + if ((off_t)-1 == lseek(fd, 0, SEEK_END)) + UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words"); + WriteOrThrow(fd, buffer_.data(), buffer_.size()); } SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index bb5d789b..8c99d797 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -27,15 +27,18 @@ inline uint64_t HashForVocab(const StringPiece &str) { class WriteWordsWrapper : public EnumerateVocab { public: - WriteWordsWrapper(EnumerateVocab *inner, int fd); + WriteWordsWrapper(EnumerateVocab *inner); ~WriteWordsWrapper(); void Add(WordIndex index, const StringPiece &str); + void Write(int fd); + private: EnumerateVocab *inner_; - int fd_; + + std::string buffer_; }; // Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. -- cgit v1.2.3