From c4ade3091b812ca135ae6520fa7173e1bbf28754 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 25 Jan 2011 22:30:48 +0200 Subject: update kenlm --- klm/lm/binary_format.cc | 4 +- klm/lm/blank.hh | 44 +++++- klm/lm/build_binary.cc | 2 +- klm/lm/model.cc | 97 +++++++------ klm/lm/model.hh | 9 +- klm/lm/model_test.cc | 49 ++++++- klm/lm/ngram_query.cc | 27 ++-- klm/lm/read_arpa.cc | 17 ++- klm/lm/read_arpa.hh | 6 +- klm/lm/search_hashed.cc | 52 ++++++- klm/lm/search_trie.cc | 302 +++++++++++++++++++++++++++++++++++------ klm/lm/vocab.hh | 1 - klm/util/bit_packing.cc | 2 +- klm/util/bit_packing.hh | 17 ++- klm/util/ersatz_progress.cc | 1 + klm/util/file_piece.cc | 24 +--- klm/util/file_piece.hh | 35 ++++- klm/util/key_value_packing.hh | 4 + klm/util/probing_hash_table.hh | 14 +- klm/util/sorted_uniform.hh | 10 ++ 20 files changed, 550 insertions(+), 167 deletions(-) diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 3d9700da..2a6aff34 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -18,8 +18,8 @@ namespace lm { namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; -const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 3\n\0"; -const long int kMagicVersion = 2; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0"; +const long int kMagicVersion = 4; // Test values. struct Sanity { diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 639bc98b..4615a09e 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -1,12 +1,52 @@ #ifndef LM_BLANK__ #define LM_BLANK__ + #include +#include +#include + namespace lm { namespace ngram { -const float kBlankProb = -std::numeric_limits::quiet_NaN(); -const float kBlankBackoff = std::numeric_limits::infinity(); +/* Suppose "foo bar" appears with zero backoff but there is no trigram + * beginning with these words. Then, when scoring "foo bar", the model could + * return out_state containing "bar" or even null context if "bar" also has no + * backoff and is never followed by another word. Then the backoff is set to + * kNoExtensionBackoff. If the n-gram might be extended, then out_state must + * contain the full n-gram, in which case kExtensionBackoff is set. In any + * case, if an n-gram has non-zero backoff, the full state is returned so + * backoff can be properly charged. + * These differ only in sign bit because the backoff is in fact zero in either + * case. + */ +const float kNoExtensionBackoff = -0.0; +const float kExtensionBackoff = 0.0; + +inline void SetExtension(float &backoff) { + if (backoff == kNoExtensionBackoff) backoff = kExtensionBackoff; +} + +// This compiles down nicely. +inline bool HasExtension(const float &backoff) { + typedef union { float f; uint32_t i; } UnionValue; + UnionValue compare, interpret; + compare.f = kNoExtensionBackoff; + interpret.f = backoff; + return compare.i != interpret.i; +} + +/* Suppose "foo bar baz quux" appears in the ARPA but not "bar baz quux" or + * "baz quux" (because they were pruned). 1.2% of n-grams generated by SRI + * with default settings on the benchmark data set are like this. Since search + * proceeds by finding "quux", "baz quux", "bar baz quux", and finally + * "foo bar baz quux" and the trie needs pointer nodes anyway, blanks are + * inserted. The blanks have probability kBlankProb and backoff kBlankBackoff. + * A blank is recognized by kBlankProb in the probability field; kBlankBackoff + * must be 0 so that inference asseses zero backoff from these blanks. + */ +const float kBlankProb = -std::numeric_limits::infinity(); +const float kBlankBackoff = kNoExtensionBackoff; } // namespace ngram } // namespace lm diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index b340797b..144c57e0 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -21,7 +21,7 @@ void Usage(const char *name) { "memory and is still faster than SRI or IRST. Building the trie format uses an\n" "on-disk sort to save memory.\n" "-t is the temporary directory prefix. Default is the output file name.\n" -"-m is the amount of memory to use, in MB. Default is 1024MB (1GB).\n\n" +"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n\n" /*"sorted is like probing but uses a sorted uniform map instead of a hash table.\n" "It uses more memory than trie and is also slower, so there's no real reason to\n" "use it.\n\n"*/ diff --git a/klm/lm/model.cc b/klm/lm/model.cc index c7ba4908..146fe07b 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -61,10 +61,10 @@ template void GenericModel counts; - // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed with search_.VariableSizeLoad + // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. ReadARPACounts(f, counts); - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); @@ -114,7 +114,24 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state); - ret.prob += SlowBackoffLookup(context_rbegin, context_rend, ret.ngram_length); + + // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin). + unsigned char start = ret.ngram_length; + if (context_rend - context_rbegin < static_cast(start)) return ret; + if (start <= 1) { + ret.prob += search_.unigram.Lookup(*context_rbegin).backoff; + start = 2; + } + typename Search::Node node; + if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { + return ret; + } + float backoff; + // i is the order of the backoff we're looking for. + for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { + if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; + ret.prob += backoff; + } return ret; } @@ -128,8 +145,7 @@ template void GenericModel void GenericModel float GenericModel::SlowBackoffLookup( - const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const { - // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin). - if (context_rend - context_rbegin < static_cast(start)) return 0.0; - float ret = 0.0; - if (start == 1) { - ret += search_.unigram.Lookup(*context_rbegin).backoff; - start = 2; - } - typename Search::Node node; - if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { - return 0.0; - } - float backoff; - // i is the order of the backoff we're looking for. - for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { - if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; - if (backoff != kBlankBackoff) ret += backoff; - } - return ret; +namespace { +// Do a paraonoid copy of history, assuming new_word has already been copied +// (hence the -1). out_state.valid_length_ could be zero so I avoided using +// std::copy. +void CopyRemainingHistory(const WordIndex *from, State &out_state) { + WordIndex *out = out_state.history_ + 1; + const WordIndex *in_end = from + static_cast(out_state.valid_length_) - 1; + for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in; } +} // namespace /* Ugly optimized function. Produce a score excluding backoff. * The search goes in increasing order of ngram length. @@ -179,28 +180,26 @@ template FullScoreReturn GenericModel::const_iterator mid_iter = search_.middle.begin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { // Ran out of history. Typically no backoff, but this could be a blank. - out_state.valid_length_ = ret.ngram_length; - std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); + CopyRemainingHistory(context_rbegin, out_state); // ret.prob was already set. return ret; } @@ -210,32 +209,32 @@ template FullScoreReturn GenericModel(state.valid_length_) << ':'; + for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) { + o << ' ' << *i; + } + return o; +} + namespace { #define StartTest(word, ngram, score) \ @@ -17,7 +26,15 @@ namespace { out);\ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); \ - BOOST_CHECK_EQUAL(std::min(ngram, 5 - 1), out.valid_length_); + BOOST_CHECK_GE(std::min(ngram, 5 - 1), out.valid_length_); \ + {\ + WordIndex context[state.valid_length_ + 1]; \ + context[0] = model.GetVocabulary().Index(word); \ + std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \ + State get_state; \ + model.GetState(context, context + state.valid_length_ + 1, get_state); \ + BOOST_CHECK_EQUAL(out, get_state); \ + } #define AppendTest(word, ngram, score) \ StartTest(word, ngram, score) \ @@ -52,10 +69,13 @@ template void Continuation(const M &model) { AppendTest("more", 1, -1.20632 - 20.0); AppendTest(".", 2, -0.51363); AppendTest("", 3, -0.0191651); + BOOST_CHECK_EQUAL(0, state.valid_length_); state = preserve; AppendTest("more", 5, -0.00181395); + BOOST_CHECK_EQUAL(4, state.valid_length_); AppendTest("loin", 5, -0.0432557); + BOOST_CHECK_EQUAL(1, state.valid_length_); } template void Blanks(const M &model) { @@ -68,6 +88,7 @@ template void Blanks(const M &model) { State preserve = state; AppendTest("higher", 4, -4); AppendTest("looking", 5, -5); + BOOST_CHECK_EQUAL(1, state.valid_length_); state = preserve; AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103); @@ -94,6 +115,29 @@ template void Unknowns(const M &model) { AppendTest("not_found3", 3, -6); } +template void MinimalState(const M &model) { + FullScoreReturn ret; + State state(model.NullContextState()); + State out; + + AppendTest("baz", 1, -6.535897); + BOOST_CHECK_EQUAL(0, state.valid_length_); + state = model.NullContextState(); + AppendTest("foo", 1, -3.141592); + BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("bar", 2, -6.0); + // Has to include the backoff weight. + BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("bar", 1, -2.718281 + 3.0); + BOOST_CHECK_EQUAL(1, state.valid_length_); + + state = model.NullContextState(); + AppendTest("to", 1, -1.687872); + AppendTest("look", 2, -0.2922095); + BOOST_CHECK_EQUAL(2, state.valid_length_); + AppendTest("good", 3, -7); +} + #define StatelessTest(word, provide, ngram, score) \ ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ @@ -154,6 +198,7 @@ template void Everything(const M &m) { Continuation(m); Blanks(m); Unknowns(m); + MinimalState(m); Stateless(m); } @@ -167,7 +212,7 @@ class ExpectEnumerateVocab : public EnumerateVocab { } void Check(const base::Vocabulary &vocab) { - BOOST_CHECK_EQUAL(34ULL, seen.size()); + BOOST_CHECK_EQUAL(37ULL, seen.size()); BOOST_REQUIRE(!seen.empty()); BOOST_CHECK_EQUAL("", seen[0]); for (WordIndex i = 0; i < seen.size(); ++i) { diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 3fa8cb03..d6da02e3 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -6,6 +6,8 @@ #include #include +#include + #include #include @@ -43,35 +45,38 @@ template void Query(const Model &model) { state = model.BeginSentenceState(); float total = 0.0; bool got = false; + unsigned int oov = 0; while (std::cin >> word) { got = true; lm::WordIndex vocab = model.GetVocabulary().Index(word); + if (vocab == 0) ++oov; ret = model.FullScore(state, vocab, out); total += ret.prob; std::cout << word << '=' << vocab << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\n'; state = out; - if (std::cin.get() == '\n') break; + char c; + while (true) { + c = std::cin.get(); + if (!std::cin) break; + if (c == '\n') break; + if (!isspace(c)) { + std::cin.unget(); + break; + } + } + if (c == '\n') break; } if (!got && !std::cin) break; ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); total += ret.prob; std::cout << "=" << model.GetVocabulary().EndSentence() << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\n'; - std::cout << "Total: " << total << '\n'; + std::cout << "Total: " << total << " OOV: " << oov << '\n'; } PrintUsage("After queries:\n"); } -class PrintVocab : public lm::ngram::EnumerateVocab { - public: - void Add(lm::WordIndex index, const StringPiece &str) { - std::cerr << "vocab " << index << ' ' << str << '\n'; - } -}; - template void Query(const char *name) { lm::ngram::Config config; - PrintVocab printer; - config.enumerate_vocab = &printer; Model model(name, config); Query(model); } diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 262a9c6a..d0fe67f0 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -1,5 +1,7 @@ #include "lm/read_arpa.hh" +#include "lm/blank.hh" + #include #include @@ -8,6 +10,9 @@ namespace lm { +// 1 for '\t', '\n', and ' '. This is stricter than isspace. +const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; + namespace { bool IsEntirelyWhiteSpace(const StringPiece &line) { @@ -116,21 +121,27 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) { case '\n': break; default: - UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); + UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff"); } } void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) { + // Always make zero negative. + // Negative zero means that no (n+1)-gram has this n-gram as context. + // Therefore the hypothesis state can be shorter. Of course, many n-grams + // are context for (n+1)-grams. An algorithm in the data structure will go + // back and set the backoff to positive zero in these cases. switch (in.get()) { case '\t': weights.backoff = in.ReadFloat(); + if (weights.backoff == ngram::kExtensionBackoff) weights.backoff = ngram::kNoExtensionBackoff; if ((in.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); break; case '\n': - weights.backoff = 0.0; + weights.backoff = ngram::kNoExtensionBackoff; break; default: - UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); + UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff"); } } diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh index 571fcbc5..4efdd29d 100644 --- a/klm/lm/read_arpa.hh +++ b/klm/lm/read_arpa.hh @@ -23,12 +23,14 @@ void ReadBackoff(util::FilePiece &in, ProbBackoff &weights); void ReadEnd(util::FilePiece &in); void ReadEnd(std::istream &in); +extern const bool kARPASpaces[256]; + template void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams) { try { float prob = f.ReadFloat(); if (prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << prob); if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); - ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())]; + ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))]; value.prob = prob; ReadBackoff(f, value); } catch(util::Exception &e) { @@ -50,7 +52,7 @@ template void ReadNGram(util::FilePiece &f, const uns weights.prob = f.ReadFloat(); if (weights.prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << weights.prob); for (WordIndex *vocab_out = reverse_indices + n - 1; vocab_out >= reverse_indices; --vocab_out) { - *vocab_out = vocab.Index(f.ReadDelimited()); + *vocab_out = vocab.Index(f.ReadDelimited(kARPASpaces)); } ReadBackoff(f, weights); } catch(util::Exception &e) { diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 9200aeb6..00d03f4e 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -14,7 +14,41 @@ namespace ngram { namespace { -template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector &middle, Store &store) { +/* These are passed to ReadNGrams so that n-grams with zero backoff that appear as context will still be used in state. */ +template class ActivateLowerMiddle { + public: + explicit ActivateLowerMiddle(Middle &middle) : modify_(middle) {} + + void operator()(const WordIndex *vocab_ids, const unsigned int n) { + uint64_t hash = static_cast(vocab_ids[1]); + for (const WordIndex *i = vocab_ids + 2; i < vocab_ids + n; ++i) { + hash = detail::CombineWordHash(hash, *i); + } + typename Middle::MutableIterator i; + // TODO: somehow get text of n-gram for this error message. + if (!modify_.UnsafeMutableFind(hash, i)) + UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram"); + SetExtension(i->MutableValue().backoff); + } + + private: + Middle &modify_; +}; + +class ActivateUnigram { + public: + explicit ActivateUnigram(ProbBackoff *unigram) : modify_(unigram) {} + + void operator()(const WordIndex *vocab_ids, const unsigned int /*n*/) { + // assert(n == 2); + SetExtension(modify_[vocab_ids[1]].backoff); + } + + private: + ProbBackoff *modify_; +}; + +template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector &middle, Activate activate, Store &store) { ReadNGramHeader(f, n); ProbBackoff blank; @@ -38,6 +72,7 @@ template void ReadNGrams(util::FilePiece if (middle[lower].Find(keys[lower], found)) break; middle[lower].Insert(Middle::Packing::Make(keys[lower], blank)); } + activate(vocab_ids, n); } store.FinishedInserting(); @@ -53,12 +88,19 @@ template template void TemplateHashe Read1Grams(f, counts[0], vocab, unigram.Raw()); try { - for (unsigned int n = 2; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle, middle[n-2]); + if (counts.size() > 2) { + ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0]); + } + for (unsigned int n = 3; n < counts.size(); ++n) { + ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle(middle[n-3]), middle[n-2]); + } + if (counts.size() > 2) { + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest); + } else { + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle(middle.back()), longest); } - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, longest); } catch (util::ProbingSizeException &e) { - UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces. "); + UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); } } diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 3aeeeca3..1060ddef 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -3,6 +3,7 @@ #include "lm/blank.hh" #include "lm/lm_exception.hh" +#include "lm/max_order.hh" #include "lm/read_arpa.hh" #include "lm/trie.hh" #include "lm/vocab.hh" @@ -27,6 +28,7 @@ #include #include #include +#include namespace lm { namespace ngram { @@ -98,7 +100,7 @@ class EntryProxy { } const WordIndex *Indices() const { - return static_cast(inner_.Data()); + return reinterpret_cast(inner_.Data()); } private: @@ -114,17 +116,57 @@ class EntryProxy { typedef util::ProxyIterator NGramIter; -class CompareRecords : public std::binary_function { +// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams. +class PartialViewProxy { + public: + PartialViewProxy() : attention_size_(0), inner_() {} + + PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {} + + operator std::string() const { + return std::string(reinterpret_cast(inner_.Data()), attention_size_); + } + + PartialViewProxy &operator=(const PartialViewProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), attention_size_); + return *this; + } + + PartialViewProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), attention_size_); + return *this; + } + + const WordIndex *Indices() const { + return reinterpret_cast(inner_.Data()); + } + + private: + friend class util::ProxyIterator; + + typedef std::string value_type; + + const std::size_t attention_size_; + + typedef EntryIterator InnerIterator; + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +typedef util::ProxyIterator PartialIter; + +template class CompareRecords : public std::binary_function { public: explicit CompareRecords(unsigned char order) : order_(order) {} - bool operator()(const EntryProxy &first, const EntryProxy &second) const { + bool operator()(const Proxy &first, const Proxy &second) const { return Compare(first.Indices(), second.Indices()); } - bool operator()(const EntryProxy &first, const std::string &second) const { + bool operator()(const Proxy &first, const std::string &second) const { return Compare(first.Indices(), reinterpret_cast(second.data())); } - bool operator()(const std::string &first, const EntryProxy &second) const { + bool operator()(const std::string &first, const Proxy &second) const { return Compare(reinterpret_cast(first.data()), second.Indices()); } bool operator()(const std::string &first, const std::string &second) const { @@ -144,6 +186,12 @@ class CompareRecords : public std::binary_function(order) << '_' << batch; std::string ret(assembled.str()); - util::scoped_FILE out(fopen(ret.c_str(), "w")); - if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing"); + util::scoped_FILE out(OpenOrThrow(ret.c_str(), "w")); // Compress entries that being with the same (order-1) words. for (const uint8_t *group_begin = static_cast(mem_begin); group_begin != static_cast(mem_end);) { const uint8_t *group_end; @@ -194,8 +254,7 @@ class SortedFileReader { SortedFileReader() : ended_(false) {} void Init(const std::string &name, unsigned char order) { - file_.reset(fopen(name.c_str(), "r")); - if (!file_.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " for read"); + file_.reset(OpenOrThrow(name.c_str(), "r")); header_.resize(order - 1); NextHeader(); } @@ -262,12 +321,13 @@ void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); } -void MergeSortedFiles(const char *first_name, const char *second_name, const char *out, std::size_t weights_size, unsigned char order) { +void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order) { SortedFileReader first, second; - first.Init(first_name, order); - second.Init(second_name, order); - util::scoped_FILE out_file(fopen(out, "w")); - if (!out_file.get()) UTIL_THROW(util::ErrnoException, "Could not open " << out << " for write"); + first.Init(first_name.c_str(), order); + RemoveOrThrow(first_name.c_str()); + second.Init(second_name.c_str(), order); + RemoveOrThrow(second_name.c_str()); + util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); while (!first.Ended() && !second.Ended()) { if (first.HeaderVector() < second.HeaderVector()) { CopyFullRecord(first, out_file.get(), weights_size); @@ -316,10 +376,109 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha } } -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { - if (order == 1) return; - ConvertToSorted(f, vocab, counts, mem, file_prefix, order - 1); +const char *kContextSuffix = "_contexts"; + +void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { + const size_t context_size = sizeof(WordIndex) * (order - 1); + // Sort just the contexts using the same memory. + PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); + PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); + + // TODO: __gnu_parallel::sort here. + std::sort(context_begin, context_end, CompareRecords(order - 1)); + + std::string name(ngram_file_name + kContextSuffix); + util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); + + // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. + if (context_begin == context_end) return; + PartialIter i(context_begin); + WriteOrThrow(out.get(), i->Indices(), context_size); + const WordIndex *previous = i->Indices(); + ++i; + for (; i != context_end; ++i) { + if (memcmp(previous, i->Indices(), context_size)) { + WriteOrThrow(out.get(), i->Indices(), context_size); + previous = i->Indices(); + } + } +} +class ContextReader { + public: + ContextReader() : length_(0) {} + + ContextReader(const char *name, size_t length) : file_(OpenOrThrow(name, "r")), length_(length), words_(length), valid_(true) { + ++*this; + } + + void Reset(const char *name, size_t length) { + file_.reset(OpenOrThrow(name, "r")); + length_ = length; + words_.resize(length); + valid_ = true; + ++*this; + } + + ContextReader &operator++() { + if (1 != fread(&*words_.begin(), length_, 1, file_.get())) { + if (!feof(file_.get())) + UTIL_THROW(util::ErrnoException, "Short read"); + valid_ = false; + } + return *this; + } + + const WordIndex *operator*() const { return &*words_.begin(); } + + operator bool() const { return valid_; } + + FILE *GetFile() { return file_.get(); } + + private: + util::scoped_FILE file_; + + size_t length_; + + std::vector words_; + + bool valid_; +}; + +void MergeContextFiles(const std::string &first_base, const std::string &second_base, const std::string &out_base, unsigned char order) { + const size_t context_size = sizeof(WordIndex) * (order - 1); + std::string first_name(first_base + kContextSuffix); + std::string second_name(second_base + kContextSuffix); + ContextReader first(first_name.c_str(), context_size), second(second_name.c_str(), context_size); + RemoveOrThrow(first_name.c_str()); + RemoveOrThrow(second_name.c_str()); + std::string out_name(out_base + kContextSuffix); + util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w")); + while (first && second) { + for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) { + if (f == *first + order) { + // Equal. + WriteOrThrow(out.get(), *first, context_size); + ++first; + ++second; + break; + } + if (*f < *s) { + // First lower + WriteOrThrow(out.get(), *first, context_size); + ++first; + break; + } else if (*f > *s) { + WriteOrThrow(out.get(), *second, context_size); + ++second; + break; + } + } + } + CopyRestOrThrow((first ? first : second).GetFile(), out.get()); +} + +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { ReadNGramHeader(f, order); const size_t count = counts[order - 1]; // Size of weights. Does it include backoff? @@ -341,11 +500,13 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size)); } } - // TODO: __gnu_parallel::sort here. + // Sort full records by full n-gram. EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order)); - + // TODO: __gnu_parallel::sort here. + std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order)); files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); + WriteContextFile(begin, out_end, files.back(), entry_size, order); + done += (out_end - begin) / entry_size; } @@ -356,10 +517,9 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st std::stringstream assembled; assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); files.push_back(assembled.str()); - MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), weights_size, order); - if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); + MergeSortedFiles(files[0], files[1], files.back(), weights_size, order); + MergeContextFiles(files[0], files[1], files.back(), order); files.pop_front(); - if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); files.pop_front(); } if (!files.empty()) { @@ -367,6 +527,9 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st assembled << file_prefix << static_cast(order) << "_merged"; std::string merged_name(assembled.str()); if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); + std::string context_name = files[0] + kContextSuffix; + merged_name += kContextSuffix; + if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); } } @@ -378,26 +541,38 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector &counts, Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get())); } + // Only use as much buffer as we need. + size_t buffer_use = 0; + for (unsigned int order = 2; order < counts.size(); ++order) { + buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); + } + buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); + buffer = std::min(buffer, buffer_use); + util::scoped_memory mem; mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); - ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size()); + + for (unsigned char order = 2; order <= counts.size(); ++order) { + ConvertToSorted(f, vocab, counts, mem, file_prefix, order); + } ReadEnd(f); } bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) { for (; words != words_end; ++words, ++header) { if (*words != *header) { - assert(*words <= *header); + //assert(*words <= *header); return false; } } return true; } +// Counting phrase class JustCount { public: - JustCount(UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) + JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) : counts_(counts), longest_counts_(counts + order - 1) {} void Unigrams(WordIndex begin, WordIndex end) { @@ -408,7 +583,7 @@ class JustCount { ++counts_[mid_idx + 1]; } - void Middle(const unsigned char mid_idx, WordIndex /*key*/, const ProbBackoff &/*weights*/) { + void Middle(const unsigned char mid_idx, const WordIndex * /*before*/, WordIndex /*key*/, const ProbBackoff &/*weights*/) { ++counts_[mid_idx + 1]; } @@ -427,7 +602,8 @@ class JustCount { class WriteEntries { public: - WriteEntries(UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + contexts_(contexts), unigrams_(unigrams), middle_(middle), longest_(longest), @@ -444,7 +620,13 @@ class WriteEntries { middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff); } - void Middle(const unsigned char mid_idx, WordIndex key, const ProbBackoff &weights) { + void Middle(const unsigned char mid_idx, const WordIndex *before, WordIndex key, ProbBackoff weights) { + // Order (mid_idx+2). + ContextReader &context = contexts_[mid_idx + 1]; + if (context && !memcmp(before, *context, sizeof(WordIndex) * (mid_idx + 1)) && (*context)[mid_idx + 1] == key) { + SetExtension(weights.backoff); + ++context; + } middle_[mid_idx].Insert(key, weights.prob, weights.backoff); } @@ -455,6 +637,7 @@ class WriteEntries { void Cleanup() {} private: + ContextReader *contexts_; UnigramValue *const unigrams_; BitPackedMiddle *const middle_; BitPackedLongest &longest_; @@ -463,14 +646,15 @@ class WriteEntries { template class RecursiveInsert { public: - RecursiveInsert(SortedFileReader *inputs, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : - doing_(unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), words_(new WordIndex[order]), order_minus_2_(order - 2) { + RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : + doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { } // Outer unigram loop. void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) { util::ErsatzProgress progress(progress_out, message, unigram_count + 1); for (words_[0] = 0; ; ++words_[0]) { + progress.Set(words_[0]); WordIndex min_continue = unigram_count; for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) { if (other->Ended()) continue; @@ -479,7 +663,6 @@ template class RecursiveInsert { // This will write at unigram_count. This is by design so that the next pointers will make sense. doing_.Unigrams(words_[0], min_continue + 1); if (min_continue == unigram_count) break; - progress += min_continue - words_[0]; words_[0] = min_continue; Middle(0); } @@ -497,7 +680,7 @@ template class RecursiveInsert { SortedFileReader &reader = inputs_[mid_idx]; - if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + mid_idx + 1, reader.Header())) { + if (reader.Ended() || !HeadMatch(words_, words_ + mid_idx + 1, reader.Header())) { // This order doesn't have a header match, but longer ones might. MiddleAllBlank(mid_idx); return; @@ -509,7 +692,7 @@ template class RecursiveInsert { while (count) { WordIndex min_continue = std::numeric_limits::max(); for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header())) + if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); } while (true) { @@ -521,7 +704,7 @@ template class RecursiveInsert { } ProbBackoff weights; reader.ReadWeights(weights); - doing_.Middle(mid_idx, current, weights); + doing_.Middle(mid_idx, words_, current, weights); --count; if (current == min_continue) { words_[mid_idx + 1] = min_continue; @@ -542,7 +725,7 @@ template class RecursiveInsert { while (true) { WordIndex min_continue = std::numeric_limits::max(); for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header())) + if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); } if (min_continue == std::numeric_limits::max()) return; @@ -554,7 +737,7 @@ template class RecursiveInsert { void Longest() { SortedFileReader &reader = *(inputs_end_ - 1); - if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + order_minus_2_ + 1, reader.Header())) return; + if (reader.Ended() || !HeadMatch(words_, words_ + order_minus_2_ + 1, reader.Header())) return; WordIndex count = reader.ReadCount(); for (WordIndex i = 0; i < count; ++i) { WordIndex word = reader.ReadWord(); @@ -571,7 +754,7 @@ template class RecursiveInsert { SortedFileReader *inputs_; SortedFileReader *inputs_end_; - util::scoped_array words_; + WordIndex words_[kMaxOrder]; const unsigned char order_minus_2_; }; @@ -586,17 +769,21 @@ void SanityCheckCounts(const std::vector &initial, const std::vector &counts, const Config &config, TrieSearch &out, Backing &backing) { SortedFileReader inputs[counts.size() - 1]; + ContextReader contexts[counts.size() - 1]; for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; assembled << file_prefix << static_cast(i) << "_merged"; inputs[i-2].Init(assembled.str(), i); - unlink(assembled.str().c_str()); + RemoveOrThrow(assembled.str().c_str()); + assembled << kContextSuffix; + contexts[i-2].Reset(assembled.str().c_str(), (i-1) * sizeof(WordIndex)); + RemoveOrThrow(assembled.str().c_str()); } std::vector fixed_counts(counts.size()); { - RecursiveInsert counter(inputs, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } SanityCheckCounts(counts, fixed_counts); @@ -609,21 +796,38 @@ void BuildTrie(const std::string &file_prefix, const std::vector &coun UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert inserter(inputs, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert inserter(inputs, contexts, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } // Fill unigram probabilities. { std::string name(file_prefix + "unigrams"); - util::scoped_FILE file(fopen(name.c_str(), "r")); - if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed"); + util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); for (WordIndex i = 0; i < counts[0]; ++i) { ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); + if (contexts[0] && **contexts[0] == i) { + SetExtension(unigrams[i].weights.backoff); + ++contexts[0]; + } } unlink(name.c_str()); } + // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. + for (unsigned char order = 2; order <= counts.size(); ++order) { + const ContextReader &context = contexts[order - 2]; + if (context) { + FormatLoadException e; + e << "An " << static_cast(order) << "-gram has the context (i.e. all but the last word):"; + for (const WordIndex *i = *context; i != *context + order - 1; ++i) { + e << ' ' << *i; + } + e << " so this context must appear in the model as a " << static_cast(order - 1) << "-gram but it does not."; + throw e; + } + } + /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. if (!out.middle.empty()) { @@ -634,19 +838,27 @@ void BuildTrie(const std::string &file_prefix, const std::vector &coun } } +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} + } // namespace void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; + if (!temporary_directory.empty() && temporary_directory[temporary_directory.size() - 1] != '/' && IsDirectory(temporary_directory.c_str())) + temporary_directory += '/'; } else if (config.write_mmap) { temporary_directory = config.write_mmap; } else { temporary_directory = file; } // Null on end is kludge to ensure null termination. - temporary_directory += "-tmp-XXXXXX"; + temporary_directory += "_trie_tmp_XXXXXX"; temporary_directory += '\0'; if (!mkdtemp(&temporary_directory[0])) { UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str()); @@ -658,7 +870,7 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v // At least 1MB sorting memory. ARPAToSortedFiles(f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory.c_str(), counts, config, *this, backing); + BuildTrie(temporary_directory, counts, config, *this, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 8c99d797..b584c82f 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -65,7 +65,6 @@ class SortedVocabulary : public base::Vocabulary { } } - // Ignores second argument for consistency with probing hash which has a float here. static size_t Size(std::size_t entries, const Config &config); // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc index 9d4fdf27..681da5f2 100644 --- a/klm/util/bit_packing.cc +++ b/klm/util/bit_packing.cc @@ -22,7 +22,7 @@ uint8_t RequiredBits(uint64_t max_value) { } void BitPackingSanity() { - const detail::FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 }; + const FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 }; if ((neg1.i ^ pos1.i) != 0x80000000) UTIL_THROW(Exception, "Sign bit is not 0x80000000"); char mem[57+8]; memset(mem, 0, sizeof(mem)); diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 636547b1..70cfc2d2 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -53,29 +53,32 @@ inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) *reinterpret_cast(base) |= (value << BitPackShift(bit, length)); } -namespace detail { typedef union { float f; uint32_t i; } FloatEnc; } +typedef union { float f; uint32_t i; } FloatEnc; + inline float ReadFloat32(const void *base, uint8_t bit) { - detail::FloatEnc encoded; + FloatEnc encoded; encoded.i = *reinterpret_cast(base) >> BitPackShift(bit, 32); return encoded.f; } inline void WriteFloat32(void *base, uint8_t bit, float value) { - detail::FloatEnc encoded; + FloatEnc encoded; encoded.f = value; WriteInt57(base, bit, 32, encoded.i); } +const uint32_t kSignBit = 0x80000000; + inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) { - detail::FloatEnc encoded; + FloatEnc encoded; encoded.i = *reinterpret_cast(base) >> BitPackShift(bit, 31); // Sign bit set means negative. - encoded.i |= 0x80000000; + encoded.i |= kSignBit; return encoded.f; } inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) { - detail::FloatEnc encoded; + FloatEnc encoded; encoded.f = value; - encoded.i &= ~0x80000000; + encoded.i &= ~kSignBit; WriteInt57(base, bit, 31, encoded.i); } diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc index 55c182bd..a82ce672 100644 --- a/klm/util/ersatz_progress.cc +++ b/klm/util/ersatz_progress.cc @@ -36,6 +36,7 @@ void ErsatzProgress::Milestone() { if (stone == kWidth) { (*out_) << std::endl; next_ = std::numeric_limits::max(); + out_ = NULL; } else { next_ = std::max(next_, (stone * complete_) / kWidth); } diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 5a667ebb..81eb9bb9 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -37,6 +37,9 @@ GZException::GZException(void *file) { #endif // HAVE_ZLIB } +// Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale). +const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; + int OpenReadOrThrow(const char *name) { int ret = open(name, O_RDONLY); if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading"); @@ -107,13 +110,6 @@ unsigned long int FilePiece::ReadULong() throw(GZException, EndOfFileException, return ReadNumber(); } -void FilePiece::SkipSpaces() throw (GZException, EndOfFileException) { - for (; ; ++position_) { - if (position_ == position_end_) Shift(); - if (!isspace(*position_)) return; - } -} - void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) { #ifdef HAVE_ZLIB gz_file_ = NULL; @@ -190,20 +186,6 @@ template T FilePiece::ReadNumber() throw(GZException, EndOfFileExcepti return ret; } -const char *FilePiece::FindDelimiterOrEOF() throw (GZException, EndOfFileException) { - for (const char *i = position_; i <= last_space_; ++i) { - if (isspace(*i)) return i; - } - while (!at_end_) { - size_t skip = position_end_ - position_; - Shift(); - for (const char *i = position_ + skip; i <= last_space_; ++i) { - if (isspace(*i)) return i; - } - } - return position_end_; -} - void FilePiece::Shift() throw(GZException, EndOfFileException) { if (at_end_) { progress_.Finished(); diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index b7697e71..f5249fcf 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -36,10 +36,13 @@ class GZException : public Exception { int OpenReadOrThrow(const char *name); +extern const bool kSpaces[256]; + // Return value for SizeFile when it can't size properly. const off_t kBadSize = -1; off_t SizeFile(int fd); +// Memory backing the returned StringPiece may vanish on the next call. class FilePiece { public: // 32 MB default. @@ -57,12 +60,12 @@ class FilePiece { return *(position_++); } - // Memory backing the returned StringPiece may vanish on the next call. - // Leaves the delimiter, if any, to be returned by get(). - StringPiece ReadDelimited() throw(GZException, EndOfFileException) { - SkipSpaces(); - return Consume(FindDelimiterOrEOF()); + // Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace(). + StringPiece ReadDelimited(const bool *delim = kSpaces) throw(GZException, EndOfFileException) { + SkipSpaces(delim); + return Consume(FindDelimiterOrEOF(delim)); } + // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. // It is similar to getline in that way. StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException); @@ -72,7 +75,13 @@ class FilePiece { long int ReadLong() throw(GZException, EndOfFileException, ParseNumberException); unsigned long int ReadULong() throw(GZException, EndOfFileException, ParseNumberException); - void SkipSpaces() throw (GZException, EndOfFileException); + // Skip spaces defined by isspace. + void SkipSpaces(const bool *delim = kSpaces) throw (GZException, EndOfFileException) { + for (; ; ++position_) { + if (position_ == position_end_) Shift(); + if (!delim[static_cast(*position_)]) return; + } + } off_t Offset() const { return position_ - data_.begin() + mapped_offset_; @@ -91,7 +100,19 @@ class FilePiece { return ret; } - const char *FindDelimiterOrEOF() throw(EndOfFileException, GZException); + const char *FindDelimiterOrEOF(const bool *delim = kSpaces) throw (GZException, EndOfFileException) { + for (const char *i = position_; i < position_end_; ++i) { + if (delim[static_cast(*i)]) return i; + } + while (!at_end_) { + size_t skip = position_end_ - position_; + Shift(); + for (const char *i = position_ + skip; i < position_end_; ++i) { + if (delim[static_cast(*i)]) return i; + } + } + return position_end_; + } void Shift() throw (EndOfFileException, GZException); // Backends to Shift(). diff --git a/klm/util/key_value_packing.hh b/klm/util/key_value_packing.hh index 450512ac..b84a5aad 100644 --- a/klm/util/key_value_packing.hh +++ b/klm/util/key_value_packing.hh @@ -18,6 +18,8 @@ template struct Entry { const Key &GetKey() const { return key; } const Value &GetValue() const { return value; } + Value &MutableValue() { return value; } + void Set(const Key &key_in, const Value &value_in) { SetKey(key_in); SetValue(value_in); @@ -77,6 +79,8 @@ template class ByteAlignedPacking { const Key &GetKey() const { return key; } const Value &GetValue() const { return value; } + Value &MutableValue() { return value; } + void Set(const Key &key_in, const Value &value_in) { SetKey(key_in); SetValue(value_in); diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 7b5cdc22..00be0ed7 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -77,6 +77,16 @@ template bool UnsafeMutableFind(const Key key, MutableIterator &out) { + for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) { + Key got(i->GetKey()); + if (equal_(got, key)) { out = i; return true; } + if (equal_(got, invalid_)) return false; + if (++i == end_) i = begin_; + } + } + template bool Find(const Key key, ConstIterator &out) const { #ifdef DEBUG assert(initialized_); @@ -84,8 +94,8 @@ template GetKey()); if (equal_(got, key)) { out = i; return true; } - if (equal_(got, invalid_)) { return false; } - if (++i == end_) { i = begin_; } + if (equal_(got, invalid_)) return false; + if (++i == end_) i = begin_; } } diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index a8e208fb..05826b51 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -62,6 +62,7 @@ template class SortedUniformMap { public: typedef PackingT Packing; typedef typename Packing::ConstIterator ConstIterator; + typedef typename Packing::MutableIterator MutableIterator; public: // Offer consistent API with probing hash. @@ -113,6 +114,15 @@ template class SortedUniformMap { *size_ptr_ = (end_ - begin_); } + // Don't use this to change the key. + template bool UnsafeMutableFind(const Key key, MutableIterator &out) { +#ifdef DEBUG + assert(initialized_); + assert(loaded_); +#endif + return SortedUniformFind(begin_, end_, key, out); + } + // Do not call before FinishedInserting. template bool Find(const Key key, ConstIterator &out) const { #ifdef DEBUG -- cgit v1.2.3