From f111672dd611f78656fceb3df3729a290453ef56 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 21 Sep 2011 18:23:50 -0400 Subject: Updated kenlm. Includes left state support but not the cdec-side use of it. Updated binary format. --- klm/lm/bhiksha.hh | 2 +- klm/lm/binary_format.cc | 16 +- klm/lm/binary_format.hh | 20 +- klm/lm/blank.hh | 14 - klm/lm/left.hh | 181 ++++++++ klm/lm/left_test.cc | 360 +++++++++++++++ klm/lm/model.cc | 132 ++++-- klm/lm/model.hh | 47 +- klm/lm/model_test.cc | 184 +++++--- klm/lm/model_type.hh | 16 + klm/lm/quantize.cc | 4 +- klm/lm/quantize.hh | 13 +- klm/lm/return.hh | 39 ++ klm/lm/search_hashed.cc | 79 +++- klm/lm/search_hashed.hh | 43 +- klm/lm/search_trie.cc | 1052 ++++++++++++++----------------------------- klm/lm/search_trie.hh | 37 +- klm/lm/trie.cc | 5 +- klm/lm/trie.hh | 10 +- klm/lm/trie_sort.cc | 261 +++++++++++ klm/lm/trie_sort.hh | 94 ++++ klm/lm/virtual_interface.hh | 26 +- klm/lm/vocab.cc | 44 +- klm/lm/vocab.hh | 10 +- 24 files changed, 1750 insertions(+), 939 deletions(-) create mode 100644 klm/lm/left.hh create mode 100644 klm/lm/left_test.cc create mode 100644 klm/lm/model_type.hh create mode 100644 klm/lm/return.hh create mode 100644 klm/lm/trie_sort.cc create mode 100644 klm/lm/trie_sort.hh (limited to 'klm/lm') diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index cfb2b053..ff7fe452 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -12,7 +12,7 @@ #include -#include "lm/binary_format.hh" +#include "lm/model_type.hh" #include "lm/trie.hh" #include "util/bit_packing.hh" #include "util/sorted_uniform.hh" diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index e02e621a..27cada13 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -19,10 +19,10 @@ namespace lm { namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; -const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0"; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; // This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; -const long int kMagicVersion = 4; +const long int kMagicVersion = 5; // Test values. struct Sanity { @@ -42,12 +42,6 @@ struct Sanity { const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; -std::size_t Align8(std::size_t in) { - std::size_t off = in % 8; - if (!off) return in; - return in + 8 - off; -} - std::size_t TotalHeaderSize(unsigned char order) { return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); } @@ -119,7 +113,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t } } -void FinishFile(const Config &config, ModelType model_type, const std::vector &counts, Backing &backing) { +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, Backing &backing) { if (config.write_mmap) { if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC)) UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap); @@ -130,6 +124,7 @@ void FinishFile(const Config &config, ModelType model_type, const std::vector(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *))) UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast(params.fixed.model_type) << " but this is not implemented for in this inference code."); UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]); } + UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version); } void SeekPastHeader(int fd, const Parameters ¶ms) { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index d28cb6c5..e9df0892 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -2,6 +2,7 @@ #define LM_BINARY_FORMAT__ #include "lm/config.hh" +#include "lm/model_type.hh" #include "lm/read_arpa.hh" #include "util/file_piece.hh" @@ -16,13 +17,6 @@ namespace lm { namespace ngram { -/* Not the best numbering system, but it grew this way for historical reasons - * and I want to preserve existing binary files. */ -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; - -const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); -const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); - /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in * this header designed for use by decoder authors. @@ -36,8 +30,14 @@ struct FixedWidthParameters { ModelType model_type; // Does the end of the file have the actual strings in the vocabulary? bool has_vocabulary; + unsigned int search_version; }; +inline std::size_t Align8(std::size_t in) { + std::size_t off = in % 8; + return off ? (in + 8 - off) : in; +} + // Parameters stored in the header of a binary file. struct Parameters { FixedWidthParameters fixed; @@ -64,7 +64,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t // Write header to binary file. This is done last to prevent incomplete files // from loading. -void FinishFile(const Config &config, ModelType model_type, const std::vector &counts, Backing &backing); +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, Backing &backing); namespace detail { @@ -72,7 +72,7 @@ bool IsBinaryFormat(int fd); void ReadHeader(int fd, Parameters ¶ms); -void MatchCheck(ModelType model_type, const Parameters ¶ms); +void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms); void SeekPastHeader(int fd, const Parameters ¶ms); @@ -90,7 +90,7 @@ template void LoadLM(const char *file, const Config &config, To &to) if (detail::IsBinaryFormat(backing.file.get())) { Parameters params; detail::ReadHeader(backing.file.get(), params); - detail::MatchCheck(To::kModelType, params); + detail::MatchCheck(To::kModelType, To::kVersion, params); // Replace the run-time configured probing_multiplier with the one in the file. Config new_config(config); new_config.probing_multiplier = params.fixed.probing_multiplier; diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 162411a9..2fb64cd0 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -38,20 +38,6 @@ inline bool HasExtension(const float &backoff) { return compare.i != interpret.i; } -/* Suppose "foo bar baz quux" appears in the ARPA but not "bar baz quux" or - * "baz quux" (because they were pruned). 1.2% of n-grams generated by SRI - * with default settings on the benchmark data set are like this. Since search - * proceeds by finding "quux", "baz quux", "bar baz quux", and finally - * "foo bar baz quux" and the trie needs pointer nodes anyway, blanks are - * inserted. The blanks have probability kBlankProb and backoff kBlankBackoff. - * A blank is recognized by kBlankProb in the probability field; kBlankBackoff - * must be 0 so that inference asseses zero backoff from these blanks. - */ -const float kBlankProb = -std::numeric_limits::infinity(); -const float kBlankBackoff = kNoExtensionBackoff; -const uint32_t kBlankProbQuant = 0; -const uint32_t kBlankBackoffQuant = 0; - } // namespace ngram } // namespace lm #endif // LM_BLANK__ diff --git a/klm/lm/left.hh b/klm/lm/left.hh new file mode 100644 index 00000000..df69e97a --- /dev/null +++ b/klm/lm/left.hh @@ -0,0 +1,181 @@ +#ifndef LM_LEFT__ +#define LM_LEFT__ + +#include "lm/max_order.hh" +#include "lm/model.hh" +#include "lm/return.hh" + +#include + +namespace lm { +namespace ngram { + +struct Left { + bool operator==(const Left &other) const { + return + (length == other.length) && + pointers[length - 1] == other.pointers[length - 1]; + } + + int Compare(const Left &other) const { + if (length != other.length) { + return (int)length - (int)other.length; + } + if (pointers[length - 1] > other.pointers[length - 1]) return 1; + if (pointers[length - 1] < other.pointers[length - 1]) return -1; + return 0; + } + + uint64_t pointers[kMaxOrder - 1]; + unsigned char length; +}; + +struct ChartState { + bool operator==(const ChartState &other) { + return (left == other.left) && (right == other.right) && (full == other.full); + } + + int Compare(const ChartState &other) const { + int lres = left.Compare(other.left); + if (lres) return lres; + int rres = right.Compare(other.right); + if (rres) return rres; + return (int)full - (int)other.full; + } + + Left left; + State right; + bool full; +}; + +template class RuleScore { + public: + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { + out.left.length = 0; + out.right.length = 0; + } + + void BeginSentence() { + out_.right = model_.BeginSentenceState(); + // out_.left is empty. + left_done_ = true; + } + + void Terminal(WordIndex word) { + State copy(out_.right); + FullScoreReturn ret = model_.FullScore(copy, word, out_.right); + ProcessRet(ret); + if (out_.right.length != copy.length + 1) left_done_ = true; + } + + // Faster version of NonTerminal for the case where the rule begins with a non-terminal. + void BeginNonTerminal(const ChartState &in, float prob) { + prob_ = prob; + out_ = in; + left_write_ = out_.left.pointers + out_.left.length; + left_done_ = in.full; + } + + void NonTerminal(const ChartState &in, float prob) { + prob_ += prob; + + if (!in.left.length) { + if (in.full) { + for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + } + return; + } + + if (!out_.right.length) { + out_.right = in.right; + if (left_done_) return; + if (left_write_ != out_.left.pointers) { + left_done_ = true; + } else { + out_.left = in.left; + left_write_ = out_.left.pointers + in.left.length; + left_done_ = in.full; + } + return; + } + + float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; + float *back = backoffs, *back2 = backoffs2; + unsigned char next_use; + FullScoreReturn ret; + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + unsigned char extend_length = 2; + for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) { + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + std::swap(back, back2); + } + + if (in.full) { + for (const float *i = back; i != back + next_use; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + return; + } + + // Right state was minimized, so it's already independent of the new words to the left. + if (in.right.length < in.left.length) { + out_.right = in.right; + return; + } + + // Shift exisiting words down. + for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { + *(i + in.right.length) = *i; + } + // Add words from in.right. + std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); + // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. + std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); + std::copy(back, back + next_use, out_.right.backoff + in.right.length); + out_.right.length = in.right.length + next_use; + } + + float Finish() { + out_.left.length = left_write_ - out_.left.pointers; + out_.full = left_done_; + return prob_; + } + + private: + void ProcessRet(const FullScoreReturn &ret) { + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + *(left_write_++) = ret.extend_left; + } + + const M &model_; + + ChartState &out_; + + bool left_done_; + + uint64_t *left_write_; + + float prob_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_LEFT__ diff --git a/klm/lm/left_test.cc b/klm/lm/left_test.cc new file mode 100644 index 00000000..8bb91cb3 --- /dev/null +++ b/klm/lm/left_test.cc @@ -0,0 +1,360 @@ +#include "lm/left.hh" +#include "lm/model.hh" + +#include "util/tokenize_piece.hh" + +#include + +#define BOOST_TEST_MODULE LeftTest +#include +#include + +namespace lm { +namespace ngram { +namespace { + +#define Term(word) score.Terminal(m.GetVocabulary().Index(word)); +#define VCheck(word, value) BOOST_CHECK_EQUAL(m.GetVocabulary().Index(word), value); + +template void Short(const M &m) { + ChartState base; + { + RuleScore score(m, base); + Term("more"); + Term("loin"); + BOOST_CHECK_CLOSE(-1.206319 - 0.3561665, score.Finish(), 0.001); + } + BOOST_CHECK(base.full); + BOOST_CHECK_EQUAL(2, base.left.length); + BOOST_CHECK_EQUAL(1, base.right.length); + VCheck("loin", base.right.words[0]); + + ChartState more_left; + { + RuleScore score(m, more_left); + Term("little"); + score.NonTerminal(base, -1.206319 - 0.3561665); + // p(little more loin | null context) + BOOST_CHECK_CLOSE(-1.56538, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(3, more_left.left.length); + BOOST_CHECK_EQUAL(1, more_left.right.length); + VCheck("loin", more_left.right.words[0]); + BOOST_CHECK(more_left.full); + + ChartState shorter; + { + RuleScore score(m, shorter); + Term("to"); + score.NonTerminal(base, -1.206319 - 0.3561665); + BOOST_CHECK_CLOSE(-0.30103 - 1.687872 - 1.206319 - 0.3561665, score.Finish(), 0.01); + } + BOOST_CHECK_EQUAL(1, shorter.left.length); + BOOST_CHECK_EQUAL(1, shorter.right.length); + VCheck("loin", shorter.right.words[0]); + BOOST_CHECK(shorter.full); +} + +template void Charge(const M &m) { + ChartState base; + { + RuleScore score(m, base); + Term("on"); + Term("more"); + BOOST_CHECK_CLOSE(-1.509559 -0.4771212 -1.206319, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(1, base.left.length); + BOOST_CHECK_EQUAL(1, base.right.length); + VCheck("more", base.right.words[0]); + BOOST_CHECK(base.full); + + ChartState extend; + { + RuleScore score(m, extend); + Term("looking"); + score.NonTerminal(base, -1.509559 -0.4771212 -1.206319); + BOOST_CHECK_CLOSE(-3.91039, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, extend.left.length); + BOOST_CHECK_EQUAL(1, extend.right.length); + VCheck("more", extend.right.words[0]); + BOOST_CHECK(extend.full); + + ChartState tobos; + { + RuleScore score(m, tobos); + score.BeginSentence(); + score.NonTerminal(extend, -3.91039); + BOOST_CHECK_CLOSE(-3.471169, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(0, tobos.left.length); + BOOST_CHECK_EQUAL(1, tobos.right.length); +} + +template float LeftToRight(const M &m, const std::vector &words) { + float ret = 0.0; + State right = m.NullContextState(); + for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { + State copy(right); + ret += m.Score(copy, *i, right); + } + return ret; +} + +template float RightToLeft(const M &m, const std::vector &words) { + float ret = 0.0; + ChartState state; + state.left.length = 0; + state.right.length = 0; + state.full = false; + for (std::vector::const_reverse_iterator i = words.rbegin(); i != words.rend(); ++i) { + ChartState copy(state); + RuleScore score(m, state); + score.Terminal(*i); + score.NonTerminal(copy, ret); + ret = score.Finish(); + } + return ret; +} + +template float TreeMiddle(const M &m, const std::vector &words) { + std::vector > states(words.size()); + for (unsigned int i = 0; i < words.size(); ++i) { + RuleScore score(m, states[i].first); + score.Terminal(words[i]); + states[i].second = score.Finish(); + } + while (states.size() > 1) { + std::vector > upper((states.size() + 1) / 2); + for (unsigned int i = 0; i < states.size() / 2; ++i) { + RuleScore score(m, upper[i].first); + score.NonTerminal(states[i*2].first, states[i*2].second); + score.NonTerminal(states[i*2+1].first, states[i*2+1].second); + upper[i].second = score.Finish(); + } + if (states.size() % 2) { + upper.back() = states.back(); + } + std::swap(states, upper); + } + return states.empty() ? 0 : states.back().second; +} + +template void LookupVocab(const M &m, const StringPiece &str, std::vector &out) { + out.clear(); + for (util::PieceIterator<' '> i(str); i; ++i) { + out.push_back(m.GetVocabulary().Index(*i)); + } +} + +#define TEXT_TEST(str) \ +{ \ + std::vector words; \ + LookupVocab(m, str, words); \ + float expect = LeftToRight(m, words); \ + BOOST_CHECK_CLOSE(expect, RightToLeft(m, words), 0.001); \ + BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words), 0.001); \ +} + +// Build sentences, or parts thereof, from right to left. +template void GrowBig(const M &m) { + TEXT_TEST("in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown "); + TEXT_TEST("on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown "); + TEXT_TEST("on a little more loin also would consider higher to look good"); + TEXT_TEST("more loin also would consider higher to look good"); + TEXT_TEST("more loin also would consider higher to look"); + TEXT_TEST("also would consider higher to look"); + TEXT_TEST("also would consider higher"); + TEXT_TEST("would consider higher to look"); + TEXT_TEST("consider higher to look"); + TEXT_TEST("consider higher to"); + TEXT_TEST("consider higher"); +} + +template void AlsoWouldConsiderHigher(const M &m) { + ChartState also; + { + RuleScore score(m, also); + score.Terminal(m.GetVocabulary().Index("also")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + ChartState would; + { + RuleScore score(m, would); + score.Terminal(m.GetVocabulary().Index("would")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + ChartState combine_also_would; + { + RuleScore score(m, combine_also_would); + score.NonTerminal(also, -1.687872); + score.NonTerminal(would, -1.687872); + BOOST_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, combine_also_would.right.length); + + ChartState also_would; + { + RuleScore score(m, also_would); + score.Terminal(m.GetVocabulary().Index("also")); + score.Terminal(m.GetVocabulary().Index("would")); + BOOST_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, also_would.right.length); + + ChartState consider; + { + RuleScore score(m, consider); + score.Terminal(m.GetVocabulary().Index("consider")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(1, consider.left.length); + BOOST_CHECK_EQUAL(1, consider.right.length); + BOOST_CHECK(!consider.full); + + ChartState higher; + float higher_score; + { + RuleScore score(m, higher); + score.Terminal(m.GetVocabulary().Index("higher")); + higher_score = score.Finish(); + } + BOOST_CHECK_CLOSE(-1.509559, higher_score, 0.001); + BOOST_CHECK_EQUAL(1, higher.left.length); + BOOST_CHECK_EQUAL(1, higher.right.length); + BOOST_CHECK(!higher.full); + VCheck("higher", higher.right.words[0]); + BOOST_CHECK_CLOSE(-0.30103, higher.right.backoff[0], 0.001); + + ChartState consider_higher; + { + RuleScore score(m, consider_higher); + score.NonTerminal(consider, -1.687872); + score.NonTerminal(higher, higher_score); + BOOST_CHECK_CLOSE(-1.509559 - 1.687872 - 0.30103, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, consider_higher.left.length); + BOOST_CHECK(!consider_higher.full); + + ChartState full; + { + RuleScore score(m, full); + score.NonTerminal(combine_also_would, -1.687872 - 2.0); + score.NonTerminal(consider_higher, -1.509559 - 1.687872 - 0.30103); + BOOST_CHECK_CLOSE(-10.6879, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(4, full.right.length); +} + +template void GrowSmall(const M &m) { + TEXT_TEST("in biarritz watching considering looking . "); + TEXT_TEST("in biarritz watching considering looking ."); + TEXT_TEST("in biarritz"); +} + +#define CHECK_SCORE(str, val) \ +{ \ + float got = val; \ + std::vector indices; \ + LookupVocab(m, str, indices); \ + BOOST_CHECK_CLOSE(LeftToRight(m, indices), got, 0.001); \ +} + +template void FullGrow(const M &m) { + std::vector words; + LookupVocab(m, "in biarritz watching considering looking . ", words); + + ChartState lexical[7]; + float lexical_scores[7]; + for (unsigned int i = 0; i < 7; ++i) { + RuleScore score(m, lexical[i]); + score.Terminal(words[i]); + lexical_scores[i] = score.Finish(); + } + CHECK_SCORE("in", lexical_scores[0]); + CHECK_SCORE("biarritz", lexical_scores[1]); + CHECK_SCORE("watching", lexical_scores[2]); + CHECK_SCORE("", lexical_scores[6]); + + ChartState l1[4]; + float l1_scores[4]; + { + RuleScore score(m, l1[0]); + score.NonTerminal(lexical[0], lexical_scores[0]); + score.NonTerminal(lexical[1], lexical_scores[1]); + CHECK_SCORE("in biarritz", l1_scores[0] = score.Finish()); + } + { + RuleScore score(m, l1[1]); + score.NonTerminal(lexical[2], lexical_scores[2]); + score.NonTerminal(lexical[3], lexical_scores[3]); + CHECK_SCORE("watching considering", l1_scores[1] = score.Finish()); + } + { + RuleScore score(m, l1[2]); + score.NonTerminal(lexical[4], lexical_scores[4]); + score.NonTerminal(lexical[5], lexical_scores[5]); + CHECK_SCORE("looking .", l1_scores[2] = score.Finish()); + } + BOOST_CHECK_EQUAL(l1[2].left.length, 1); + l1[3] = lexical[6]; + l1_scores[3] = lexical_scores[6]; + + ChartState l2[2]; + float l2_scores[2]; + { + RuleScore score(m, l2[0]); + score.NonTerminal(l1[0], l1_scores[0]); + score.NonTerminal(l1[1], l1_scores[1]); + CHECK_SCORE("in biarritz watching considering", l2_scores[0] = score.Finish()); + } + { + RuleScore score(m, l2[1]); + score.NonTerminal(l1[2], l1_scores[2]); + score.NonTerminal(l1[3], l1_scores[3]); + CHECK_SCORE("looking . ", l2_scores[1] = score.Finish()); + } + BOOST_CHECK_EQUAL(l2[1].left.length, 1); + BOOST_CHECK(l2[1].full); + + ChartState top; + { + RuleScore score(m, top); + score.NonTerminal(l2[0], l2_scores[0]); + score.NonTerminal(l2[1], l2_scores[1]); + CHECK_SCORE("in biarritz watching considering looking . ", score.Finish()); + } +} + +template void Everything() { + Config config; + config.messages = NULL; + M m("test.arpa", config); + + Short(m); + Charge(m); + GrowBig(m); + AlsoWouldConsiderHigher(m); + GrowSmall(m); + FullGrow(m); +} + +BOOST_AUTO_TEST_CASE(ProbingAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(TrieAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(QuantTrieAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(ArrayQuantTrieAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(ArrayTrieAll) { + Everything(); +} + +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 27e24b1c..ca581d8a 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -16,7 +16,7 @@ namespace lm { namespace ngram { size_t hash_value(const State &state) { - return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); + return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); } namespace detail { @@ -41,11 +41,11 @@ template GenericModel::Ge // g++ prints warnings unless these are fully initialized. State begin_sentence = State(); - begin_sentence.valid_length_ = 1; - begin_sentence.history_[0] = vocab_.BeginSentence(); - begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; + begin_sentence.length = 1; + begin_sentence.words[0] = vocab_.BeginSentence(); + begin_sentence.backoff[0] = search_.unigram.Lookup(begin_sentence.words[0]).backoff; State null_context = State(); - null_context.valid_length_ = 0; + null_context.length = 0; P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); } @@ -87,7 +87,7 @@ template void GenericModel void GenericModel FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { - FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, out_state); - if (ret.ngram_length - 1 < in_state.valid_length_) { - ret.prob = std::accumulate(in_state.backoff_ + ret.ngram_length - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); + FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); + if (ret.ngram_length - 1 < in_state.length) { + ret.prob = std::accumulate(in_state.backoff + ret.ngram_length - 1, in_state.backoff + in_state.length, ret.prob); } return ret; } @@ -131,32 +131,80 @@ template void GenericModel FullScoreReturn GenericModel::ExtendLeft( + const WordIndex *add_rbegin, const WordIndex *add_rend, + const float *backoff_in, + uint64_t extend_pointer, + unsigned char extend_length, + float *backoff_out, + unsigned char &next_use) const { + FullScoreReturn ret; + float subtract_me; + typename Search::Node node(search_.Unpack(extend_pointer, extend_length, subtract_me)); + ret.prob = subtract_me; + ret.ngram_length = extend_length; + next_use = 0; + // If this function is called, then it does depend on left words. + ret.independent_left = false; + ret.extend_left = extend_pointer; + const typename Search::Middle *mid_iter = search_.MiddleBegin() + extend_length - 1; + const WordIndex *i = add_rbegin; + for (; ; ++i, ++backoff_out, ++mid_iter) { + if (i == add_rend) { + // Ran out of words. + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; + ret.prob -= subtract_me; + return ret; + } + if (mid_iter == search_.MiddleEnd()) break; + if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *i, *backoff_out, node, ret)) { + // Didn't match a word. + ret.independent_left = true; + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; + ret.prob -= subtract_me; + return ret; + } + ret.ngram_length = mid_iter - search_.MiddleBegin() + 2; + if (HasExtension(*backoff_out)) next_use = i - add_rbegin + 1; + } + + if (ret.independent_left || !search_.LookupLongest(*i, ret.prob, node)) { + // The last backoff weight, for Order() - 1. + ret.prob += backoff_in[i - add_rbegin]; + } else { + ret.ngram_length = P::Order(); + } + ret.independent_left = true; + ret.prob -= subtract_me; + return ret; } namespace { // Do a paraonoid copy of history, assuming new_word has already been copied -// (hence the -1). out_state.valid_length_ could be zero so I avoided using +// (hence the -1). out_state.length could be zero so I avoided using // std::copy. void CopyRemainingHistory(const WordIndex *from, State &out_state) { - WordIndex *out = out_state.history_ + 1; - const WordIndex *in_end = from + static_cast(out_state.valid_length_) - 1; + WordIndex *out = out_state.words + 1; + const WordIndex *in_end = from + static_cast(out_state.length) - 1; for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in; } } // namespace @@ -175,17 +223,17 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel class GenericModel : public base::Mod // This is the model type returned by RecognizeBinary. static const ModelType kModelType; + static const unsigned int kVersion = Search::kVersion; + /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class * itself. @@ -114,6 +116,25 @@ template class GenericModel : public base::Mod */ void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; + /* More efficient version of FullScore where a partial n-gram has already + * been scored. + * NOTE: THE RETURNED .prob IS RELATIVE, NOT ABSOLUTE. So for example, if + * the n-gram does not end up extending further left, then 0 is returned. + */ + FullScoreReturn ExtendLeft( + // Additional context in reverse order. This will update add_rend to + const WordIndex *add_rbegin, const WordIndex *add_rend, + // Backoff weights to use. + const float *backoff_in, + // extend_left returned by a previous query. + uint64_t extend_pointer, + // Length of n-gram that the pointer corresponds to. + unsigned char extend_length, + // Where to write additional backoffs for [extend_length + 1, min(Order() - 1, return.ngram_length)] + float *backoff_out, + // Amount of additional content that should be considered by the next call. + unsigned char &next_use) const; + private: friend void LoadLM<>(const char *file, const Config &config, GenericModel &to); diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 57c7291c..2654071f 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -10,8 +10,8 @@ namespace lm { namespace ngram { std::ostream &operator<<(std::ostream &o, const State &state) { - o << "State length " << static_cast(state.valid_length_) << ':'; - for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) { + o << "State length " << static_cast(state.length) << ':'; + for (const WordIndex *i = state.words; i < state.words + state.length; ++i) { o << ' ' << *i; } return o; @@ -19,25 +19,26 @@ std::ostream &operator<<(std::ostream &o, const State &state) { namespace { -#define StartTest(word, ngram, score) \ +#define StartTest(word, ngram, score, indep_left) \ ret = model.FullScore( \ state, \ model.GetVocabulary().Index(word), \ out);\ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); \ - BOOST_CHECK_GE(std::min(ngram, 5 - 1), out.valid_length_); \ + BOOST_CHECK_GE(std::min(ngram, 5 - 1), out.length); \ + BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \ {\ - WordIndex context[state.valid_length_ + 1]; \ + WordIndex context[state.length + 1]; \ context[0] = model.GetVocabulary().Index(word); \ - std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \ + std::copy(state.words, state.words + state.length, context + 1); \ State get_state; \ - model.GetState(context, context + state.valid_length_ + 1, get_state); \ + model.GetState(context, context + state.length + 1, get_state); \ BOOST_CHECK_EQUAL(out, get_state); \ } -#define AppendTest(word, ngram, score) \ - StartTest(word, ngram, score) \ +#define AppendTest(word, ngram, score, indep_left) \ + StartTest(word, ngram, score, indep_left) \ state = out; template void Starters(const M &model) { @@ -45,12 +46,12 @@ template void Starters(const M &model) { Model::State state(model.BeginSentenceState()); Model::State out; - StartTest("looking", 2, -0.4846522); + StartTest("looking", 2, -0.4846522, true); // , probability plus backoff - StartTest(",", 1, -1.383514 + -0.4149733); + StartTest(",", 1, -1.383514 + -0.4149733, true); // probability plus backoff - StartTest("this_is_not_found", 1, -1.995635 + -0.4149733); + StartTest("this_is_not_found", 1, -1.995635 + -0.4149733, true); } template void Continuation(const M &model) { @@ -58,46 +59,64 @@ template void Continuation(const M &model) { Model::State state(model.BeginSentenceState()); Model::State out; - AppendTest("looking", 2, -0.484652); - AppendTest("on", 3, -0.348837); - AppendTest("a", 4, -0.0155266); - AppendTest("little", 5, -0.00306122); + AppendTest("looking", 2, -0.484652, true); + AppendTest("on", 3, -0.348837, true); + AppendTest("a", 4, -0.0155266, true); + AppendTest("little", 5, -0.00306122, true); State preserve = state; - AppendTest("the", 1, -4.04005); - AppendTest("biarritz", 1, -1.9889); - AppendTest("not_found", 1, -2.29666); - AppendTest("more", 1, -1.20632 - 20.0); - AppendTest(".", 2, -0.51363); - AppendTest("", 3, -0.0191651); - BOOST_CHECK_EQUAL(0, state.valid_length_); + AppendTest("the", 1, -4.04005, true); + AppendTest("biarritz", 1, -1.9889, true); + AppendTest("not_found", 1, -2.29666, true); + AppendTest("more", 1, -1.20632 - 20.0, true); + AppendTest(".", 2, -0.51363, true); + AppendTest("", 3, -0.0191651, true); + BOOST_CHECK_EQUAL(0, state.length); state = preserve; - AppendTest("more", 5, -0.00181395); - BOOST_CHECK_EQUAL(4, state.valid_length_); - AppendTest("loin", 5, -0.0432557); - BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("more", 5, -0.00181395, true); + BOOST_CHECK_EQUAL(4, state.length); + AppendTest("loin", 5, -0.0432557, true); + BOOST_CHECK_EQUAL(1, state.length); } template void Blanks(const M &model) { FullScoreReturn ret; State state(model.NullContextState()); State out; - AppendTest("also", 1, -1.687872); - AppendTest("would", 2, -2); - AppendTest("consider", 3, -3); + AppendTest("also", 1, -1.687872, false); + AppendTest("would", 2, -2, true); + AppendTest("consider", 3, -3, true); State preserve = state; - AppendTest("higher", 4, -4); - AppendTest("looking", 5, -5); - BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("higher", 4, -4, true); + AppendTest("looking", 5, -5, true); + BOOST_CHECK_EQUAL(1, state.length); state = preserve; - AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103); + // also would consider not_found + AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103, true); state = model.NullContextState(); // higher looking is a blank. - AppendTest("higher", 1, -1.509559); - AppendTest("looking", 1, -1.285941 - 0.30103); - AppendTest("not_found", 1, -1.995635 - 0.4771212); + AppendTest("higher", 1, -1.509559, false); + AppendTest("looking", 2, -1.285941 - 0.30103, false); + + State higher_looking = state; + + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("not_found", 1, -1.995635 - 0.4771212, true); + + state = higher_looking; + // higher looking consider + AppendTest("consider", 1, -1.687872 - 0.4771212, true); + + state = model.NullContextState(); + AppendTest("would", 1, -1.687872, false); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("consider", 2, -1.687872 -0.30103, false); + BOOST_CHECK_EQUAL(2, state.length); + AppendTest("higher", 3, -1.509559 - 0.30103, false); + BOOST_CHECK_EQUAL(3, state.length); + AppendTest("looking", 4, -1.285941 - 0.30103, false); } template void Unknowns(const M &model) { @@ -105,14 +124,14 @@ template void Unknowns(const M &model) { State state(model.NullContextState()); State out; - AppendTest("not_found", 1, -1.995635); + AppendTest("not_found", 1, -1.995635, false); State preserve = state; - AppendTest("not_found2", 2, -15.0); - AppendTest("not_found3", 2, -15.0 - 2.0); + AppendTest("not_found2", 2, -15.0, true); + AppendTest("not_found3", 2, -15.0 - 2.0, true); state = preserve; - AppendTest("however", 2, -4); - AppendTest("not_found3", 3, -6); + AppendTest("however", 2, -4, true); + AppendTest("not_found3", 3, -6, true); } template void MinimalState(const M &model) { @@ -120,22 +139,66 @@ template void MinimalState(const M &model) { State state(model.NullContextState()); State out; - AppendTest("baz", 1, -6.535897); - BOOST_CHECK_EQUAL(0, state.valid_length_); + AppendTest("baz", 1, -6.535897, true); + BOOST_CHECK_EQUAL(0, state.length); state = model.NullContextState(); - AppendTest("foo", 1, -3.141592); - BOOST_CHECK_EQUAL(1, state.valid_length_); - AppendTest("bar", 2, -6.0); + AppendTest("foo", 1, -3.141592, true); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("bar", 2, -6.0, true); // Has to include the backoff weight. - BOOST_CHECK_EQUAL(1, state.valid_length_); - AppendTest("bar", 1, -2.718281 + 3.0); - BOOST_CHECK_EQUAL(1, state.valid_length_); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("bar", 1, -2.718281 + 3.0, true); + BOOST_CHECK_EQUAL(1, state.length); state = model.NullContextState(); - AppendTest("to", 1, -1.687872); - AppendTest("look", 2, -0.2922095); - BOOST_CHECK_EQUAL(2, state.valid_length_); - AppendTest("good", 3, -7); + AppendTest("to", 1, -1.687872, false); + AppendTest("look", 2, -0.2922095, true); + BOOST_CHECK_EQUAL(2, state.length); + AppendTest("good", 3, -7, true); +} + +template void ExtendLeftTest(const M &model) { + State right; + FullScoreReturn little(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("little"), right)); + const float kLittleProb = -1.285941; + BOOST_CHECK_CLOSE(kLittleProb, little.prob, 0.001); + unsigned char next_use; + float backoff_out[4]; + + FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use)); + BOOST_CHECK_EQUAL(0, next_use); + BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left); + BOOST_CHECK_CLOSE(0.0, extend_none.prob, 0.001); + BOOST_CHECK_EQUAL(1, extend_none.ngram_length); + + const WordIndex a = model.GetVocabulary().Index("a"); + float backoff_in = 3.14; + // a little + FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use)); + BOOST_CHECK_EQUAL(1, next_use); + BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.09132547 - kLittleProb, extend_a.prob, 0.001); + BOOST_CHECK_EQUAL(2, extend_a.ngram_length); + BOOST_CHECK(!extend_a.independent_left); + + const WordIndex on = model.GetVocabulary().Index("on"); + FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use)); + BOOST_CHECK_EQUAL(1, next_use); + BOOST_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.0283603 - -0.09132547, extend_on.prob, 0.001); + BOOST_CHECK_EQUAL(3, extend_on.ngram_length); + BOOST_CHECK(!extend_on.independent_left); + + const WordIndex both[2] = {a, on}; + float backoff_in_arr[4]; + FullScoreReturn extend_both(model.ExtendLeft(both, both + 2, backoff_in_arr, little.extend_left, 1, backoff_out, next_use)); + BOOST_CHECK_EQUAL(2, next_use); + BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001); + BOOST_CHECK_CLOSE(-0.0283603 - kLittleProb, extend_both.prob, 0.001); + BOOST_CHECK_EQUAL(3, extend_both.ngram_length); + BOOST_CHECK(!extend_both.independent_left); + BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left); } #define StatelessTest(word, provide, ngram, score) \ @@ -166,17 +229,17 @@ template void Stateless(const M &model) { // looking StatelessTest(1, 2, 2, -0.484652); // on - AppendTest("on", 3, -0.348837); + AppendTest("on", 3, -0.348837, true); StatelessTest(2, 3, 3, -0.348837); StatelessTest(2, 2, 3, -0.348837); StatelessTest(2, 1, 2, -0.4638903); // a StatelessTest(3, 4, 4, -0.0155266); // little - AppendTest("little", 5, -0.00306122); + AppendTest("little", 5, -0.00306122, true); StatelessTest(4, 5, 5, -0.00306122); // the - AppendTest("the", 1, -4.04005); + AppendTest("the", 1, -4.04005, true); StatelessTest(5, 5, 1, -4.04005); // No context of the. StatelessTest(5, 0, 1, -1.687872); @@ -189,8 +252,8 @@ template void Stateless(const M &model) { WordIndex unk[1]; unk[0] = 0; model.GetState(unk, unk + 1, state); - BOOST_CHECK_EQUAL(1, state.valid_length_); - BOOST_CHECK_EQUAL(static_cast(0), state.history_[0]); + BOOST_CHECK_EQUAL(1, state.length); + BOOST_CHECK_EQUAL(static_cast(0), state.words[0]); } template void NoUnkCheck(const M &model) { @@ -207,6 +270,7 @@ template void Everything(const M &m) { Blanks(m); Unknowns(m); MinimalState(m); + ExtendLeftTest(m); Stateless(m); } @@ -245,6 +309,7 @@ template void LoadingTest() { config.enumerate_vocab = &enumerate; ModelT m("test.arpa", config); enumerate.Check(m.GetVocabulary()); + BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); Everything(m); } { @@ -252,6 +317,7 @@ template void LoadingTest() { config.enumerate_vocab = &enumerate; ModelT m("test_nounk.arpa", config); enumerate.Check(m.GetVocabulary()); + BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); NoUnkCheck(m); } } diff --git a/klm/lm/model_type.hh b/klm/lm/model_type.hh new file mode 100644 index 00000000..5057ed25 --- /dev/null +++ b/klm/lm/model_type.hh @@ -0,0 +1,16 @@ +#ifndef LM_MODEL_TYPE__ +#define LM_MODEL_TYPE__ + +namespace lm { +namespace ngram { + +/* Not the best numbering system, but it grew this way for historical reasons + * and I want to preserve existing binary files. */ +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; + +const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); +const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); + +} // namespace ngram +} // namespace lm +#endif // LM_MODEL_TYPE__ diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index fd371cc8..98a5d048 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -1,5 +1,6 @@ #include "lm/quantize.hh" +#include "lm/binary_format.hh" #include "lm/lm_exception.hh" #include @@ -70,8 +71,7 @@ void SeparatelyQuantize::Train(uint8_t order, std::vector &prob, std::vec void SeparatelyQuantize::TrainProb(uint8_t order, std::vector &prob) { float *centers = start_ + TableStart(order); - *(centers++) = kBlankProb; - MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_) - 1); + MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_)); } void SeparatelyQuantize::FinishedLoading(const Config &config) { diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 0b71d14a..4cf4236e 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -1,9 +1,9 @@ #ifndef LM_QUANTIZE_H__ #define LM_QUANTIZE_H__ -#include "lm/binary_format.hh" // for ModelType #include "lm/blank.hh" #include "lm/config.hh" +#include "lm/model_type.hh" #include "util/bit_packing.hh" #include @@ -36,6 +36,9 @@ class DontQuantize { prob = util::ReadNonPositiveFloat31(base, bit_offset); backoff = util::ReadFloat32(base, bit_offset + 31); } + void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + } void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { backoff = util::ReadFloat32(base, bit_offset + 31); } @@ -77,7 +80,7 @@ class SeparatelyQuantize { Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} uint64_t EncodeProb(float value) const { - return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1)); + return Encode(value, 0); } uint64_t EncodeBackoff(float value) const { @@ -132,6 +135,10 @@ class SeparatelyQuantize { (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); } + void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { + prob = prob_.Decode(util::ReadInt25(base, bit_offset + backoff_.Bits(), prob_.Bits(), prob_.Mask())); + } + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); prob = prob_.Decode(both >> backoff_.Bits()); @@ -179,7 +186,7 @@ class SeparatelyQuantize { void SetupMemory(void *start, const Config &config); static const bool kTrain = true; - // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff. + // Assumes 0.0 is removed from backoff. void Train(uint8_t order, std::vector &prob, std::vector &backoff); // Train just probabilities (for longest order). void TrainProb(uint8_t order, std::vector &prob); diff --git a/klm/lm/return.hh b/klm/lm/return.hh new file mode 100644 index 00000000..15571960 --- /dev/null +++ b/klm/lm/return.hh @@ -0,0 +1,39 @@ +#ifndef LM_RETURN__ +#define LM_RETURN__ + +#include + +namespace lm { +/* Structure returned by scoring routines. */ +struct FullScoreReturn { + // log10 probability + float prob; + + /* The length of n-gram matched. Do not use this for recombination. + * Consider a model containing only the following n-grams: + * -1 foo + * -3.14 bar + * -2.718 baz -5 + * -6 foo bar + * + * If you score ``bar'' then ngram_length is 1 and recombination state is the + * empty string because bar has zero backoff and does not extend to the + * right. + * If you score ``foo'' then ngram_length is 1 and recombination state is + * ``foo''. + * + * Ideally, keep output states around and compare them. Failing that, + * get out_state.ValidLength() and use that length for recombination. + */ + unsigned char ngram_length; + + /* Left extension information. If independent_left is set, then prob is + * independent of words to the left (up to additional backoff). Otherwise, + * extend_left indicates how to efficiently extend further to the left. + */ + bool independent_left; + uint64_t extend_left; // Defined only if independent_left +}; + +} // namespace lm +#endif // LM_RETURN__ diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 82c53ec8..334adf12 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -1,10 +1,12 @@ #include "lm/search_hashed.hh" +#include "lm/binary_format.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/read_arpa.hh" #include "lm/vocab.hh" +#include "util/bit_packing.hh" #include "util/file_piece.hh" #include @@ -48,30 +50,77 @@ class ActivateUnigram { ProbBackoff *modify_; }; -template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector &middle, Activate activate, Store &store, PositiveProbWarn &warn) { - - ReadNGramHeader(f, n); +template void FixSRI(int lower, float negative_lower_prob, unsigned int n, const uint64_t *keys, const WordIndex *vocab_ids, ProbBackoff *unigrams, std::vector &middle) { ProbBackoff blank; - blank.prob = kBlankProb; - blank.backoff = kBlankBackoff; + blank.backoff = kNoExtensionBackoff; + // Fix SRI's stupidity. + // Note that negative_lower_prob is the negative of the probability (so it's currently >= 0). We still want the sign bit off to indicate left extension, so I just do -= on the backoffs. + blank.prob = negative_lower_prob; + // An entry was found at lower (order lower + 2). + // We need to insert blanks starting at lower + 1 (order lower + 3). + unsigned int fix = static_cast(lower + 1); + uint64_t backoff_hash = detail::CombineWordHash(static_cast(vocab_ids[1]), vocab_ids[2]); + if (fix == 0) { + // Insert a missing bigram. + blank.prob -= unigrams[vocab_ids[1]].backoff; + SetExtension(unigrams[vocab_ids[1]].backoff); + // Bigram including a unigram's backoff + middle[0].Insert(Middle::Packing::Make(keys[0], blank)); + fix = 1; + } else { + for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]); + } + // fix >= 1. Insert trigrams and above. + for (; fix <= n - 3; ++fix) { + typename Middle::MutableIterator gotit; + if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) { + float &backoff = gotit->MutableValue().backoff; + SetExtension(backoff); + blank.prob -= backoff; + } + middle[fix].Insert(Middle::Packing::Make(keys[fix], blank)); + backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]); + } +} + +template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector &middle, Activate activate, Store &store, PositiveProbWarn &warn) { + ReadNGramHeader(f, n); // vocab ids of words in reverse order WordIndex vocab_ids[n]; uint64_t keys[n - 1]; typename Store::Packing::Value value; - typename Middle::ConstIterator found; + typename Middle::MutableIterator found; for (size_t i = 0; i < count; ++i) { ReadNGram(f, n, vocab, vocab_ids, value, warn); + keys[0] = detail::CombineWordHash(static_cast(*vocab_ids), vocab_ids[1]); for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); } + // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. + util::SetSign(value.prob); store.Insert(Store::Packing::Make(keys[n-2], value)); - // Go back and insert blanks. - for (int lower = n - 3; lower >= 0; --lower) { - if (middle[lower].Find(keys[lower], found)) break; - middle[lower].Insert(Middle::Packing::Make(keys[lower], blank)); + // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. + int lower; + util::FloatEnc fix_prob; + for (lower = n - 3; ; --lower) { + if (lower == -1) { + fix_prob.f = unigrams[vocab_ids[0]].prob; + fix_prob.i &= ~util::kSignBit; + unigrams[vocab_ids[0]].prob = fix_prob.f; + break; + } + if (middle[lower].UnsafeMutableFind(keys[lower], found)) { + // Turn off sign bit to indicate that it extends left. + fix_prob.f = found->MutableValue().prob; + fix_prob.i &= ~util::kSignBit; + found->MutableValue().prob = fix_prob.f; + // We don't need to recurse further down because this entry already set the bits for lower entries. + break; + } } + if (lower != static_cast(n) - 3) FixSRI(lower, fix_prob.f, n, keys, vocab_ids, unigrams, middle); activate(vocab_ids, n); } @@ -107,15 +156,15 @@ template template void TemplateHashe try { if (counts.size() > 2) { - ReadNGrams(f, 2, counts[1], vocab, middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); + ReadNGrams(f, 2, counts[1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); } for (unsigned int n = 3; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle_, ActivateLowerMiddle(middle_[n-3]), middle_[n-2], warn); + ReadNGrams(f, n, counts[n-1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle(middle_[n-3]), middle_[n-2], warn); } if (counts.size() > 2) { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateLowerMiddle(middle_.back()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle(middle_.back()), longest, warn); } else { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateUnigram(unigram.Raw()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), longest, warn); } } catch (util::ProbingSizeException &e) { UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); @@ -133,7 +182,7 @@ template void TemplateHashedSearch; -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); +template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); } // namespace detail } // namespace ngram diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index c62985e4..e289fd11 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -1,15 +1,18 @@ #ifndef LM_SEARCH_HASHED__ #define LM_SEARCH_HASHED__ -#include "lm/binary_format.hh" +#include "lm/model_type.hh" #include "lm/config.hh" #include "lm/read_arpa.hh" +#include "lm/return.hh" #include "lm/weights.hh" +#include "util/bit_packing.hh" #include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" #include +#include #include namespace util { class FilePiece; } @@ -52,9 +55,14 @@ struct HashedSearch { Unigram unigram; - void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const { const ProbBackoff &entry = unigram.Lookup(word); - prob = entry.prob; + util::FloatEnc val; + val.f = entry.prob; + ret.independent_left = (val.i & util::kSignBit); + ret.extend_left = static_cast(word); + val.i |= util::kSignBit; + ret.prob = val.f; backoff = entry.backoff; next = static_cast(word); } @@ -67,6 +75,8 @@ template class TemplateHashedSearch : public Has typedef LongestT Longest; Longest longest; + static const unsigned int kVersion = 0; + // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} @@ -85,11 +95,33 @@ template class TemplateHashedSearch : public Has const Middle *MiddleBegin() const { return &*middle_.begin(); } const Middle *MiddleEnd() const { return &*middle_.end(); } - bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { + util::FloatEnc val; + if (extend_length == 1) { + val.f = unigram.Lookup(static_cast(extend_pointer)).prob; + } else { + typename Middle::ConstIterator found; + if (!middle_[extend_length - 2].Find(extend_pointer, found)) { + std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl; + abort(); + } + val.f = found->GetValue().prob; + } + val.i |= util::kSignBit; + prob = val.f; + return extend_pointer; + } + + bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { node = CombineWordHash(node, word); typename Middle::ConstIterator found; if (!middle.Find(node, found)) return false; - prob = found->GetValue().prob; + util::FloatEnc enc; + enc.f = found->GetValue().prob; + ret.independent_left = (enc.i & util::kSignBit); + ret.extend_left = node; + enc.i |= util::kSignBit; + ret.prob = enc.f; backoff = found->GetValue().backoff; return true; } @@ -105,6 +137,7 @@ template class TemplateHashedSearch : public Has } bool LookupLongest(WordIndex word, float &prob, Node &node) const { + // Sign bit is always on because longest n-grams do not extend left. node = CombineWordHash(node, word); typename Longest::ConstIterator found; if (!longest.Find(node, found)) return false; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 05059ffb..6479813b 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -2,26 +2,25 @@ #include "lm/search_trie.hh" #include "lm/bhiksha.hh" +#include "lm/binary_format.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" #include "lm/quantize.hh" -#include "lm/read_arpa.hh" #include "lm/trie.hh" +#include "lm/trie_sort.hh" #include "lm/vocab.hh" #include "lm/weights.hh" #include "lm/word_index.hh" #include "util/ersatz_progress.hh" -#include "util/file_piece.hh" -#include "util/have.hh" #include "util/proxy_iterator.hh" #include "util/scoped.hh" +#include "util/sized_iterator.hh" #include -#include #include #include -#include +#include #include #include #include @@ -29,575 +28,221 @@ #include #include #include -#include -#include -#include namespace lm { namespace ngram { namespace trie { namespace { -/* An entry is a n-gram with probability. It consists of: - * WordIndex[order] - * float probability - * backoff probability (omitted for highest order n-gram) - * These are stored consecutively in memory. We want to sort them. - * - * The problem is the length depends on order (but all n-grams being compared - * have the same order). Allocating each entry on the heap (i.e. std::vector - * or std::string) then sorting pointers is the normal solution. But that's - * too memory inefficient. A lot of this code is just here to force std::sort - * to work with records where length is specified at runtime (and avoid using - * Boost for LM code). I could have used qsort, but the point is to also - * support __gnu_cxx:parallel_sort which doesn't have a qsort version. - */ - -class EntryIterator { - public: - EntryIterator() {} - - EntryIterator(void *ptr, std::size_t size) : ptr_(static_cast(ptr)), size_(size) {} - - bool operator==(const EntryIterator &other) const { - return ptr_ == other.ptr_; - } - bool operator<(const EntryIterator &other) const { - return ptr_ < other.ptr_; - } - EntryIterator &operator+=(std::ptrdiff_t amount) { - ptr_ += amount * size_; - return *this; - } - std::ptrdiff_t operator-(const EntryIterator &other) const { - return (ptr_ - other.ptr_) / size_; - } - - const void *Data() const { return ptr_; } - void *Data() { return ptr_; } - std::size_t EntrySize() const { return size_; } - - private: - uint8_t *ptr_; - std::size_t size_; -}; - -class EntryProxy { - public: - EntryProxy() {} - - EntryProxy(void *ptr, std::size_t size) : inner_(ptr, size) {} - - operator std::string() const { - return std::string(reinterpret_cast(inner_.Data()), inner_.EntrySize()); - } - - EntryProxy &operator=(const EntryProxy &from) { - memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize()); - return *this; - } - - EntryProxy &operator=(const std::string &from) { - memcpy(inner_.Data(), from.data(), inner_.EntrySize()); - return *this; - } - - const WordIndex *Indices() const { - return reinterpret_cast(inner_.Data()); - } - - private: - friend class util::ProxyIterator; - - typedef std::string value_type; - - typedef EntryIterator InnerIterator; - InnerIterator &Inner() { return inner_; } - const InnerIterator &Inner() const { return inner_; } - InnerIterator inner_; -}; - -typedef util::ProxyIterator NGramIter; - -// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams. -class PartialViewProxy { - public: - PartialViewProxy() : attention_size_(0), inner_() {} - - PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {} - - operator std::string() const { - return std::string(reinterpret_cast(inner_.Data()), attention_size_); - } - - PartialViewProxy &operator=(const PartialViewProxy &from) { - memcpy(inner_.Data(), from.inner_.Data(), attention_size_); - return *this; - } - - PartialViewProxy &operator=(const std::string &from) { - memcpy(inner_.Data(), from.data(), attention_size_); - return *this; - } - - const WordIndex *Indices() const { - return reinterpret_cast(inner_.Data()); - } - - private: - friend class util::ProxyIterator; - - typedef std::string value_type; - - const std::size_t attention_size_; - - typedef EntryIterator InnerIterator; - InnerIterator &Inner() { return inner_; } - const InnerIterator &Inner() const { return inner_; } - InnerIterator inner_; -}; - -typedef util::ProxyIterator PartialIter; - -template class CompareRecords : public std::binary_function { - public: - explicit CompareRecords(unsigned char order) : order_(order) {} - - bool operator()(const Proxy &first, const Proxy &second) const { - return Compare(first.Indices(), second.Indices()); - } - bool operator()(const Proxy &first, const std::string &second) const { - return Compare(first.Indices(), reinterpret_cast(second.data())); - } - bool operator()(const std::string &first, const Proxy &second) const { - return Compare(reinterpret_cast(first.data()), second.Indices()); - } - bool operator()(const std::string &first, const std::string &second) const { - return Compare(reinterpret_cast(first.data()), reinterpret_cast(second.data())); - } - - private: - bool Compare(const WordIndex *first, const WordIndex *second) const { - const WordIndex *end = first + order_; - for (; first != end; ++first, ++second) { - if (*first < *second) return true; - if (*first > *second) return false; - } - return false; - } - - unsigned char order_; -}; - -FILE *OpenOrThrow(const char *name, const char *mode) { - FILE *ret = fopen(name, mode); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); - return ret; -} - -void WriteOrThrow(FILE *to, const void *data, size_t size) { - assert(size); - if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); -} - void ReadOrThrow(FILE *from, void *data, size_t size) { - if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size); -} - -const std::size_t kCopyBufSize = 512; -void CopyOrThrow(FILE *from, FILE *to, size_t size) { - char buf[std::min(size, kCopyBufSize)]; - for (size_t i = 0; i < size; i += kCopyBufSize) { - std::size_t amount = std::min(size - i, kCopyBufSize); - ReadOrThrow(from, buf, amount); - WriteOrThrow(to, buf, amount); - } + UTIL_THROW_IF(1 != std::fread(data, size, 1, from), util::ErrnoException, "Short read"); } -void CopyRestOrThrow(FILE *from, FILE *to) { - char buf[kCopyBufSize]; - size_t amount; - while ((amount = fread(buf, 1, kCopyBufSize, from))) { - WriteOrThrow(to, buf, amount); +int Compare(unsigned char order, const void *first_void, const void *second_void) { + const WordIndex *first = reinterpret_cast(first_void), *second = reinterpret_cast(second_void); + const WordIndex *end = first + order; + for (; first != end; ++first, ++second) { + if (*first < *second) return -1; + if (*first > *second) return 1; } - if (!feof(from)) UTIL_THROW(util::ErrnoException, "Short read"); -} - -void RemoveOrThrow(const char *name) { - if (std::remove(name)) UTIL_THROW(util::ErrnoException, "Could not remove " << name); + return 0; } -std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) { - const std::size_t entry_size = sizeof(WordIndex) * order + weights_size; - const std::size_t prefix_size = sizeof(WordIndex) * (order - 1); - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << '_' << batch; - std::string ret(assembled.str()); - util::scoped_FILE out(OpenOrThrow(ret.c_str(), "w")); - // Compress entries that being with the same (order-1) words. - for (const uint8_t *group_begin = static_cast(mem_begin); group_begin != static_cast(mem_end);) { - const uint8_t *group_end; - for (group_end = group_begin + entry_size; - (group_end != static_cast(mem_end)) && !memcmp(group_begin, group_end, prefix_size); - group_end += entry_size) {} - WriteOrThrow(out.get(), group_begin, prefix_size); - WordIndex group_size = (group_end - group_begin) / entry_size; - WriteOrThrow(out.get(), &group_size, sizeof(group_size)); - for (const uint8_t *i = group_begin; i != group_end; i += entry_size) { - WriteOrThrow(out.get(), i + prefix_size, sizeof(WordIndex)); - WriteOrThrow(out.get(), i + sizeof(WordIndex) * order, weights_size); - } - group_begin = group_end; - } - return ret; -} +struct ProbPointer { + unsigned char array; + uint64_t index; +}; -class SortedFileReader { +// Array of n-grams and float indices. +class BackoffMessages { public: - SortedFileReader() : ended_(false) {} - - void Init(const std::string &name, unsigned char order) { - file_.reset(OpenOrThrow(name.c_str(), "r")); - header_.resize(order - 1); - NextHeader(); + void Init(std::size_t entry_size) { + current_ = NULL; + allocated_ = NULL; + entry_size_ = entry_size; } - // Preceding words. - const WordIndex *Header() const { - return &*header_.begin(); - } - const std::vector &HeaderVector() const { return header_;} - - std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); } - - void NextHeader() { - if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get())) { - if (feof(file_.get())) { - ended_ = true; - } else { - UTIL_THROW(util::ErrnoException, "Short read of counts"); + void Add(const WordIndex *to, ProbPointer index) { + while (current_ + entry_size_ > allocated_) { + std::size_t allocated_size = allocated_ - (uint8_t*)backing_.get(); + Resize(std::max(allocated_size * 2, entry_size_)); + } + memcpy(current_, to, entry_size_ - sizeof(ProbPointer)); + *reinterpret_cast(current_ + entry_size_ - sizeof(ProbPointer)) = index; + current_ += entry_size_; + } + + void Apply(float *const *const base, FILE *unigrams) { + FinishedAdding(); + if (current_ == allocated_) return; + rewind(unigrams); + ProbBackoff weights; + WordIndex unigram = 0; + ReadOrThrow(unigrams, &weights, sizeof(weights)); + for (; current_ != allocated_; current_ += entry_size_) { + const WordIndex &cur_word = *reinterpret_cast(current_); + for (; unigram < cur_word; ++unigram) { + ReadOrThrow(unigrams, &weights, sizeof(weights)); } + if (!HasExtension(weights.backoff)) { + weights.backoff = kExtensionBackoff; + UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed."); + WriteOrThrow(unigrams, &weights, sizeof(weights)); + } + const ProbPointer &write_to = *reinterpret_cast(current_ + sizeof(WordIndex)); + base[write_to.array][write_to.index] += weights.backoff; } + backing_.reset(); + } + + void Apply(float *const *const base, RecordReader &reader) { + FinishedAdding(); + if (current_ == allocated_) return; + // We'll also use the same buffer to record messages to blanks that they extend. + WordIndex *extend_out = reinterpret_cast(current_); + const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex); + for (reader.Rewind(); reader && (current_ != allocated_); ) { + switch (Compare(order, reader.Data(), current_)) { + case -1: + ++reader; + break; + case 1: + // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. + for (const WordIndex *w = reinterpret_cast(current_); w != reinterpret_cast(current_) + order; ++w, ++extend_out) *extend_out = *w; + current_ += entry_size_; + break; + case 0: + float &backoff = reinterpret_cast((uint8_t*)reader.Data() + order * sizeof(WordIndex))->backoff; + if (!HasExtension(backoff)) { + backoff = kExtensionBackoff; + reader.Overwrite(&backoff, sizeof(float)); + } else { + const ProbPointer &write_to = *reinterpret_cast(current_ + entry_size_ - sizeof(ProbPointer)); + base[write_to.array][write_to.index] += backoff; + } + current_ += entry_size_; + break; + } + } + // Now this is a list of blanks that extend right. + entry_size_ = sizeof(WordIndex) * order; + Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get())); + current_ = (uint8_t*)backing_.get(); } - WordIndex ReadCount() { - WordIndex ret; - ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); - return ret; - } - - WordIndex ReadWord() { - WordIndex ret; - ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); - return ret; - } - - template void ReadWeights(Weights &weights) { - ReadOrThrow(file_.get(), &weights, sizeof(Weights)); + // Call after Apply + bool Extends(unsigned char order, const WordIndex *words) { + if (current_ == allocated_) return false; + assert(order * sizeof(WordIndex) == entry_size_); + while (true) { + switch(Compare(order, words, current_)) { + case 1: + current_ += entry_size_; + if (current_ == allocated_) return false; + break; + case -1: + return false; + case 0: + return true; + } + } } - bool Ended() const { - return ended_; + private: + void FinishedAdding() { + Resize(current_ - (uint8_t*)backing_.get()); + current_ = (uint8_t*)backing_.get(); } - void Rewind() { - rewind(file_.get()); - ended_ = false; - NextHeader(); + void Resize(std::size_t to) { + std::size_t current = current_ - (uint8_t*)backing_.get(); + backing_.call_realloc(to); + current_ = (uint8_t*)backing_.get() + current; + allocated_ = (uint8_t*)backing_.get() + to; } - FILE *File() { return file_.get(); } - - private: - util::scoped_FILE file_; + util::scoped_malloc backing_; - std::vector header_; + uint8_t *current_, *allocated_; - bool ended_; + std::size_t entry_size_; }; -void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) { - WriteOrThrow(to, from.Header(), from.HeaderBytes()); - WordIndex count = from.ReadCount(); - WriteOrThrow(to, &count, sizeof(WordIndex)); - - CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); -} - -void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order) { - SortedFileReader first, second; - first.Init(first_name.c_str(), order); - RemoveOrThrow(first_name.c_str()); - second.Init(second_name.c_str(), order); - RemoveOrThrow(second_name.c_str()); - util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); - while (!first.Ended() && !second.Ended()) { - if (first.HeaderVector() < second.HeaderVector()) { - CopyFullRecord(first, out_file.get(), weights_size); - first.NextHeader(); - continue; - } - if (first.HeaderVector() > second.HeaderVector()) { - CopyFullRecord(second, out_file.get(), weights_size); - second.NextHeader(); - continue; - } - // Merge at the entry level. - WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes()); - WordIndex first_count = first.ReadCount(), second_count = second.ReadCount(); - WordIndex total_count = first_count + second_count; - WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex)); - - WordIndex first_word = first.ReadWord(), second_word = second.ReadWord(); - WordIndex first_index = 0, second_index = 0; - while (true) { - if (first_word < second_word) { - WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); - CopyOrThrow(first.File(), out_file.get(), weights_size); - if (++first_index == first_count) break; - first_word = first.ReadWord(); - } else { - WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); - CopyOrThrow(second.File(), out_file.get(), weights_size); - if (++second_index == second_count) break; - second_word = second.ReadWord(); - } - } - if (first_index == first_count) { - WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); - CopyOrThrow(second.File(), out_file.get(), (second_count - second_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); - } else { - WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); - CopyOrThrow(first.File(), out_file.get(), (first_count - first_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); - } - first.NextHeader(); - second.NextHeader(); - } - - for (SortedFileReader &remaining = first.Ended() ? second : first; !remaining.Ended(); remaining.NextHeader()) { - CopyFullRecord(remaining, out_file.get(), weights_size); - } -} - -const char *kContextSuffix = "_contexts"; - -void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { - const size_t context_size = sizeof(WordIndex) * (order - 1); - // Sort just the contexts using the same memory. - PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); - PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); - - std::sort(context_begin, context_end, CompareRecords(order - 1)); - - std::string name(ngram_file_name + kContextSuffix); - util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); - - // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. - if (context_begin == context_end) return; - PartialIter i(context_begin); - WriteOrThrow(out.get(), i->Indices(), context_size); - const WordIndex *previous = i->Indices(); - ++i; - for (; i != context_end; ++i) { - if (memcmp(previous, i->Indices(), context_size)) { - WriteOrThrow(out.get(), i->Indices(), context_size); - previous = i->Indices(); - } - } -} +const float kBadProb = std::numeric_limits::infinity(); -class ContextReader { +class SRISucks { public: - ContextReader() : valid_(false) {} - - ContextReader(const char *name, unsigned char order) { - Reset(name, order); - } - - void Reset(const char *name, unsigned char order) { - file_.reset(OpenOrThrow(name, "r")); - length_ = sizeof(WordIndex) * static_cast(order); - words_.resize(order); - valid_ = true; - ++*this; - } - - ContextReader &operator++() { - if (1 != fread(&*words_.begin(), length_, 1, file_.get())) { - if (!feof(file_.get())) - UTIL_THROW(util::ErrnoException, "Short read"); - valid_ = false; + SRISucks() { + for (BackoffMessages *i = messages_; i != messages_ + kMaxOrder - 1; ++i) + i->Init(sizeof(ProbPointer) + sizeof(WordIndex) * (i - messages_ + 1)); + } + + void Send(unsigned char begin, unsigned char order, const WordIndex *to, float prob_basis) { + assert(prob_basis != kBadProb); + ProbPointer pointer; + pointer.array = order - 1; + pointer.index = values_[order - 1].size(); + for (unsigned char i = begin; i < order; ++i) { + messages_[i - 1].Add(to, pointer); } - return *this; + values_[order - 1].push_back(prob_basis); } - const WordIndex *operator*() const { return &*words_.begin(); } - - operator bool() const { return valid_; } - - FILE *GetFile() { return file_.get(); } - - private: - util::scoped_FILE file_; - - size_t length_; - - std::vector words_; - - bool valid_; -}; - -void MergeContextFiles(const std::string &first_base, const std::string &second_base, const std::string &out_base, unsigned char order) { - const size_t context_size = sizeof(WordIndex) * (order - 1); - std::string first_name(first_base + kContextSuffix); - std::string second_name(second_base + kContextSuffix); - ContextReader first(first_name.c_str(), order - 1), second(second_name.c_str(), order - 1); - RemoveOrThrow(first_name.c_str()); - RemoveOrThrow(second_name.c_str()); - std::string out_name(out_base + kContextSuffix); - util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w")); - while (first && second) { - for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) { - if (f == *first + order - 1) { - // Equal. - WriteOrThrow(out.get(), *first, context_size); - ++first; - ++second; - break; + void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) { + for (unsigned char i = 0; i < kMaxOrder - 1; ++i) { + it_[i] = &*values_[i].begin(); } - if (*f < *s) { - // First lower - WriteOrThrow(out.get(), *first, context_size); - ++first; - break; - } else if (*f > *s) { - WriteOrThrow(out.get(), *second, context_size); - ++second; - break; + messages_[0].Apply(it_, unigram_file); + BackoffMessages *messages = messages_ + 1; + const RecordReader *end = reader + total_order - 2 /* exclude unigrams and longest order */; + for (; reader != end; ++messages, ++reader) { + messages->Apply(it_, *reader); } } - } - ContextReader &remaining = first ? first : second; - if (!remaining) return; - WriteOrThrow(out.get(), *remaining, context_size); - CopyRestOrThrow(remaining.GetFile(), out.get()); -} -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { - ReadNGramHeader(f, order); - const size_t count = counts[order - 1]; - // Size of weights. Does it include backoff? - const size_t words_size = sizeof(WordIndex) * order; - const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); - const size_t entry_size = words_size + weights_size; - const size_t batch_size = std::min(count, mem.size() / entry_size); - uint8_t *const begin = reinterpret_cast(mem.get()); - std::deque files; - for (std::size_t batch = 0, done = 0; done < count; ++batch) { - uint8_t *out = begin; - uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; - if (order == counts.size()) { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); - } - } else { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); - } + ProbBackoff GetBlank(unsigned char total_order, unsigned char order, const WordIndex *indices) { + assert(order > 1); + ProbBackoff ret; + ret.prob = *(it_[order - 1]++); + ret.backoff = ((order != total_order - 1) && messages_[order - 1].Extends(order, indices)) ? kExtensionBackoff : kNoExtensionBackoff; + return ret; } - // Sort full records by full n-gram. - EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - // parallel_sort uses too much RAM - std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order)); - files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); - WriteContextFile(begin, out_end, files.back(), entry_size, order); - - done += (out_end - begin) / entry_size; - } - // All individual files created. Merge them. - - std::size_t merge_count = 0; - while (files.size() > 1) { - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); - files.push_back(assembled.str()); - MergeSortedFiles(files[0], files[1], files.back(), weights_size, order); - MergeContextFiles(files[0], files[1], files.back(), order); - files.pop_front(); - files.pop_front(); - } - if (!files.empty()) { - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << "_merged"; - std::string merged_name(assembled.str()); - if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); - std::string context_name = files[0] + kContextSuffix; - merged_name += kContextSuffix; - if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); - } -} - -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { - PositiveProbWarn warn(config.positive_log_probability); - { - std::string unigram_name = file_prefix + "unigrams"; - util::scoped_fd unigram_file; - // In case appears. - size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); - Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); - CheckSpecials(config, vocab); - if (!vocab.SawUnk()) ++counts[0]; - } + const std::vector &Values(unsigned char order) const { + return values_[order - 1]; + } - // Only use as much buffer as we need. - size_t buffer_use = 0; - for (unsigned int order = 2; order < counts.size(); ++order) { - buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); - } - buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); - buffer = std::min(buffer, buffer_use); + private: + // This used to be one array. Then I needed to separate it by order for quantization to work. + std::vector values_[kMaxOrder - 1]; + BackoffMessages messages_[kMaxOrder - 1]; - util::scoped_memory mem; - mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); - if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); + float *it_[kMaxOrder - 1]; +}; - for (unsigned char order = 2; order <= counts.size(); ++order) { - ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); - } - ReadEnd(f); -} +class FindBlanks { + public: + FindBlanks(uint64_t *counts, unsigned char order, const ProbBackoff *unigrams, SRISucks &messages) + : counts_(counts), longest_counts_(counts + order - 1), unigrams_(unigrams), sri_(messages) {} -bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) { - for (; words != words_end; ++words, ++header) { - if (*words != *header) { - //assert(*words <= *header); - return false; + float UnigramProb(WordIndex index) const { + return unigrams_[index].prob; } - } - return true; -} -// Phase to count n-grams, including blanks inserted because they were pruned but have extensions -class JustCount { - public: - template JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) - : counts_(counts), longest_counts_(counts + order - 1) {} - - void Unigrams(WordIndex begin, WordIndex end) { - counts_[0] += end - begin; + void Unigram(WordIndex /*index*/) { + ++counts_[0]; } - void MiddleBlank(const unsigned char mid_idx, WordIndex /* idx */) { - ++counts_[mid_idx + 1]; + void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) { + sri_.Send(lower, order, indices + 1, prob_basis); + ++counts_[order - 1]; } - void Middle(const unsigned char mid_idx, const WordIndex * /*before*/, WordIndex /*key*/, const ProbBackoff &/*weights*/) { - ++counts_[mid_idx + 1]; + void Middle(const unsigned char order, const void * /*data*/) { + ++counts_[order - 1]; } - void Longest(WordIndex /*key*/, Prob /*prob*/) { + void Longest(const void * /*data*/) { ++*longest_counts_; } @@ -608,167 +253,156 @@ class JustCount { private: uint64_t *const counts_, *const longest_counts_; + + const ProbBackoff *unigrams_; + + SRISucks &sri_; }; // Phase to actually write n-grams to the trie. template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), unigrams_(unigrams), middle_(middle), longest_(longest), - bigram_pack_((order == 2) ? static_cast(longest_) : static_cast(*middle_)) {} + bigram_pack_((order == 2) ? static_cast(longest_) : static_cast(*middle_)), + order_(order), + sri_(sri) {} - void Unigrams(WordIndex begin, WordIndex end) { - uint64_t next = bigram_pack_.InsertIndex(); - for (UnigramValue *i = unigrams_ + begin; i < unigrams_ + end; ++i) { - i->next = next; - } + float UnigramProb(WordIndex index) const { return unigrams_[index].weights.prob; } + + void Unigram(WordIndex word) { + unigrams_[word].next = bigram_pack_.InsertIndex(); } - void MiddleBlank(const unsigned char mid_idx, WordIndex key) { - middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff); + void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) { + ProbBackoff weights = sri_.GetBlank(order_, order, indices); + middle_[order - 2].Insert(indices[order - 1], weights.prob, weights.backoff); } - void Middle(const unsigned char mid_idx, const WordIndex *before, WordIndex key, ProbBackoff weights) { - // Order (mid_idx+2). - ContextReader &context = contexts_[mid_idx + 1]; - if (context && !memcmp(before, *context, sizeof(WordIndex) * (mid_idx + 1)) && (*context)[mid_idx + 1] == key) { + void Middle(const unsigned char order, const void *data) { + RecordReader &context = contexts_[order - 1]; + const WordIndex *words = reinterpret_cast(data); + ProbBackoff weights = *reinterpret_cast(words + order); + if (context && !memcmp(data, context.Data(), sizeof(WordIndex) * order)) { SetExtension(weights.backoff); ++context; } - middle_[mid_idx].Insert(key, weights.prob, weights.backoff); + middle_[order - 2].Insert(words[order - 1], weights.prob, weights.backoff); } - void Longest(WordIndex key, Prob prob) { - longest_.Insert(key, prob.prob); + void Longest(const void *data) { + const WordIndex *words = reinterpret_cast(data); + longest_.Insert(words[order_ - 1], reinterpret_cast(words + order_)->prob); } void Cleanup() {} private: - ContextReader *contexts_; + RecordReader *contexts_; UnigramValue *const unigrams_; BitPackedMiddle *const middle_; BitPackedLongest &longest_; BitPacked &bigram_pack_; + const unsigned char order_; + SRISucks &sri_; }; -template class RecursiveInsert { - public: - template RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : - doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { - } +struct Gram { + Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {} - // Outer unigram loop. - void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) { - util::ErsatzProgress progress(progress_out, message, unigram_count + 1); - for (words_[0] = 0; ; ++words_[0]) { - progress.Set(words_[0]); - WordIndex min_continue = unigram_count; - for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) { - if (other->Ended()) continue; - min_continue = std::min(min_continue, other->Header()[0]); - } - // This will write at unigram_count. This is by design so that the next pointers will make sense. - doing_.Unigrams(words_[0], min_continue + 1); - if (min_continue == unigram_count) break; - words_[0] = min_continue; - Middle(0); - } - doing_.Cleanup(); - } + const WordIndex *begin, *end; - private: - void Middle(const unsigned char mid_idx) { - // (mid_idx + 2)-gram. - if (mid_idx == order_minus_2_) { - Longest(); - return; - } - // Orders [2, order) + // For queue, this is the direction we want. + bool operator<(const Gram &other) const { + return std::lexicographical_compare(other.begin, other.end, begin, end); + } +}; - SortedFileReader &reader = inputs_[mid_idx]; +template class BlankManager { + public: + BlankManager(unsigned char total_order, Doing &doing) : total_order_(total_order), been_length_(0), doing_(doing) { + for (float *i = basis_; i != basis_ + kMaxOrder - 1; ++i) *i = kBadProb; + } - if (reader.Ended() || !HeadMatch(words_, words_ + mid_idx + 1, reader.Header())) { - // This order doesn't have a header match, but longer ones might. - MiddleAllBlank(mid_idx); - return; + void Visit(const WordIndex *to, unsigned char length, float prob) { + basis_[length - 1] = prob; + unsigned char overlap = std::min(length - 1, been_length_); + const WordIndex *cur; + WordIndex *pre; + for (cur = to, pre = been_; cur != to + overlap; ++cur, ++pre) { + if (*pre != *cur) break; } - - // There is a header match. - WordIndex count = reader.ReadCount(); - WordIndex current = reader.ReadWord(); - while (count) { - WordIndex min_continue = std::numeric_limits::max(); - for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) - min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); - } - while (true) { - if (current > min_continue) { - doing_.MiddleBlank(mid_idx, min_continue); - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); - break; - } - ProbBackoff weights; - reader.ReadWeights(weights); - doing_.Middle(mid_idx, words_, current, weights); - --count; - if (current == min_continue) { - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); - if (count) current = reader.ReadWord(); - break; - } - if (!count) break; - current = reader.ReadWord(); - } + if (cur == to + length - 1) { + *pre = *cur; + been_length_ = length; + return; } - // Count is now zero. Finish off remaining blanks. - MiddleAllBlank(mid_idx); - reader.NextHeader(); - } - - void MiddleAllBlank(const unsigned char mid_idx) { - while (true) { - WordIndex min_continue = std::numeric_limits::max(); - for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) - min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); - } - if (min_continue == std::numeric_limits::max()) return; - doing_.MiddleBlank(mid_idx, min_continue); - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); + // There are blanks to insert starting with order blank. + unsigned char blank = cur - to + 1; + UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context."); + const float *lower_basis; + for (lower_basis = basis_ + blank - 2; *lower_basis == kBadProb; --lower_basis) {} + unsigned char based_on = lower_basis - basis_ + 1; + for (; cur != to + length - 1; ++blank, ++cur, ++pre) { + assert(*lower_basis != kBadProb); + doing_.MiddleBlank(blank, to, based_on, *lower_basis); + *pre = *cur; + // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. + basis_[blank - 1] = kBadProb; } + been_length_ = length; } - void Longest() { - SortedFileReader &reader = *(inputs_end_ - 1); - if (reader.Ended() || !HeadMatch(words_, words_ + order_minus_2_ + 1, reader.Header())) return; - WordIndex count = reader.ReadCount(); - for (WordIndex i = 0; i < count; ++i) { - WordIndex word = reader.ReadWord(); - Prob prob; - reader.ReadWeights(prob); - doing_.Longest(word, prob); - } - reader.NextHeader(); - return; - } + private: + const unsigned char total_order_; - Doing doing_; + WordIndex been_[kMaxOrder]; + unsigned char been_length_; - SortedFileReader *inputs_; - SortedFileReader *inputs_end_; + float basis_[kMaxOrder]; + + Doing &doing_; +}; - WordIndex words_[kMaxOrder]; +template void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) { + util::ErsatzProgress progress(progress_out, message, unigram_count + 1); + unsigned int unigram = 0; + std::priority_queue grams; + grams.push(Gram(&unigram, 1)); + for (unsigned char i = 2; i <= total_order; ++i) { + if (input[i-2]) grams.push(Gram(reinterpret_cast(input[i-2].Data()), i)); + } - const unsigned char order_minus_2_; -}; + BlankManager blank(total_order, doing); + + while (true) { + Gram top = grams.top(); + grams.pop(); + unsigned char order = top.end - top.begin; + if (order == 1) { + blank.Visit(&unigram, 1, doing.UnigramProb(unigram)); + doing.Unigram(unigram); + progress.Set(unigram); + if (++unigram == unigram_count + 1) break; + grams.push(top); + } else { + if (order == total_order) { + blank.Visit(top.begin, order, reinterpret_cast(top.end)->prob); + doing.Longest(top.begin); + } else { + blank.Visit(top.begin, order, reinterpret_cast(top.end)->prob); + doing.Middle(order, top.begin); + } + RecordReader &reader = input[order - 2]; + if (++reader) grams.push(top); + } + } + assert(grams.empty()); + doing.Cleanup(); +} void SanityCheckCounts(const std::vector &initial, const std::vector &fixed) { if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]); @@ -778,120 +412,122 @@ void SanityCheckCounts(const std::vector &initial, const std::vector void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { - ProbBackoff weights; - std::vector probs, backoffs; - probs.reserve(count); +template void TrainQuantizer(uint8_t order, uint64_t count, const std::vector &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { + std::vector probs(additional), backoffs; + probs.reserve(count + additional.size()); backoffs.reserve(count); - for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { - uint64_t entries = reader.ReadCount(); - for (uint64_t c = 0; c < entries; ++c) { - reader.ReadWord(); - reader.ReadWeights(weights); - // kBlankProb isn't added yet. - probs.push_back(weights.prob); - if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); - ++progress; - } + for (reader.Rewind(); reader; ++reader) { + const ProbBackoff &weights = *reinterpret_cast(reinterpret_cast(reader.Data()) + sizeof(WordIndex) * order); + probs.push_back(weights.prob); + if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); + ++progress; } quant.Train(order, probs, backoffs); } -template void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { - Prob weights; +template void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { std::vector probs, backoffs; probs.reserve(count); - for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { - uint64_t entries = reader.ReadCount(); - for (uint64_t c = 0; c < entries; ++c) { - reader.ReadWord(); - reader.ReadWeights(weights); - // kBlankProb isn't added yet. - probs.push_back(weights.prob); - ++progress; - } + for (reader.Rewind(); reader; ++reader) { + const Prob &weights = *reinterpret_cast(reinterpret_cast(reader.Data()) + sizeof(WordIndex) * order); + probs.push_back(weights.prob); + ++progress; } quant.TrainProb(order, probs); } +void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { + // Fill unigram probabilities. + try { + rewind(file); + for (WordIndex i = 0; i < unigram_count; ++i) { + ReadOrThrow(file, &unigrams[i].weights, sizeof(ProbBackoff)); + if (contexts && *reinterpret_cast(contexts.Data()) == i) { + SetExtension(unigrams[i].weights.backoff); + ++contexts; + } + } + } catch (util::Exception &e) { + e << " while re-reading unigram probabilities"; + throw; + } +} + } // namespace template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { - std::vector inputs(counts.size() - 1); - std::vector contexts(counts.size() - 1); + RecordReader inputs[kMaxOrder - 1]; + RecordReader contexts[kMaxOrder - 1]; for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; assembled << file_prefix << static_cast(i) << "_merged"; - inputs[i-2].Init(assembled.str(), i); - RemoveOrThrow(assembled.str().c_str()); + inputs[i-2].Init(assembled.str(), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff))); + util::RemoveOrThrow(assembled.str().c_str()); assembled << kContextSuffix; - contexts[i-2].Reset(assembled.str().c_str(), i-1); - RemoveOrThrow(assembled.str().c_str()); + contexts[i-2].Init(assembled.str(), (i-1) * sizeof(WordIndex)); + util::RemoveOrThrow(assembled.str().c_str()); } + SRISucks sri; std::vector fixed_counts(counts.size()); { - RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); - counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); + std::string temp(file_prefix); temp += "unigrams"; + util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str())); + util::scoped_memory unigrams; + MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); + FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast(unigrams.get()), sri); + RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); } - for (std::vector::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { - if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs.begin() + 2) << "-gram table did not complete reading"); + for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) { + if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; + util::scoped_FILE unigram_file; + { + std::string name(file_prefix + "unigrams"); + unigram_file.reset(OpenOrThrow(name.c_str(), "r")); + util::RemoveOrThrow(name.c_str()); + } + sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); + out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + for (unsigned char i = 2; i <= counts.size(); ++i) { + inputs[i-2].Rewind(); + } if (Quant::kTrain) { util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); for (unsigned char i = 2; i < counts.size(); ++i) { - TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); + TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); quant.FinishedLoading(config); } + UnigramValue *unigrams = out.unigram.Raw(); + PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams); + unigram_file.reset(); + for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); } - UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); - inserter.Apply(config.messages, "Building trie", fixed_counts[0]); - } - - // Fill unigram probabilities. - try { - std::string name(file_prefix + "unigrams"); - util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); - for (WordIndex i = 0; i < counts[0]; ++i) { - ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); - if (contexts[0] && **contexts[0] == i) { - SetExtension(unigrams[i].weights.backoff); - ++contexts[0]; - } - } - RemoveOrThrow(name.c_str()); - } catch (util::Exception &e) { - e << " while re-reading unigram probabilities"; - throw; + WriteEntries writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri); + RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. for (unsigned char order = 2; order <= counts.size(); ++order) { - const ContextReader &context = contexts[order - 2]; + const RecordReader &context = contexts[order - 2]; if (context) { FormatLoadException e; - e << "An " << static_cast(order) << "-gram has the context (i.e. all but the last word):"; - for (const WordIndex *i = *context; i != *context + order - 1; ++i) { + e << "An " << static_cast(order) << "-gram has context"; + const WordIndex *ctx = reinterpret_cast(context.Data()); + for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) { e << ' ' << *i; } e << " so this context must appear in the model as a " << static_cast(order - 1) << "-gram but it does not"; @@ -945,6 +581,14 @@ template void TrieSearch::LoadedBin longest.LoadedBinary(); } +namespace { +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} +} // namespace + template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 2f39c09f..c3e02a98 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -1,10 +1,16 @@ #ifndef LM_SEARCH_TRIE__ #define LM_SEARCH_TRIE__ -#include "lm/binary_format.hh" +#include "lm/config.hh" +#include "lm/model_type.hh" +#include "lm/return.hh" #include "lm/trie.hh" #include "lm/weights.hh" +#include "util/file_piece.hh" + +#include + #include namespace lm { @@ -30,6 +36,8 @@ template class TrieSearch { static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); + static const unsigned int kVersion = 0; + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); @@ -57,12 +65,16 @@ template class TrieSearch { void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - unigram.Find(word, prob, backoff, node); + void LookupUnigram(WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { + unigram.Find(word, ret.prob, backoff, node); + ret.independent_left = (node.begin == node.end); + ret.extend_left = static_cast(word); } - bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { - return mid.Find(word, prob, backoff, node); + bool LookupMiddle(const Middle &mid, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { + if (!mid.Find(word, ret.prob, backoff, node, ret.extend_left)) return false; + ret.independent_left = (node.begin == node.end); + return true; } bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { @@ -76,14 +88,25 @@ template class TrieSearch { bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { // TODO: don't decode backoff. assert(begin != end); - float ignored_prob, ignored_backoff; - LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + FullScoreReturn ignored; + float ignored_backoff; + LookupUnigram(*begin, ignored_backoff, node, ignored); for (const WordIndex *i = begin + 1; i < end; ++i) { if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; } return true; } + Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { + if (extend_length == 1) { + float ignored; + Node ret; + unigram.Find(static_cast(extend_pointer), prob, ignored, ret); + return ret; + } + return middle_begin_[extend_length - 2].ReadEntry(extend_pointer, prob); + } + private: friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 8c536e66..4e60b184 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -86,7 +86,7 @@ template void BitPackedMiddle::Inse ++insert_index_; } -template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; @@ -94,6 +94,9 @@ template bool BitPackedMiddle::Find uint64_t index = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; + + pointer = at_pointer; + quant_.Read(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 53612064..a9f5e417 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -94,10 +94,18 @@ template class BitPackedMiddle : public BitPacked { void LoadedBinary() { bhiksha_.LoadedBinary(); } - bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; + bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const; bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; + NodeRange ReadEntry(uint64_t pointer, float &prob) { + quant_.ReadProb(base_, pointer, prob); + NodeRange ret; + // pointer/total_bits_ should always round down. + bhiksha_.ReadNext(base_, pointer + quant_.TotalBits(), pointer / total_bits_, total_bits_, ret); + return ret; + } + private: Quant quant_; Bhiksha bhiksha_; diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc new file mode 100644 index 00000000..01c4e490 --- /dev/null +++ b/klm/lm/trie_sort.cc @@ -0,0 +1,261 @@ +#include "lm/trie_sort.hh" + +#include "lm/config.hh" +#include "lm/lm_exception.hh" +#include "lm/read_arpa.hh" +#include "lm/vocab.hh" +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/file_piece.hh" +#include "util/mmap.hh" +#include "util/proxy_iterator.hh" +#include "util/sized_iterator.hh" + +#include +#include +#include +#include +#include +#include + +namespace lm { +namespace ngram { +namespace trie { + +const char *kContextSuffix = "_contexts"; + +FILE *OpenOrThrow(const char *name, const char *mode) { + FILE *ret = fopen(name, mode); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); + return ret; +} + +void WriteOrThrow(FILE *to, const void *data, size_t size) { + assert(size); + if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); +} + +namespace { + +typedef util::SizedIterator NGramIter; + +// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams. +class PartialViewProxy { + public: + PartialViewProxy() : attention_size_(0), inner_() {} + + PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {} + + operator std::string() const { + return std::string(reinterpret_cast(inner_.Data()), attention_size_); + } + + PartialViewProxy &operator=(const PartialViewProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), attention_size_); + return *this; + } + + PartialViewProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), attention_size_); + return *this; + } + + const void *Data() const { return inner_.Data(); } + void *Data() { return inner_.Data(); } + + private: + friend class util::ProxyIterator; + + typedef std::string value_type; + + const std::size_t attention_size_; + + typedef util::SizedInnerIterator InnerIterator; + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +typedef util::ProxyIterator PartialIter; + +std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order) { + std::stringstream assembled; + assembled << file_prefix << static_cast(order) << '_' << batch; + std::string ret(assembled.str()); + util::scoped_fd out(util::CreateOrThrow(ret.c_str())); + util::WriteOrThrow(out.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); + return ret; +} + +void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { + const size_t context_size = sizeof(WordIndex) * (order - 1); + // Sort just the contexts using the same memory. + PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); + PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); + + std::sort(context_begin, context_end, util::SizedCompare(EntryCompare(order - 1))); + + std::string name(ngram_file_name + kContextSuffix); + util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); + + // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. + if (context_begin == context_end) return; + PartialIter i(context_begin); + WriteOrThrow(out.get(), i->Data(), context_size); + const void *previous = i->Data(); + ++i; + for (; i != context_end; ++i) { + if (memcmp(previous, i->Data(), context_size)) { + WriteOrThrow(out.get(), i->Data(), context_size); + previous = i->Data(); + } + } +} + +struct ThrowCombine { + void operator()(std::size_t /*entry_size*/, const void * /*first*/, const void * /*second*/, FILE * /*out*/) const { + UTIL_THROW(FormatLoadException, "Duplicate n-gram detected."); + } +}; + +// Useful for context files that just contain records with no value. +struct FirstCombine { + void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const { + WriteOrThrow(out, first, entry_size); + } +}; + +template void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order, const Combine &combine = ThrowCombine()) { + std::size_t entry_size = sizeof(WordIndex) * order + weights_size; + RecordReader first, second; + first.Init(first_name.c_str(), entry_size); + util::RemoveOrThrow(first_name.c_str()); + second.Init(second_name.c_str(), entry_size); + util::RemoveOrThrow(second_name.c_str()); + util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); + EntryCompare less(order); + while (first && second) { + if (less(first.Data(), second.Data())) { + WriteOrThrow(out_file.get(), first.Data(), entry_size); + ++first; + } else if (less(second.Data(), first.Data())) { + WriteOrThrow(out_file.get(), second.Data(), entry_size); + ++second; + } else { + combine(entry_size, first.Data(), second.Data(), out_file.get()); + ++first; ++second; + } + } + for (RecordReader &remains = (first ? second : first); remains; ++remains) { + WriteOrThrow(out_file.get(), remains.Data(), entry_size); + } +} + +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { + ReadNGramHeader(f, order); + const size_t count = counts[order - 1]; + // Size of weights. Does it include backoff? + const size_t words_size = sizeof(WordIndex) * order; + const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); + const size_t entry_size = words_size + weights_size; + const size_t batch_size = std::min(count, mem.size() / entry_size); + uint8_t *const begin = reinterpret_cast(mem.get()); + std::deque files; + for (std::size_t batch = 0, done = 0; done < count; ++batch) { + uint8_t *out = begin; + uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; + if (order == counts.size()) { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); + } + } else { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); + } + } + // Sort full records by full n-gram. + util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); + // parallel_sort uses too much RAM + std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare(EntryCompare(order))); + files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order)); + WriteContextFile(begin, out_end, files.back(), entry_size, order); + + done += (out_end - begin) / entry_size; + } + + // All individual files created. Merge them. + + std::size_t merge_count = 0; + while (files.size() > 1) { + std::stringstream assembled; + assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); + files.push_back(assembled.str()); + MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); + MergeSortedFiles(files[0], files[1], files.back(), 0, order, FirstCombine()); + files.pop_front(); + files.pop_front(); + } + if (!files.empty()) { + std::stringstream assembled; + assembled << file_prefix << static_cast(order) << "_merged"; + std::string merged_name(assembled.str()); + if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); + std::string context_name = files[0] + kContextSuffix; + merged_name += kContextSuffix; + if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); + } +} + +} // namespace + +void RecordReader::Init(const std::string &name, std::size_t entry_size) { + file_.reset(OpenOrThrow(name.c_str(), "r+")); + data_.reset(malloc(entry_size)); + UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer"); + remains_ = true; + entry_size_ = entry_size; + ++*this; +} + +void RecordReader::Overwrite(const void *start, std::size_t amount) { + long internal = (uint8_t*)start - (uint8_t*)data_.get(); + UTIL_THROW_IF(fseek(file_.get(), internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); + WriteOrThrow(file_.get(), start, amount); + long forward = entry_size_ - internal - amount; + if (forward) UTIL_THROW_IF(fseek(file_.get(), forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision"); +} + +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + PositiveProbWarn warn(config.positive_log_probability); + { + std::string unigram_name = file_prefix + "unigrams"; + util::scoped_fd unigram_file; + // In case appears. + size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); + Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); + CheckSpecials(config, vocab); + if (!vocab.SawUnk()) ++counts[0]; + } + + // Only use as much buffer as we need. + size_t buffer_use = 0; + for (unsigned int order = 2; order < counts.size(); ++order) { + buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); + } + buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); + buffer = std::min(buffer, buffer_use); + + util::scoped_memory mem; + mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); + if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); + + for (unsigned char order = 2; order <= counts.size(); ++order) { + ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); + } + ReadEnd(f); +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh new file mode 100644 index 00000000..a6916483 --- /dev/null +++ b/klm/lm/trie_sort.hh @@ -0,0 +1,94 @@ +#ifndef LM_TRIE_SORT__ +#define LM_TRIE_SORT__ + +#include "lm/word_index.hh" + +#include "util/file.hh" +#include "util/scoped.hh" + +#include +#include +#include +#include + +#include + +namespace util { class FilePiece; } + +// Step of trie builder: create sorted files. +namespace lm { +namespace ngram { +class SortedVocabulary; +class Config; + +namespace trie { + +extern const char *kContextSuffix; +FILE *OpenOrThrow(const char *name, const char *mode); +void WriteOrThrow(FILE *to, const void *data, size_t size); + +class EntryCompare : public std::binary_function { + public: + explicit EntryCompare(unsigned char order) : order_(order) {} + + bool operator()(const void *first_void, const void *second_void) const { + const WordIndex *first = static_cast(first_void); + const WordIndex *second = static_cast(second_void); + const WordIndex *end = first + order_; + for (; first != end; ++first, ++second) { + if (*first < *second) return true; + if (*first > *second) return false; + } + return false; + } + private: + unsigned char order_; +}; + +class RecordReader { + public: + RecordReader() : remains_(true) {} + + void Init(const std::string &name, std::size_t entry_size); + + void *Data() { return data_.get(); } + const void *Data() const { return data_.get(); } + + RecordReader &operator++() { + std::size_t ret = fread(data_.get(), entry_size_, 1, file_.get()); + if (!ret) { + UTIL_THROW_IF(!feof(file_.get()), util::ErrnoException, "Error reading temporary file"); + remains_ = false; + } + return *this; + } + + operator bool() const { return remains_; } + + void Rewind() { + rewind(file_.get()); + remains_ = true; + ++*this; + } + + std::size_t EntrySize() const { return entry_size_; } + + void Overwrite(const void *start, std::size_t amount); + + private: + util::scoped_malloc data_; + + bool remains_; + + std::size_t entry_size_; + + util::scoped_FILE file_; +}; + +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab); + +} // namespace trie +} // namespace ngram +} // namespace lm + +#endif // LM_TRIE_SORT__ diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 08627efd..6a5a0196 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -1,37 +1,13 @@ #ifndef LM_VIRTUAL_INTERFACE__ #define LM_VIRTUAL_INTERFACE__ +#include "lm/return.hh" #include "lm/word_index.hh" #include "util/string_piece.hh" #include namespace lm { - -/* Structure returned by scoring routines. */ -struct FullScoreReturn { - // log10 probability - float prob; - - /* The length of n-gram matched. Do not use this for recombination. - * Consider a model containing only the following n-grams: - * -1 foo - * -3.14 bar - * -2.718 baz -5 - * -6 foo bar - * - * If you score ``bar'' then ngram_length is 1 and recombination state is the - * empty string because bar has zero backoff and does not extend to the - * right. - * If you score ``foo'' then ngram_length is 1 and recombination state is - * ``foo''. - * - * Ideally, keep output states around and compare them. Failing that, - * get out_state.ValidLength() and use that length for recombination. - */ - unsigned char ngram_length; -}; - namespace base { template class ModelFacade; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 04979d51..03b0767a 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -1,5 +1,6 @@ #include "lm/vocab.hh" +#include "lm/binary_format.hh" #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/config.hh" @@ -56,16 +57,6 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { } } -void WriteOrThrow(int fd, const void *data_void, std::size_t size) { - const uint8_t *data = static_cast(data_void); - while (size) { - ssize_t ret = write(fd, data, size); - if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); - data += ret; - size -= ret; - } -} - } // namespace WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {} @@ -80,7 +71,7 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { void WriteWordsWrapper::Write(int fd) { if ((off_t)-1 == lseek(fd, 0, SEEK_END)) UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words"); - WriteOrThrow(fd, buffer_.data(), buffer_.size()); + util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); } SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} @@ -146,15 +137,28 @@ void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { SetSpecial(Index(""), Index(""), 0); } +namespace { +const unsigned int kProbingVocabularyVersion = 0; +} // namespace + +namespace detail { +struct ProbingVocabularyHeader { + // Lowest unused vocab id. This is also the number of words, including . + unsigned int version; + WordIndex bound; +}; +} // namespace detail + ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { - return Lookup::Size(entries, config.probing_multiplier); + return Align8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); } void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { - lookup_ = Lookup(start, allocated); - available_ = 1; + header_ = static_cast(start); + lookup_ = Lookup(static_cast(start) + Align8(sizeof(detail::ProbingVocabularyHeader)), allocated); + bound_ = 1; saw_unk_ = false; } @@ -172,20 +176,24 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) { saw_unk_ = true; return 0; } else { - if (enumerate_) enumerate_->Add(available_, str); - lookup_.Insert(Lookup::Packing::Make(hashed, available_)); - return available_++; + if (enumerate_) enumerate_->Add(bound_, str); + lookup_.Insert(Lookup::Packing::Make(hashed, bound_)); + return bound_++; } } void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { lookup_.FinishedInserting(); + header_->bound = bound_; + header_->version = kProbingVocabularyVersion; SetSpecial(Index(""), Index(""), 0); } void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { + UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); lookup_.LoadedBinary(); - available_ = ReadWords(fd, to); + ReadWords(fd, to); + bound_ = header_->bound; SetSpecial(Index(""), Index(""), 0); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 9d218fff..41e97052 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -25,6 +25,7 @@ uint64_t HashForVocab(const char *str, std::size_t len); inline uint64_t HashForVocab(const StringPiece &str) { return HashForVocab(str.data(), str.length()); } +class ProbingVocabularyHeader; } // namespace detail class WriteWordsWrapper : public EnumerateVocab { @@ -113,10 +114,7 @@ class ProbingVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()). - // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary. - // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update. - // Specifically, the binary file format does not currently indicate whether is in count or not. - WordIndex Bound() const { return available_; } + WordIndex Bound() const { return bound_; } // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); @@ -141,11 +139,13 @@ class ProbingVocabulary : public base::Vocabulary { Lookup lookup_; - WordIndex available_; + WordIndex bound_; bool saw_unk_; EnumerateVocab *enumerate_; + + detail::ProbingVocabularyHeader *header_; }; void MissingUnknown(const Config &config) throw(SpecialWordMissingException); -- cgit v1.2.3