From 931a036dc3cf9e1deafc10e78e94a0ebe3c8004f Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 25 Jan 2011 22:30:48 +0200 Subject: update kenlm --- klm/lm/binary_format.cc | 4 +- klm/lm/blank.hh | 44 ++++++- klm/lm/build_binary.cc | 2 +- klm/lm/model.cc | 97 ++++++++-------- klm/lm/model.hh | 9 +- klm/lm/model_test.cc | 49 +++++++- klm/lm/ngram_query.cc | 27 +++-- klm/lm/read_arpa.cc | 17 ++- klm/lm/read_arpa.hh | 6 +- klm/lm/search_hashed.cc | 52 ++++++++- klm/lm/search_trie.cc | 302 ++++++++++++++++++++++++++++++++++++++++-------- klm/lm/vocab.hh | 1 - 12 files changed, 481 insertions(+), 129 deletions(-) (limited to 'klm/lm') diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 3d9700da..2a6aff34 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -18,8 +18,8 @@ namespace lm { namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; -const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 3\n\0"; -const long int kMagicVersion = 2; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0"; +const long int kMagicVersion = 4; // Test values. struct Sanity { diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 639bc98b..4615a09e 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -1,12 +1,52 @@ #ifndef LM_BLANK__ #define LM_BLANK__ + #include +#include +#include + namespace lm { namespace ngram { -const float kBlankProb = -std::numeric_limits::quiet_NaN(); -const float kBlankBackoff = std::numeric_limits::infinity(); +/* Suppose "foo bar" appears with zero backoff but there is no trigram + * beginning with these words. Then, when scoring "foo bar", the model could + * return out_state containing "bar" or even null context if "bar" also has no + * backoff and is never followed by another word. Then the backoff is set to + * kNoExtensionBackoff. If the n-gram might be extended, then out_state must + * contain the full n-gram, in which case kExtensionBackoff is set. In any + * case, if an n-gram has non-zero backoff, the full state is returned so + * backoff can be properly charged. + * These differ only in sign bit because the backoff is in fact zero in either + * case. + */ +const float kNoExtensionBackoff = -0.0; +const float kExtensionBackoff = 0.0; + +inline void SetExtension(float &backoff) { + if (backoff == kNoExtensionBackoff) backoff = kExtensionBackoff; +} + +// This compiles down nicely. +inline bool HasExtension(const float &backoff) { + typedef union { float f; uint32_t i; } UnionValue; + UnionValue compare, interpret; + compare.f = kNoExtensionBackoff; + interpret.f = backoff; + return compare.i != interpret.i; +} + +/* Suppose "foo bar baz quux" appears in the ARPA but not "bar baz quux" or + * "baz quux" (because they were pruned). 1.2% of n-grams generated by SRI + * with default settings on the benchmark data set are like this. Since search + * proceeds by finding "quux", "baz quux", "bar baz quux", and finally + * "foo bar baz quux" and the trie needs pointer nodes anyway, blanks are + * inserted. The blanks have probability kBlankProb and backoff kBlankBackoff. + * A blank is recognized by kBlankProb in the probability field; kBlankBackoff + * must be 0 so that inference asseses zero backoff from these blanks. + */ +const float kBlankProb = -std::numeric_limits::infinity(); +const float kBlankBackoff = kNoExtensionBackoff; } // namespace ngram } // namespace lm diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index b340797b..144c57e0 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -21,7 +21,7 @@ void Usage(const char *name) { "memory and is still faster than SRI or IRST. Building the trie format uses an\n" "on-disk sort to save memory.\n" "-t is the temporary directory prefix. Default is the output file name.\n" -"-m is the amount of memory to use, in MB. Default is 1024MB (1GB).\n\n" +"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n\n" /*"sorted is like probing but uses a sorted uniform map instead of a hash table.\n" "It uses more memory than trie and is also slower, so there's no real reason to\n" "use it.\n\n"*/ diff --git a/klm/lm/model.cc b/klm/lm/model.cc index c7ba4908..146fe07b 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -61,10 +61,10 @@ template void GenericModel counts; - // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed with search_.VariableSizeLoad + // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. ReadARPACounts(f, counts); - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); @@ -114,7 +114,24 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state); - ret.prob += SlowBackoffLookup(context_rbegin, context_rend, ret.ngram_length); + + // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin). + unsigned char start = ret.ngram_length; + if (context_rend - context_rbegin < static_cast(start)) return ret; + if (start <= 1) { + ret.prob += search_.unigram.Lookup(*context_rbegin).backoff; + start = 2; + } + typename Search::Node node; + if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { + return ret; + } + float backoff; + // i is the order of the backoff we're looking for. + for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { + if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; + ret.prob += backoff; + } return ret; } @@ -128,8 +145,7 @@ template void GenericModel void GenericModel float GenericModel::SlowBackoffLookup( - const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const { - // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin). - if (context_rend - context_rbegin < static_cast(start)) return 0.0; - float ret = 0.0; - if (start == 1) { - ret += search_.unigram.Lookup(*context_rbegin).backoff; - start = 2; - } - typename Search::Node node; - if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { - return 0.0; - } - float backoff; - // i is the order of the backoff we're looking for. - for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { - if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; - if (backoff != kBlankBackoff) ret += backoff; - } - return ret; +namespace { +// Do a paraonoid copy of history, assuming new_word has already been copied +// (hence the -1). out_state.valid_length_ could be zero so I avoided using +// std::copy. +void CopyRemainingHistory(const WordIndex *from, State &out_state) { + WordIndex *out = out_state.history_ + 1; + const WordIndex *in_end = from + static_cast(out_state.valid_length_) - 1; + for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in; } +} // namespace /* Ugly optimized function. Produce a score excluding backoff. * The search goes in increasing order of ngram length. @@ -179,28 +180,26 @@ template FullScoreReturn GenericModel::const_iterator mid_iter = search_.middle.begin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { // Ran out of history. Typically no backoff, but this could be a blank. - out_state.valid_length_ = ret.ngram_length; - std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); + CopyRemainingHistory(context_rbegin, out_state); // ret.prob was already set. return ret; } @@ -210,32 +209,32 @@ template FullScoreReturn GenericModel(state.valid_length_) << ':'; + for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) { + o << ' ' << *i; + } + return o; +} + namespace { #define StartTest(word, ngram, score) \ @@ -17,7 +26,15 @@ namespace { out);\ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); \ - BOOST_CHECK_EQUAL(std::min(ngram, 5 - 1), out.valid_length_); + BOOST_CHECK_GE(std::min(ngram, 5 - 1), out.valid_length_); \ + {\ + WordIndex context[state.valid_length_ + 1]; \ + context[0] = model.GetVocabulary().Index(word); \ + std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \ + State get_state; \ + model.GetState(context, context + state.valid_length_ + 1, get_state); \ + BOOST_CHECK_EQUAL(out, get_state); \ + } #define AppendTest(word, ngram, score) \ StartTest(word, ngram, score) \ @@ -52,10 +69,13 @@ template void Continuation(const M &model) { AppendTest("more", 1, -1.20632 - 20.0); AppendTest(".", 2, -0.51363); AppendTest("", 3, -0.0191651); + BOOST_CHECK_EQUAL(0, state.valid_length_); state = preserve; AppendTest("more", 5, -0.00181395); + BOOST_CHECK_EQUAL(4, state.valid_length_); AppendTest("loin", 5, -0.0432557); + BOOST_CHECK_EQUAL(1, state.valid_length_); } template void Blanks(const M &model) { @@ -68,6 +88,7 @@ template void Blanks(const M &model) { State preserve = state; AppendTest("higher", 4, -4); AppendTest("looking", 5, -5); + BOOST_CHECK_EQUAL(1, state.valid_length_); state = preserve; AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103); @@ -94,6 +115,29 @@ template void Unknowns(const M &model) { AppendTest("not_found3", 3, -6); } +template void MinimalState(const M &model) { + FullScoreReturn ret; + State state(model.NullContextState()); + State out; + + AppendTest("baz", 1, -6.535897); + BOOST_CHECK_EQUAL(0, state.valid_length_); + state = model.NullContextState(); + AppendTest("foo", 1, -3.141592); + BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("bar", 2, -6.0); + // Has to include the backoff weight. + BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("bar", 1, -2.718281 + 3.0); + BOOST_CHECK_EQUAL(1, state.valid_length_); + + state = model.NullContextState(); + AppendTest("to", 1, -1.687872); + AppendTest("look", 2, -0.2922095); + BOOST_CHECK_EQUAL(2, state.valid_length_); + AppendTest("good", 3, -7); +} + #define StatelessTest(word, provide, ngram, score) \ ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ @@ -154,6 +198,7 @@ template void Everything(const M &m) { Continuation(m); Blanks(m); Unknowns(m); + MinimalState(m); Stateless(m); } @@ -167,7 +212,7 @@ class ExpectEnumerateVocab : public EnumerateVocab { } void Check(const base::Vocabulary &vocab) { - BOOST_CHECK_EQUAL(34ULL, seen.size()); + BOOST_CHECK_EQUAL(37ULL, seen.size()); BOOST_REQUIRE(!seen.empty()); BOOST_CHECK_EQUAL("", seen[0]); for (WordIndex i = 0; i < seen.size(); ++i) { diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 3fa8cb03..d6da02e3 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -6,6 +6,8 @@ #include #include +#include + #include #include @@ -43,35 +45,38 @@ template void Query(const Model &model) { state = model.BeginSentenceState(); float total = 0.0; bool got = false; + unsigned int oov = 0; while (std::cin >> word) { got = true; lm::WordIndex vocab = model.GetVocabulary().Index(word); + if (vocab == 0) ++oov; ret = model.FullScore(state, vocab, out); total += ret.prob; std::cout << word << '=' << vocab << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\n'; state = out; - if (std::cin.get() == '\n') break; + char c; + while (true) { + c = std::cin.get(); + if (!std::cin) break; + if (c == '\n') break; + if (!isspace(c)) { + std::cin.unget(); + break; + } + } + if (c == '\n') break; } if (!got && !std::cin) break; ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); total += ret.prob; std::cout << "=" << model.GetVocabulary().EndSentence() << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\n'; - std::cout << "Total: " << total << '\n'; + std::cout << "Total: " << total << " OOV: " << oov << '\n'; } PrintUsage("After queries:\n"); } -class PrintVocab : public lm::ngram::EnumerateVocab { - public: - void Add(lm::WordIndex index, const StringPiece &str) { - std::cerr << "vocab " << index << ' ' << str << '\n'; - } -}; - template void Query(const char *name) { lm::ngram::Config config; - PrintVocab printer; - config.enumerate_vocab = &printer; Model model(name, config); Query(model); } diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 262a9c6a..d0fe67f0 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -1,5 +1,7 @@ #include "lm/read_arpa.hh" +#include "lm/blank.hh" + #include #include @@ -8,6 +10,9 @@ namespace lm { +// 1 for '\t', '\n', and ' '. This is stricter than isspace. +const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; + namespace { bool IsEntirelyWhiteSpace(const StringPiece &line) { @@ -116,21 +121,27 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) { case '\n': break; default: - UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); + UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff"); } } void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) { + // Always make zero negative. + // Negative zero means that no (n+1)-gram has this n-gram as context. + // Therefore the hypothesis state can be shorter. Of course, many n-grams + // are context for (n+1)-grams. An algorithm in the data structure will go + // back and set the backoff to positive zero in these cases. switch (in.get()) { case '\t': weights.backoff = in.ReadFloat(); + if (weights.backoff == ngram::kExtensionBackoff) weights.backoff = ngram::kNoExtensionBackoff; if ((in.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); break; case '\n': - weights.backoff = 0.0; + weights.backoff = ngram::kNoExtensionBackoff; break; default: - UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); + UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff"); } } diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh index 571fcbc5..4efdd29d 100644 --- a/klm/lm/read_arpa.hh +++ b/klm/lm/read_arpa.hh @@ -23,12 +23,14 @@ void ReadBackoff(util::FilePiece &in, ProbBackoff &weights); void ReadEnd(util::FilePiece &in); void ReadEnd(std::istream &in); +extern const bool kARPASpaces[256]; + template void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams) { try { float prob = f.ReadFloat(); if (prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << prob); if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); - ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())]; + ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))]; value.prob = prob; ReadBackoff(f, value); } catch(util::Exception &e) { @@ -50,7 +52,7 @@ template void ReadNGram(util::FilePiece &f, const uns weights.prob = f.ReadFloat(); if (weights.prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << weights.prob); for (WordIndex *vocab_out = reverse_indices + n - 1; vocab_out >= reverse_indices; --vocab_out) { - *vocab_out = vocab.Index(f.ReadDelimited()); + *vocab_out = vocab.Index(f.ReadDelimited(kARPASpaces)); } ReadBackoff(f, weights); } catch(util::Exception &e) { diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 9200aeb6..00d03f4e 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -14,7 +14,41 @@ namespace ngram { namespace { -template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector &middle, Store &store) { +/* These are passed to ReadNGrams so that n-grams with zero backoff that appear as context will still be used in state. */ +template class ActivateLowerMiddle { + public: + explicit ActivateLowerMiddle(Middle &middle) : modify_(middle) {} + + void operator()(const WordIndex *vocab_ids, const unsigned int n) { + uint64_t hash = static_cast(vocab_ids[1]); + for (const WordIndex *i = vocab_ids + 2; i < vocab_ids + n; ++i) { + hash = detail::CombineWordHash(hash, *i); + } + typename Middle::MutableIterator i; + // TODO: somehow get text of n-gram for this error message. + if (!modify_.UnsafeMutableFind(hash, i)) + UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram"); + SetExtension(i->MutableValue().backoff); + } + + private: + Middle &modify_; +}; + +class ActivateUnigram { + public: + explicit ActivateUnigram(ProbBackoff *unigram) : modify_(unigram) {} + + void operator()(const WordIndex *vocab_ids, const unsigned int /*n*/) { + // assert(n == 2); + SetExtension(modify_[vocab_ids[1]].backoff); + } + + private: + ProbBackoff *modify_; +}; + +template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector &middle, Activate activate, Store &store) { ReadNGramHeader(f, n); ProbBackoff blank; @@ -38,6 +72,7 @@ template void ReadNGrams(util::FilePiece if (middle[lower].Find(keys[lower], found)) break; middle[lower].Insert(Middle::Packing::Make(keys[lower], blank)); } + activate(vocab_ids, n); } store.FinishedInserting(); @@ -53,12 +88,19 @@ template template void TemplateHashe Read1Grams(f, counts[0], vocab, unigram.Raw()); try { - for (unsigned int n = 2; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle, middle[n-2]); + if (counts.size() > 2) { + ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0]); + } + for (unsigned int n = 3; n < counts.size(); ++n) { + ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle(middle[n-3]), middle[n-2]); + } + if (counts.size() > 2) { + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest); + } else { + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle(middle.back()), longest); } - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, longest); } catch (util::ProbingSizeException &e) { - UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces. "); + UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); } } diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 3aeeeca3..1060ddef 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -3,6 +3,7 @@ #include "lm/blank.hh" #include "lm/lm_exception.hh" +#include "lm/max_order.hh" #include "lm/read_arpa.hh" #include "lm/trie.hh" #include "lm/vocab.hh" @@ -27,6 +28,7 @@ #include #include #include +#include namespace lm { namespace ngram { @@ -98,7 +100,7 @@ class EntryProxy { } const WordIndex *Indices() const { - return static_cast(inner_.Data()); + return reinterpret_cast(inner_.Data()); } private: @@ -114,17 +116,57 @@ class EntryProxy { typedef util::ProxyIterator NGramIter; -class CompareRecords : public std::binary_function { +// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams. +class PartialViewProxy { + public: + PartialViewProxy() : attention_size_(0), inner_() {} + + PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {} + + operator std::string() const { + return std::string(reinterpret_cast(inner_.Data()), attention_size_); + } + + PartialViewProxy &operator=(const PartialViewProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), attention_size_); + return *this; + } + + PartialViewProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), attention_size_); + return *this; + } + + const WordIndex *Indices() const { + return reinterpret_cast(inner_.Data()); + } + + private: + friend class util::ProxyIterator; + + typedef std::string value_type; + + const std::size_t attention_size_; + + typedef EntryIterator InnerIterator; + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +typedef util::ProxyIterator PartialIter; + +template class CompareRecords : public std::binary_function { public: explicit CompareRecords(unsigned char order) : order_(order) {} - bool operator()(const EntryProxy &first, const EntryProxy &second) const { + bool operator()(const Proxy &first, const Proxy &second) const { return Compare(first.Indices(), second.Indices()); } - bool operator()(const EntryProxy &first, const std::string &second) const { + bool operator()(const Proxy &first, const std::string &second) const { return Compare(first.Indices(), reinterpret_cast(second.data())); } - bool operator()(const std::string &first, const EntryProxy &second) const { + bool operator()(const std::string &first, const Proxy &second) const { return Compare(reinterpret_cast(first.data()), second.Indices()); } bool operator()(const std::string &first, const std::string &second) const { @@ -144,6 +186,12 @@ class CompareRecords : public std::binary_function(order) << '_' << batch; std::string ret(assembled.str()); - util::scoped_FILE out(fopen(ret.c_str(), "w")); - if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing"); + util::scoped_FILE out(OpenOrThrow(ret.c_str(), "w")); // Compress entries that being with the same (order-1) words. for (const uint8_t *group_begin = static_cast(mem_begin); group_begin != static_cast(mem_end);) { const uint8_t *group_end; @@ -194,8 +254,7 @@ class SortedFileReader { SortedFileReader() : ended_(false) {} void Init(const std::string &name, unsigned char order) { - file_.reset(fopen(name.c_str(), "r")); - if (!file_.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " for read"); + file_.reset(OpenOrThrow(name.c_str(), "r")); header_.resize(order - 1); NextHeader(); } @@ -262,12 +321,13 @@ void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); } -void MergeSortedFiles(const char *first_name, const char *second_name, const char *out, std::size_t weights_size, unsigned char order) { +void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order) { SortedFileReader first, second; - first.Init(first_name, order); - second.Init(second_name, order); - util::scoped_FILE out_file(fopen(out, "w")); - if (!out_file.get()) UTIL_THROW(util::ErrnoException, "Could not open " << out << " for write"); + first.Init(first_name.c_str(), order); + RemoveOrThrow(first_name.c_str()); + second.Init(second_name.c_str(), order); + RemoveOrThrow(second_name.c_str()); + util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); while (!first.Ended() && !second.Ended()) { if (first.HeaderVector() < second.HeaderVector()) { CopyFullRecord(first, out_file.get(), weights_size); @@ -316,10 +376,109 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha } } -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { - if (order == 1) return; - ConvertToSorted(f, vocab, counts, mem, file_prefix, order - 1); +const char *kContextSuffix = "_contexts"; + +void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { + const size_t context_size = sizeof(WordIndex) * (order - 1); + // Sort just the contexts using the same memory. + PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); + PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); + + // TODO: __gnu_parallel::sort here. + std::sort(context_begin, context_end, CompareRecords(order - 1)); + + std::string name(ngram_file_name + kContextSuffix); + util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); + + // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. + if (context_begin == context_end) return; + PartialIter i(context_begin); + WriteOrThrow(out.get(), i->Indices(), context_size); + const WordIndex *previous = i->Indices(); + ++i; + for (; i != context_end; ++i) { + if (memcmp(previous, i->Indices(), context_size)) { + WriteOrThrow(out.get(), i->Indices(), context_size); + previous = i->Indices(); + } + } +} +class ContextReader { + public: + ContextReader() : length_(0) {} + + ContextReader(const char *name, size_t length) : file_(OpenOrThrow(name, "r")), length_(length), words_(length), valid_(true) { + ++*this; + } + + void Reset(const char *name, size_t length) { + file_.reset(OpenOrThrow(name, "r")); + length_ = length; + words_.resize(length); + valid_ = true; + ++*this; + } + + ContextReader &operator++() { + if (1 != fread(&*words_.begin(), length_, 1, file_.get())) { + if (!feof(file_.get())) + UTIL_THROW(util::ErrnoException, "Short read"); + valid_ = false; + } + return *this; + } + + const WordIndex *operator*() const { return &*words_.begin(); } + + operator bool() const { return valid_; } + + FILE *GetFile() { return file_.get(); } + + private: + util::scoped_FILE file_; + + size_t length_; + + std::vector words_; + + bool valid_; +}; + +void MergeContextFiles(const std::string &first_base, const std::string &second_base, const std::string &out_base, unsigned char order) { + const size_t context_size = sizeof(WordIndex) * (order - 1); + std::string first_name(first_base + kContextSuffix); + std::string second_name(second_base + kContextSuffix); + ContextReader first(first_name.c_str(), context_size), second(second_name.c_str(), context_size); + RemoveOrThrow(first_name.c_str()); + RemoveOrThrow(second_name.c_str()); + std::string out_name(out_base + kContextSuffix); + util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w")); + while (first && second) { + for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) { + if (f == *first + order) { + // Equal. + WriteOrThrow(out.get(), *first, context_size); + ++first; + ++second; + break; + } + if (*f < *s) { + // First lower + WriteOrThrow(out.get(), *first, context_size); + ++first; + break; + } else if (*f > *s) { + WriteOrThrow(out.get(), *second, context_size); + ++second; + break; + } + } + } + CopyRestOrThrow((first ? first : second).GetFile(), out.get()); +} + +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { ReadNGramHeader(f, order); const size_t count = counts[order - 1]; // Size of weights. Does it include backoff? @@ -341,11 +500,13 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size)); } } - // TODO: __gnu_parallel::sort here. + // Sort full records by full n-gram. EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order)); - + // TODO: __gnu_parallel::sort here. + std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order)); files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); + WriteContextFile(begin, out_end, files.back(), entry_size, order); + done += (out_end - begin) / entry_size; } @@ -356,10 +517,9 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st std::stringstream assembled; assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); files.push_back(assembled.str()); - MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), weights_size, order); - if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); + MergeSortedFiles(files[0], files[1], files.back(), weights_size, order); + MergeContextFiles(files[0], files[1], files.back(), order); files.pop_front(); - if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); files.pop_front(); } if (!files.empty()) { @@ -367,6 +527,9 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st assembled << file_prefix << static_cast(order) << "_merged"; std::string merged_name(assembled.str()); if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); + std::string context_name = files[0] + kContextSuffix; + merged_name += kContextSuffix; + if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); } } @@ -378,26 +541,38 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector &counts, Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get())); } + // Only use as much buffer as we need. + size_t buffer_use = 0; + for (unsigned int order = 2; order < counts.size(); ++order) { + buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); + } + buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); + buffer = std::min(buffer, buffer_use); + util::scoped_memory mem; mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); - ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size()); + + for (unsigned char order = 2; order <= counts.size(); ++order) { + ConvertToSorted(f, vocab, counts, mem, file_prefix, order); + } ReadEnd(f); } bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) { for (; words != words_end; ++words, ++header) { if (*words != *header) { - assert(*words <= *header); + //assert(*words <= *header); return false; } } return true; } +// Counting phrase class JustCount { public: - JustCount(UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) + JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) : counts_(counts), longest_counts_(counts + order - 1) {} void Unigrams(WordIndex begin, WordIndex end) { @@ -408,7 +583,7 @@ class JustCount { ++counts_[mid_idx + 1]; } - void Middle(const unsigned char mid_idx, WordIndex /*key*/, const ProbBackoff &/*weights*/) { + void Middle(const unsigned char mid_idx, const WordIndex * /*before*/, WordIndex /*key*/, const ProbBackoff &/*weights*/) { ++counts_[mid_idx + 1]; } @@ -427,7 +602,8 @@ class JustCount { class WriteEntries { public: - WriteEntries(UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + contexts_(contexts), unigrams_(unigrams), middle_(middle), longest_(longest), @@ -444,7 +620,13 @@ class WriteEntries { middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff); } - void Middle(const unsigned char mid_idx, WordIndex key, const ProbBackoff &weights) { + void Middle(const unsigned char mid_idx, const WordIndex *before, WordIndex key, ProbBackoff weights) { + // Order (mid_idx+2). + ContextReader &context = contexts_[mid_idx + 1]; + if (context && !memcmp(before, *context, sizeof(WordIndex) * (mid_idx + 1)) && (*context)[mid_idx + 1] == key) { + SetExtension(weights.backoff); + ++context; + } middle_[mid_idx].Insert(key, weights.prob, weights.backoff); } @@ -455,6 +637,7 @@ class WriteEntries { void Cleanup() {} private: + ContextReader *contexts_; UnigramValue *const unigrams_; BitPackedMiddle *const middle_; BitPackedLongest &longest_; @@ -463,14 +646,15 @@ class WriteEntries { template class RecursiveInsert { public: - RecursiveInsert(SortedFileReader *inputs, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : - doing_(unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), words_(new WordIndex[order]), order_minus_2_(order - 2) { + RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : + doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { } // Outer unigram loop. void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) { util::ErsatzProgress progress(progress_out, message, unigram_count + 1); for (words_[0] = 0; ; ++words_[0]) { + progress.Set(words_[0]); WordIndex min_continue = unigram_count; for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) { if (other->Ended()) continue; @@ -479,7 +663,6 @@ template class RecursiveInsert { // This will write at unigram_count. This is by design so that the next pointers will make sense. doing_.Unigrams(words_[0], min_continue + 1); if (min_continue == unigram_count) break; - progress += min_continue - words_[0]; words_[0] = min_continue; Middle(0); } @@ -497,7 +680,7 @@ template class RecursiveInsert { SortedFileReader &reader = inputs_[mid_idx]; - if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + mid_idx + 1, reader.Header())) { + if (reader.Ended() || !HeadMatch(words_, words_ + mid_idx + 1, reader.Header())) { // This order doesn't have a header match, but longer ones might. MiddleAllBlank(mid_idx); return; @@ -509,7 +692,7 @@ template class RecursiveInsert { while (count) { WordIndex min_continue = std::numeric_limits::max(); for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header())) + if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); } while (true) { @@ -521,7 +704,7 @@ template class RecursiveInsert { } ProbBackoff weights; reader.ReadWeights(weights); - doing_.Middle(mid_idx, current, weights); + doing_.Middle(mid_idx, words_, current, weights); --count; if (current == min_continue) { words_[mid_idx + 1] = min_continue; @@ -542,7 +725,7 @@ template class RecursiveInsert { while (true) { WordIndex min_continue = std::numeric_limits::max(); for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header())) + if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); } if (min_continue == std::numeric_limits::max()) return; @@ -554,7 +737,7 @@ template class RecursiveInsert { void Longest() { SortedFileReader &reader = *(inputs_end_ - 1); - if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + order_minus_2_ + 1, reader.Header())) return; + if (reader.Ended() || !HeadMatch(words_, words_ + order_minus_2_ + 1, reader.Header())) return; WordIndex count = reader.ReadCount(); for (WordIndex i = 0; i < count; ++i) { WordIndex word = reader.ReadWord(); @@ -571,7 +754,7 @@ template class RecursiveInsert { SortedFileReader *inputs_; SortedFileReader *inputs_end_; - util::scoped_array words_; + WordIndex words_[kMaxOrder]; const unsigned char order_minus_2_; }; @@ -586,17 +769,21 @@ void SanityCheckCounts(const std::vector &initial, const std::vector &counts, const Config &config, TrieSearch &out, Backing &backing) { SortedFileReader inputs[counts.size() - 1]; + ContextReader contexts[counts.size() - 1]; for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; assembled << file_prefix << static_cast(i) << "_merged"; inputs[i-2].Init(assembled.str(), i); - unlink(assembled.str().c_str()); + RemoveOrThrow(assembled.str().c_str()); + assembled << kContextSuffix; + contexts[i-2].Reset(assembled.str().c_str(), (i-1) * sizeof(WordIndex)); + RemoveOrThrow(assembled.str().c_str()); } std::vector fixed_counts(counts.size()); { - RecursiveInsert counter(inputs, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } SanityCheckCounts(counts, fixed_counts); @@ -609,21 +796,38 @@ void BuildTrie(const std::string &file_prefix, const std::vector &coun UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert inserter(inputs, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert inserter(inputs, contexts, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } // Fill unigram probabilities. { std::string name(file_prefix + "unigrams"); - util::scoped_FILE file(fopen(name.c_str(), "r")); - if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed"); + util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); for (WordIndex i = 0; i < counts[0]; ++i) { ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); + if (contexts[0] && **contexts[0] == i) { + SetExtension(unigrams[i].weights.backoff); + ++contexts[0]; + } } unlink(name.c_str()); } + // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. + for (unsigned char order = 2; order <= counts.size(); ++order) { + const ContextReader &context = contexts[order - 2]; + if (context) { + FormatLoadException e; + e << "An " << static_cast(order) << "-gram has the context (i.e. all but the last word):"; + for (const WordIndex *i = *context; i != *context + order - 1; ++i) { + e << ' ' << *i; + } + e << " so this context must appear in the model as a " << static_cast(order - 1) << "-gram but it does not."; + throw e; + } + } + /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. if (!out.middle.empty()) { @@ -634,19 +838,27 @@ void BuildTrie(const std::string &file_prefix, const std::vector &coun } } +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} + } // namespace void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; + if (!temporary_directory.empty() && temporary_directory[temporary_directory.size() - 1] != '/' && IsDirectory(temporary_directory.c_str())) + temporary_directory += '/'; } else if (config.write_mmap) { temporary_directory = config.write_mmap; } else { temporary_directory = file; } // Null on end is kludge to ensure null termination. - temporary_directory += "-tmp-XXXXXX"; + temporary_directory += "_trie_tmp_XXXXXX"; temporary_directory += '\0'; if (!mkdtemp(&temporary_directory[0])) { UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str()); @@ -658,7 +870,7 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v // At least 1MB sorting memory. ARPAToSortedFiles(f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory.c_str(), counts, config, *this, backing); + BuildTrie(temporary_directory, counts, config, *this, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 8c99d797..b584c82f 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -65,7 +65,6 @@ class SortedVocabulary : public base::Vocabulary { } } - // Ignores second argument for consistency with probing hash which has a float here. static size_t Size(std::size_t entries, const Config &config); // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. -- cgit v1.2.3