From 149232c38eec558ddb1097698d1570aacb67b59f Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 16 May 2012 13:24:08 -0700 Subject: Big kenlm change includes lower order models for probing only. And other stuff. --- klm/lm/Jamfile | 2 +- klm/lm/Makefile.am | 1 + klm/lm/binary_format.cc | 2 +- klm/lm/build_binary.cc | 97 ++++++++++------ klm/lm/config.cc | 1 + klm/lm/config.hh | 22 +++- klm/lm/left.hh | 110 +++++-------------- klm/lm/left_test.cc | 83 +++++++++----- klm/lm/max_order.hh | 2 +- klm/lm/model.cc | 192 +++++++++++++++++--------------- klm/lm/model.hh | 93 ++++++---------- klm/lm/model_test.cc | 42 ++++--- klm/lm/model_type.hh | 13 ++- klm/lm/ngram_query.cc | 18 +-- klm/lm/ngram_query.hh | 47 ++------ klm/lm/quantize.cc | 20 +++- klm/lm/quantize.hh | 164 ++++++++++++++------------- klm/lm/read_arpa.cc | 8 +- klm/lm/read_arpa.hh | 14 ++- klm/lm/return.hh | 3 + klm/lm/search_hashed.cc | 243 +++++++++++++++++++++++++++++------------ klm/lm/search_hashed.hh | 229 ++++++++++++++++++-------------------- klm/lm/search_trie.cc | 38 ++++--- klm/lm/search_trie.hh | 71 ++++++------ klm/lm/state.hh | 123 +++++++++++++++++++++ klm/lm/trie.cc | 61 ++++------- klm/lm/trie.hh | 61 ++++++----- klm/lm/value.hh | 157 ++++++++++++++++++++++++++ klm/lm/value_build.cc | 58 ++++++++++ klm/lm/value_build.hh | 97 ++++++++++++++++ klm/lm/vocab.cc | 2 +- klm/lm/vocab.hh | 6 +- klm/lm/weights.hh | 5 + klm/util/Jamfile | 2 +- klm/util/Makefile.am | 3 +- klm/util/bit_packing.hh | 7 ++ klm/util/ersatz_progress.cc | 8 +- klm/util/ersatz_progress.hh | 4 +- klm/util/file.cc | 10 -- klm/util/file.hh | 3 - klm/util/file_piece.cc | 17 ++- klm/util/file_piece.hh | 10 +- klm/util/have.hh | 10 +- klm/util/mmap.cc | 14 +++ klm/util/murmur_hash.cc | 11 +- klm/util/murmur_hash.hh | 6 +- klm/util/probing_hash_table.hh | 21 ++++ klm/util/usage.cc | 46 ++++++++ klm/util/usage.hh | 8 ++ 49 files changed, 1460 insertions(+), 805 deletions(-) create mode 100644 klm/lm/state.hh create mode 100644 klm/lm/value.hh create mode 100644 klm/lm/value_build.cc create mode 100644 klm/lm/value_build.hh create mode 100644 klm/util/usage.cc create mode 100644 klm/util/usage.hh diff --git a/klm/lm/Jamfile b/klm/lm/Jamfile index b84dbb35..b1971d88 100644 --- a/klm/lm/Jamfile +++ b/klm/lm/Jamfile @@ -1,4 +1,4 @@ -lib kenlm : bhiksha.cc binary_format.cc config.cc lm_exception.cc model.cc quantize.cc read_arpa.cc search_hashed.cc search_trie.cc trie.cc trie_sort.cc virtual_interface.cc vocab.cc ../util//kenutil : .. : : .. ../util//kenutil ; +lib kenlm : bhiksha.cc binary_format.cc config.cc lm_exception.cc model.cc quantize.cc read_arpa.cc search_hashed.cc search_trie.cc trie.cc trie_sort.cc value_build.cc virtual_interface.cc vocab.cc ../util//kenutil : .. : : .. ../util//kenutil ; import testing ; diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 54fd7f68..a12c5f03 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -24,6 +24,7 @@ libklm_a_SOURCES = \ search_trie.cc \ trie.cc \ trie_sort.cc \ + value_build.cc \ virtual_interface.cc \ vocab.cc diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 4796f6d1..a56e998e 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -57,7 +57,7 @@ struct Sanity { } }; -const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; +const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; std::size_t TotalHeaderSize(unsigned char order) { return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 8cbb69d0..c4a01cb4 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -66,16 +66,28 @@ uint8_t ParseBitCount(const char *from) { return val; } +void ParseFileList(const char *from, std::vector &to) { + to.clear(); + while (true) { + const char *i; + for (i = from; *i && *i != ' '; ++i) {} + to.push_back(std::string(from, i - from)); + if (!*i) break; + from = i + 1; + } +} + void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t sizes[5]; + std::size_t sizes[6]; sizes[0] = ProbingModel::Size(counts, config); - sizes[1] = TrieModel::Size(counts, config); - sizes[2] = QuantTrieModel::Size(counts, config); - sizes[3] = ArrayTrieModel::Size(counts, config); - sizes[4] = QuantArrayTrieModel::Size(counts, config); + sizes[1] = RestProbingModel::Size(counts, config); + sizes[2] = TrieModel::Size(counts, config); + sizes[3] = QuantTrieModel::Size(counts, config); + sizes[4] = ArrayTrieModel::Size(counts, config); + sizes[5] = QuantArrayTrieModel::Size(counts, config); std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); std::size_t divide; @@ -99,10 +111,11 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { for (long int i = 0; i < length - 2; ++i) std::cout << ' '; std::cout << prefix << "B\n" "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" - "trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" - "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" - "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" - "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; + "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r -p " << config.probing_multiplier << "\n" + "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n" + "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" + "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" + "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; } void ProbingQuantizationUnsupported() { @@ -118,10 +131,10 @@ int main(int argc, char *argv[]) { using namespace lm::ngram; try { - bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false; + bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false; lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:si")) != -1) { + while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:sir:")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -164,6 +177,11 @@ int main(int argc, char *argv[]) { case 'i': config.positive_log_probability = lm::SILENT; break; + case 'r': + rest = true; + ParseFileList(optarg, config.rest_lower_files); + config.rest_function = Config::REST_LOWER; + break; default: Usage(argv[0]); } @@ -174,35 +192,48 @@ int main(int argc, char *argv[]) { } if (optind + 1 == argc) { ShowSizes(argv[optind], config); - } else if (optind + 2 == argc) { + return 0; + } + const char *model_type; + const char *from_file; + + if (optind + 2 == argc) { + model_type = "probing"; + from_file = argv[optind]; config.write_mmap = argv[optind + 1]; - if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); - ProbingModel(argv[optind], config); } else if (optind + 3 == argc) { - const char *model_type = argv[optind]; - const char *from_file = argv[optind + 1]; + model_type = argv[optind]; + from_file = argv[optind + 1]; config.write_mmap = argv[optind + 2]; - if (!strcmp(model_type, "probing")) { - if (!set_write_method) config.write_method = Config::WRITE_AFTER; - if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); + } else { + Usage(argv[0]); + } + if (!strcmp(model_type, "probing")) { + if (!set_write_method) config.write_method = Config::WRITE_AFTER; + if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); + if (rest) { + RestProbingModel(from_file, config); + } else { ProbingModel(from_file, config); - } else if (!strcmp(model_type, "trie")) { - if (!set_write_method) config.write_method = Config::WRITE_MMAP; - if (quantize) { - if (bhiksha) { - QuantArrayTrieModel(from_file, config); - } else { - QuantTrieModel(from_file, config); - } + } + } else if (!strcmp(model_type, "trie")) { + if (rest) { + std::cerr << "Rest + trie is not supported yet." << std::endl; + return 1; + } + if (!set_write_method) config.write_method = Config::WRITE_MMAP; + if (quantize) { + if (bhiksha) { + QuantArrayTrieModel(from_file, config); } else { - if (bhiksha) { - ArrayTrieModel(from_file, config); - } else { - TrieModel(from_file, config); - } + QuantTrieModel(from_file, config); } } else { - Usage(argv[0]); + if (bhiksha) { + ArrayTrieModel(from_file, config); + } else { + TrieModel(from_file, config); + } } } else { Usage(argv[0]); diff --git a/klm/lm/config.cc b/klm/lm/config.cc index dbe762b3..f9d988ca 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -19,6 +19,7 @@ Config::Config() : write_mmap(NULL), write_method(WRITE_AFTER), include_vocab(true), + rest_function(REST_MAX), prob_bits(8), backoff_bits(8), pointer_bhiksha_bits(22), diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 01b75632..739cee9c 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -1,11 +1,13 @@ #ifndef LM_CONFIG__ #define LM_CONFIG__ -#include - #include "lm/lm_exception.hh" #include "util/mmap.hh" +#include +#include +#include + /* Configuration for ngram model. Separate header to reduce pollution. */ namespace lm { @@ -63,23 +65,33 @@ struct Config { const char *temporary_directory_prefix; // Level of complaining to do when loading from ARPA instead of binary format. - typedef enum {ALL, EXPENSIVE, NONE} ARPALoadComplain; + enum ARPALoadComplain {ALL, EXPENSIVE, NONE}; ARPALoadComplain arpa_complain; // While loading an ARPA file, also write out this binary format file. Set // to NULL to disable. const char *write_mmap; - typedef enum { + enum WriteMethod { WRITE_MMAP, // Map the file directly. WRITE_AFTER // Write after we're done. - } WriteMethod; + }; WriteMethod write_method; // Include the vocab in the binary file? Only effective if write_mmap != NULL. bool include_vocab; + // Left rest options. Only used when the model includes rest costs. + enum RestFunction { + REST_MAX, // Maximum of any score to the left + REST_LOWER, // Use lower-order files given below. + }; + RestFunction rest_function; + // Only used for REST_LOWER. + std::vector rest_lower_files; + + // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used diff --git a/klm/lm/left.hh b/klm/lm/left.hh index a07f9803..c00af88a 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -39,7 +39,7 @@ #define LM_LEFT__ #include "lm/max_order.hh" -#include "lm/model.hh" +#include "lm/state.hh" #include "lm/return.hh" #include "util/murmur_hash.hh" @@ -49,72 +49,6 @@ namespace lm { namespace ngram { -struct Left { - bool operator==(const Left &other) const { - return - (length == other.length) && - pointers[length - 1] == other.pointers[length - 1]; - } - - int Compare(const Left &other) const { - if (length != other.length) return length < other.length ? -1 : 1; - if (pointers[length - 1] > other.pointers[length - 1]) return 1; - if (pointers[length - 1] < other.pointers[length - 1]) return -1; - return 0; - } - - bool operator<(const Left &other) const { - if (length != other.length) return length < other.length; - return pointers[length - 1] < other.pointers[length - 1]; - } - - void ZeroRemaining() { - for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) - *i = 0; - } - - unsigned char length; - uint64_t pointers[kMaxOrder - 1]; -}; - -inline size_t hash_value(const Left &left) { - return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]); -} - -struct ChartState { - bool operator==(const ChartState &other) { - return (left == other.left) && (right == other.right) && (full == other.full); - } - - int Compare(const ChartState &other) const { - int lres = left.Compare(other.left); - if (lres) return lres; - int rres = right.Compare(other.right); - if (rres) return rres; - return (int)full - (int)other.full; - } - - bool operator<(const ChartState &other) const { - return Compare(other) == -1; - } - - void ZeroRemaining() { - left.ZeroRemaining(); - right.ZeroRemaining(); - } - - Left left; - bool full; - State right; -}; - -inline size_t hash_value(const ChartState &state) { - size_t hashes[2]; - hashes[0] = hash_value(state.left); - hashes[1] = hash_value(state.right); - return util::MurmurHashNative(hashes, sizeof(size_t) * 2, state.full); -} - template class RuleScore { public: explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { @@ -131,29 +65,30 @@ template class RuleScore { void Terminal(WordIndex word) { State copy(out_.right); FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); - prob_ += ret.prob; - if (left_done_) return; + if (left_done_) { prob_ += ret.prob; return; } if (ret.independent_left) { + prob_ += ret.prob; left_done_ = true; return; } out_.left.pointers[out_.left.length++] = ret.extend_left; + prob_ += ret.rest; if (out_.right.length != copy.length + 1) left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. - void BeginNonTerminal(const ChartState &in, float prob) { + void BeginNonTerminal(const ChartState &in, float prob = 0.0) { prob_ = prob; out_ = in; - left_done_ = in.full; + left_done_ = in.left.full; } - void NonTerminal(const ChartState &in, float prob) { + void NonTerminal(const ChartState &in, float prob = 0.0) { prob_ += prob; if (!in.left.length) { - if (in.full) { + if (in.left.full) { for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; left_done_ = true; out_.right = in.right; @@ -163,12 +98,15 @@ template class RuleScore { if (!out_.right.length) { out_.right = in.right; - if (left_done_) return; + if (left_done_) { + prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1); + return; + } if (out_.left.length) { left_done_ = true; } else { out_.left = in.left; - left_done_ = in.full; + left_done_ = in.left.full; } return; } @@ -186,7 +124,7 @@ template class RuleScore { std::swap(back, back2); } - if (in.full) { + if (in.left.full) { for (const float *i = back; i != back + next_use; ++i) prob_ += *i; left_done_ = true; out_.right = in.right; @@ -213,10 +151,17 @@ template class RuleScore { float Finish() { // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. - out_.full = left_done_ || (out_.left.length == model_.Order() - 1); + out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1); return prob_; } + void Reset() { + prob_ = 0.0; + left_done_ = false; + out_.left.length = 0; + out_.right.length = 0; + } + private: bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { ProcessRet(model_.ExtendLeft( @@ -228,8 +173,9 @@ template class RuleScore { if (next_use != out_.right.length) { left_done_ = true; if (!next_use) { - out_.right = in.right; // Early exit. + out_.right = in.right; + prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1); return true; } } @@ -238,13 +184,17 @@ template class RuleScore { } void ProcessRet(const FullScoreReturn &ret) { - prob_ += ret.prob; - if (left_done_) return; + if (left_done_) { + prob_ += ret.prob; + return; + } if (ret.independent_left) { + prob_ += ret.prob; left_done_ = true; return; } out_.left.pointers[out_.left.length++] = ret.extend_left; + prob_ += ret.rest; } const M &model_; diff --git a/klm/lm/left_test.cc b/klm/lm/left_test.cc index c85e5efa..b23e6a0f 100644 --- a/klm/lm/left_test.cc +++ b/klm/lm/left_test.cc @@ -24,7 +24,7 @@ template void Short(const M &m) { Term("loin"); BOOST_CHECK_CLOSE(-1.206319 - 0.3561665, score.Finish(), 0.001); } - BOOST_CHECK(base.full); + BOOST_CHECK(base.left.full); BOOST_CHECK_EQUAL(2, base.left.length); BOOST_CHECK_EQUAL(1, base.right.length); VCheck("loin", base.right.words[0]); @@ -40,7 +40,7 @@ template void Short(const M &m) { BOOST_CHECK_EQUAL(3, more_left.left.length); BOOST_CHECK_EQUAL(1, more_left.right.length); VCheck("loin", more_left.right.words[0]); - BOOST_CHECK(more_left.full); + BOOST_CHECK(more_left.left.full); ChartState shorter; { @@ -52,7 +52,7 @@ template void Short(const M &m) { BOOST_CHECK_EQUAL(1, shorter.left.length); BOOST_CHECK_EQUAL(1, shorter.right.length); VCheck("loin", shorter.right.words[0]); - BOOST_CHECK(shorter.full); + BOOST_CHECK(shorter.left.full); } template void Charge(const M &m) { @@ -66,7 +66,7 @@ template void Charge(const M &m) { BOOST_CHECK_EQUAL(1, base.left.length); BOOST_CHECK_EQUAL(1, base.right.length); VCheck("more", base.right.words[0]); - BOOST_CHECK(base.full); + BOOST_CHECK(base.left.full); ChartState extend; { @@ -78,7 +78,7 @@ template void Charge(const M &m) { BOOST_CHECK_EQUAL(2, extend.left.length); BOOST_CHECK_EQUAL(1, extend.right.length); VCheck("more", extend.right.words[0]); - BOOST_CHECK(extend.full); + BOOST_CHECK(extend.left.full); ChartState tobos; { @@ -91,9 +91,9 @@ template void Charge(const M &m) { BOOST_CHECK_EQUAL(1, tobos.right.length); } -template float LeftToRight(const M &m, const std::vector &words) { +template float LeftToRight(const M &m, const std::vector &words, bool begin_sentence = false) { float ret = 0.0; - State right = m.NullContextState(); + State right = begin_sentence ? m.BeginSentenceState() : m.NullContextState(); for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { State copy(right); ret += m.Score(copy, *i, right); @@ -101,12 +101,12 @@ template float LeftToRight(const M &m, const std::vector &w return ret; } -template float RightToLeft(const M &m, const std::vector &words) { +template float RightToLeft(const M &m, const std::vector &words, bool begin_sentence = false) { float ret = 0.0; ChartState state; state.left.length = 0; state.right.length = 0; - state.full = false; + state.left.full = false; for (std::vector::const_reverse_iterator i = words.rbegin(); i != words.rend(); ++i) { ChartState copy(state); RuleScore score(m, state); @@ -114,10 +114,17 @@ template float RightToLeft(const M &m, const std::vector &w score.NonTerminal(copy, ret); ret = score.Finish(); } + if (begin_sentence) { + ChartState copy(state); + RuleScore score(m, state); + score.BeginSentence(); + score.NonTerminal(copy, ret); + ret = score.Finish(); + } return ret; } -template float TreeMiddle(const M &m, const std::vector &words) { +template float TreeMiddle(const M &m, const std::vector &words, bool begin_sentence = false) { std::vector > states(words.size()); for (unsigned int i = 0; i < words.size(); ++i) { RuleScore score(m, states[i].first); @@ -137,7 +144,19 @@ template float TreeMiddle(const M &m, const std::vector &wo } std::swap(states, upper); } - return states.empty() ? 0 : states.back().second; + + if (states.empty()) return 0.0; + + if (begin_sentence) { + ChartState ignored; + RuleScore score(m, ignored); + score.BeginSentence(); + score.NonTerminal(states.front().first, states.front().second); + return score.Finish(); + } else { + return states.front().second; + } + } template void LookupVocab(const M &m, const StringPiece &str, std::vector &out) { @@ -148,16 +167,15 @@ template void LookupVocab(const M &m, const StringPiece &str, std::vec } #define TEXT_TEST(str) \ -{ \ - std::vector words; \ LookupVocab(m, str, words); \ - float expect = LeftToRight(m, words); \ - BOOST_CHECK_CLOSE(expect, RightToLeft(m, words), 0.001); \ - BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words), 0.001); \ -} + expect = LeftToRight(m, words, rest); \ + BOOST_CHECK_CLOSE(expect, RightToLeft(m, words, rest), 0.001); \ + BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words, rest), 0.001); \ // Build sentences, or parts thereof, from right to left. -template void GrowBig(const M &m) { +template void GrowBig(const M &m, bool rest = false) { + std::vector words; + float expect; TEXT_TEST("in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown "); TEXT_TEST("on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown "); TEXT_TEST("on a little more loin also would consider higher to look good"); @@ -171,6 +189,14 @@ template void GrowBig(const M &m) { TEXT_TEST("consider higher"); } +template void GrowSmall(const M &m, bool rest = false) { + std::vector words; + float expect; + TEXT_TEST("in biarritz watching considering looking . "); + TEXT_TEST("in biarritz watching considering looking ."); + TEXT_TEST("in biarritz"); +} + template void AlsoWouldConsiderHigher(const M &m) { ChartState also; { @@ -210,7 +236,7 @@ template void AlsoWouldConsiderHigher(const M &m) { } BOOST_CHECK_EQUAL(1, consider.left.length); BOOST_CHECK_EQUAL(1, consider.right.length); - BOOST_CHECK(!consider.full); + BOOST_CHECK(!consider.left.full); ChartState higher; float higher_score; @@ -222,7 +248,7 @@ template void AlsoWouldConsiderHigher(const M &m) { BOOST_CHECK_CLOSE(-1.509559, higher_score, 0.001); BOOST_CHECK_EQUAL(1, higher.left.length); BOOST_CHECK_EQUAL(1, higher.right.length); - BOOST_CHECK(!higher.full); + BOOST_CHECK(!higher.left.full); VCheck("higher", higher.right.words[0]); BOOST_CHECK_CLOSE(-0.30103, higher.right.backoff[0], 0.001); @@ -234,7 +260,7 @@ template void AlsoWouldConsiderHigher(const M &m) { BOOST_CHECK_CLOSE(-1.509559 - 1.687872 - 0.30103, score.Finish(), 0.001); } BOOST_CHECK_EQUAL(2, consider_higher.left.length); - BOOST_CHECK(!consider_higher.full); + BOOST_CHECK(!consider_higher.left.full); ChartState full; { @@ -246,12 +272,6 @@ template void AlsoWouldConsiderHigher(const M &m) { BOOST_CHECK_EQUAL(4, full.right.length); } -template void GrowSmall(const M &m) { - TEXT_TEST("in biarritz watching considering looking . "); - TEXT_TEST("in biarritz watching considering looking ."); - TEXT_TEST("in biarritz"); -} - #define CHECK_SCORE(str, val) \ { \ float got = val; \ @@ -315,7 +335,7 @@ template void FullGrow(const M &m) { CHECK_SCORE("looking . ", l2_scores[1] = score.Finish()); } BOOST_CHECK_EQUAL(l2[1].left.length, 1); - BOOST_CHECK(l2[1].full); + BOOST_CHECK(l2[1].left.full); ChartState top; { @@ -362,6 +382,13 @@ BOOST_AUTO_TEST_CASE(ArrayTrieAll) { Everything(); } +BOOST_AUTO_TEST_CASE(RestProbing) { + Config config; + config.messages = NULL; + RestProbingModel m(FileLocation(), config); + GrowBig(m, true); +} + } // namespace } // namespace ngram } // namespace lm diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index 71cd23dd..aff9de27 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -6,7 +6,7 @@ namespace ngram { // Having this limit means that State can be // (kMaxOrder - 1) * sizeof(float) bytes instead of // sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead -const unsigned char kMaxOrder = 6; +const unsigned char kMaxOrder = 5; } // namespace ngram } // namespace lm diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 478ebed1..c081788c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -38,10 +38,13 @@ template GenericModel::Ge State begin_sentence = State(); begin_sentence.length = 1; begin_sentence.words[0] = vocab_.BeginSentence(); - begin_sentence.backoff[0] = search_.unigram.Lookup(begin_sentence.words[0]).backoff; + typename Search::Node ignored_node; + bool ignored_independent_left; + uint64_t ignored_extend_left; + begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); State null_context = State(); null_context.length = 0; - P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); + P::Init(begin_sentence, null_context, vocab_, search_.Order()); } template void GenericModel::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { @@ -50,6 +53,9 @@ template void GenericModel void GenericModel::InitializeFromARPA(const char *file, const Config &config) { // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. util::FilePiece f(backing_.file.release(), file, config.messages); @@ -79,8 +85,8 @@ template void GenericModel FullScoreReturn GenericModel(start)) return ret; + + bool independent_left; + uint64_t extend_left; + typename Search::Node node; if (start <= 1) { - ret.prob += search_.unigram.Lookup(*context_rbegin).backoff; + ret.prob += search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff(); start = 2; - } - typename Search::Node node; - if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { + } else if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { return ret; } - float backoff; // i is the order of the backoff we're looking for. - typename Search::MiddleIter mid_iter = search_.MiddleBegin() + start - 2; - for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) { - if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break; - ret.prob += backoff; + unsigned char order_minus_2 = 0; + for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++order_minus_2) { + typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left)); + if (!p.Found()) break; + ret.prob += p.Backoff(); } return ret; } @@ -134,17 +142,20 @@ template void GenericModel FullScoreReturn GenericModel(extend_pointer), node, ret.independent_left, ret.extend_left)); + ret.rest = ptr.Rest(); + ret.prob = ptr.Prob(); + assert(!ret.independent_left); } else { - ret.ngram_length = P::Order(); + typename Search::MiddlePointer ptr(search_.Unpack(extend_pointer, extend_length, node)); + ret.rest = ptr.Rest(); + ret.prob = ptr.Prob(); + ret.extend_left = extend_pointer; + // If this function is called, then it does depend on left words. + ret.independent_left = false; } - ret.independent_left = true; + float subtract_me = ret.rest; + ret.ngram_length = extend_length; + next_use = extend_length; + ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret); + next_use -= extend_length; + // Charge backoffs. + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; ret.prob -= subtract_me; + ret.rest -= subtract_me; return ret; } @@ -215,66 +212,83 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) { * new_word. */ template FullScoreReturn GenericModel::ScoreExceptBackoff( - const WordIndex *context_rbegin, - const WordIndex *context_rend, + const WordIndex *const context_rbegin, + const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const { FullScoreReturn ret; // ret.ngram_length contains the last known non-blank ngram length. ret.ngram_length = 1; - float *backoff_out(out_state.backoff); typename Search::Node node; - search_.LookupUnigram(new_word, *backoff_out, node, ret); + typename Search::UnigramPointer uni(search_.LookupUnigram(new_word, node, ret.independent_left, ret.extend_left)); + out_state.backoff[0] = uni.Backoff(); + ret.prob = uni.Prob(); + ret.rest = uni.Rest(); + // This is the length of the context that should be used for continuation to the right. - out_state.length = HasExtension(*backoff_out) ? 1 : 0; + out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; // We'll write the word anyway since it will probably be used and does no harm being there. out_state.words[0] = new_word; if (context_rbegin == context_rend) return ret; - ++backoff_out; - - // Ok start by looking up the bigram. - const WordIndex *hist_iter = context_rbegin; - typename Search::MiddleIter mid_iter(search_.MiddleBegin()); - for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { - if (hist_iter == context_rend) { - // Ran out of history. Typically no backoff, but this could be a blank. - CopyRemainingHistory(context_rbegin, out_state); - // ret.prob was already set. - return ret; - } - if (mid_iter == search_.MiddleEnd()) break; + ResumeScore(context_rbegin, context_rend, 0, node, out_state.backoff + 1, out_state.length, ret); + CopyRemainingHistory(context_rbegin, out_state); + return ret; +} - if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *hist_iter, *backoff_out, node, ret)) { - // Didn't find an ngram using hist_iter. - CopyRemainingHistory(context_rbegin, out_state); - // ret.prob was already set. - ret.independent_left = true; - return ret; - } - ret.ngram_length = hist_iter - context_rbegin + 2; +template void GenericModel::ResumeScore(const WordIndex *hist_iter, const WordIndex *const context_rend, unsigned char order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const { + for (; ; ++order_minus_2, ++hist_iter, ++backoff_out) { + if (hist_iter == context_rend) return; + if (ret.independent_left) return; + if (order_minus_2 == P::Order() - 2) break; + + typename Search::MiddlePointer pointer(search_.LookupMiddle(order_minus_2, *hist_iter, node, ret.independent_left, ret.extend_left)); + if (!pointer.Found()) return; + *backoff_out = pointer.Backoff(); + ret.prob = pointer.Prob(); + ret.rest = pointer.Rest(); + ret.ngram_length = order_minus_2 + 2; if (HasExtension(*backoff_out)) { - out_state.length = ret.ngram_length; + next_use = ret.ngram_length; } } - - // It passed every lookup in search_.middle. All that's left is to check search_.longest. - if (!ret.independent_left && search_.LookupLongest(*hist_iter, ret.prob, node)) { - // It's an P::Order()-gram. + ret.independent_left = true; + typename Search::LongestPointer longest(search_.LookupLongest(*hist_iter, node)); + if (longest.Found()) { + ret.prob = longest.Prob(); + ret.rest = ret.prob; // There is no blank in longest_. ret.ngram_length = P::Order(); } - // This handles (N-1)-grams and N-grams. - CopyRemainingHistory(context_rbegin, out_state); - ret.independent_left = true; +} + +template float GenericModel::InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const { + float ret; + typename Search::Node node; + if (first_length == 1) { + if (pointers_begin >= pointers_end) return 0.0; + bool independent_left; + uint64_t extend_left; + typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast(*pointers_begin), node, independent_left, extend_left)); + ret = ptr.Prob() - ptr.Rest(); + ++first_length; + ++pointers_begin; + } else { + ret = 0.0; + } + for (const uint64_t *i = pointers_begin; i < pointers_end; ++i, ++first_length) { + typename Search::MiddlePointer ptr(search_.Unpack(*i, first_length, node)); + ret += ptr.Prob() - ptr.Rest(); + } return ret; } -template class GenericModel; // HASH_PROBING -template class GenericModel, SortedVocabulary>; // TRIE_SORTED +template class GenericModel, ProbingVocabulary>; +template class GenericModel, ProbingVocabulary>; +template class GenericModel, SortedVocabulary>; template class GenericModel, SortedVocabulary>; -template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; template class GenericModel, SortedVocabulary>; } // namespace detail diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 6ea62a78..be872178 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -9,6 +9,8 @@ #include "lm/quantize.hh" #include "lm/search_hashed.hh" #include "lm/search_trie.hh" +#include "lm/state.hh" +#include "lm/value.hh" #include "lm/vocab.hh" #include "lm/weights.hh" @@ -23,48 +25,6 @@ namespace util { class FilePiece; } namespace lm { namespace ngram { - -// This is a POD but if you want memcmp to return the same as operator==, call -// ZeroRemaining first. -class State { - public: - bool operator==(const State &other) const { - if (length != other.length) return false; - return !memcmp(words, other.words, length * sizeof(WordIndex)); - } - - // Three way comparison function. - int Compare(const State &other) const { - if (length != other.length) return length < other.length ? -1 : 1; - return memcmp(words, other.words, length * sizeof(WordIndex)); - } - - bool operator<(const State &other) const { - if (length != other.length) return length < other.length; - return memcmp(words, other.words, length * sizeof(WordIndex)) < 0; - } - - // Call this before using raw memcmp. - void ZeroRemaining() { - for (unsigned char i = length; i < kMaxOrder - 1; ++i) { - words[i] = 0; - backoff[i] = 0.0; - } - } - - unsigned char Length() const { return length; } - - // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. - // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. - WordIndex words[kMaxOrder - 1]; - float backoff[kMaxOrder - 1]; - unsigned char length; -}; - -inline size_t hash_value(const State &state) { - return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); -} - namespace detail { // Should return the same results as SRI. @@ -119,8 +79,7 @@ template class GenericModel : public base::Mod /* More efficient version of FullScore where a partial n-gram has already * been scored. - * NOTE: THE RETURNED .prob IS RELATIVE, NOT ABSOLUTE. So for example, if - * the n-gram does not end up extending further left, then 0 is returned. + * NOTE: THE RETURNED .rest AND .prob ARE RELATIVE TO THE .rest RETURNED BEFORE. */ FullScoreReturn ExtendLeft( // Additional context in reverse order. This will update add_rend to @@ -136,12 +95,24 @@ template class GenericModel : public base::Mod // Amount of additional content that should be considered by the next call. unsigned char &next_use) const; + /* Return probabilities minus rest costs for an array of pointers. The + * first length should be the length of the n-gram to which pointers_begin + * points. + */ + float UnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const { + // Compiler should optimize this if away. + return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0; + } + private: friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel &to); static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); - FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; + FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const; + + // Score bigrams and above. Do not include backoff. + void ResumeScore(const WordIndex *context_rbegin, const WordIndex *const context_rend, unsigned char starting_order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const; // Appears after Size in the cc file. void SetupMemory(void *start, const std::vector &counts, const Config &config); @@ -150,32 +121,38 @@ template class GenericModel : public base::Mod void InitializeFromARPA(const char *file, const Config &config); + float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const; + Backing &MutableBacking() { return backing_; } Backing backing_; VocabularyT vocab_; - typedef typename Search::Middle Middle; - Search search_; }; } // namespace detail -// These must also be instantiated in the cc file. -typedef ::lm::ngram::ProbingVocabulary Vocabulary; -typedef detail::GenericModel ProbingModel; // HASH_PROBING -// Default implementation. No real reason for it to be the default. -typedef ProbingModel Model; +// Instead of typedef, inherit. This allows the Model etc to be forward declared. +// Oh the joys of C and C++. +#define LM_COMMA() , +#define LM_NAME_MODEL(name, from)\ +class name : public from {\ + public:\ + name(const char *file, const Config &config = Config()) : from(file, config) {}\ +}; -// Smaller implementation. -typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED -typedef detail::GenericModel, SortedVocabulary> ArrayTrieModel; +LM_NAME_MODEL(ProbingModel, detail::GenericModel LM_COMMA() ProbingVocabulary>); +LM_NAME_MODEL(RestProbingModel, detail::GenericModel LM_COMMA() ProbingVocabulary>); +LM_NAME_MODEL(TrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); +LM_NAME_MODEL(ArrayTrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); +LM_NAME_MODEL(QuantTrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); +LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); -typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED -typedef detail::GenericModel, SortedVocabulary> QuantArrayTrieModel; +// Default implementation. No real reason for it to be the default. +typedef ::lm::ngram::ProbingVocabulary Vocabulary; +typedef ProbingModel Model; } // namespace ngram } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 461704d4..8a122c60 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -30,7 +30,15 @@ const char *TestNoUnkLocation() { return "test_nounk.arpa"; } return boost::unit_test::framework::master_test_suite().argv[2]; +} +template State GetState(const Model &model, const char *word, const State &in) { + WordIndex context[in.length + 1]; + context[0] = model.GetVocabulary().Index(word); + std::copy(in.words, in.words + in.length, context + 1); + State ret; + model.GetState(context, context + in.length + 1, ret); + return ret; } #define StartTest(word, ngram, score, indep_left) \ @@ -42,14 +50,7 @@ const char *TestNoUnkLocation() { BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); \ BOOST_CHECK_GE(std::min(ngram, 5 - 1), out.length); \ BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \ - {\ - WordIndex context[state.length + 1]; \ - context[0] = model.GetVocabulary().Index(word); \ - std::copy(state.words, state.words + state.length, context + 1); \ - State get_state; \ - model.GetState(context, context + state.length + 1, get_state); \ - BOOST_CHECK_EQUAL(out, get_state); \ - } + BOOST_CHECK_EQUAL(out, GetState(model, word, state)); #define AppendTest(word, ngram, score, indep_left) \ StartTest(word, ngram, score, indep_left) \ @@ -182,7 +183,7 @@ template void ExtendLeftTest(const M &model) { FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use)); BOOST_CHECK_EQUAL(0, next_use); BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left); - BOOST_CHECK_CLOSE(0.0, extend_none.prob, 0.001); + BOOST_CHECK_CLOSE(little.prob - little.rest, extend_none.prob, 0.001); BOOST_CHECK_EQUAL(1, extend_none.ngram_length); const WordIndex a = model.GetVocabulary().Index("a"); @@ -191,7 +192,7 @@ template void ExtendLeftTest(const M &model) { FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use)); BOOST_CHECK_EQUAL(1, next_use); BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001); - BOOST_CHECK_CLOSE(-0.09132547 - kLittleProb, extend_a.prob, 0.001); + BOOST_CHECK_CLOSE(-0.09132547 - little.rest, extend_a.prob, 0.001); BOOST_CHECK_EQUAL(2, extend_a.ngram_length); BOOST_CHECK(!extend_a.independent_left); @@ -199,7 +200,7 @@ template void ExtendLeftTest(const M &model) { FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use)); BOOST_CHECK_EQUAL(1, next_use); BOOST_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001); - BOOST_CHECK_CLOSE(-0.0283603 - -0.09132547, extend_on.prob, 0.001); + BOOST_CHECK_CLOSE(-0.0283603 - (extend_a.rest + little.rest), extend_on.prob, 0.001); BOOST_CHECK_EQUAL(3, extend_on.ngram_length); BOOST_CHECK(!extend_on.independent_left); @@ -209,7 +210,7 @@ template void ExtendLeftTest(const M &model) { BOOST_CHECK_EQUAL(2, next_use); BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001); BOOST_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001); - BOOST_CHECK_CLOSE(-0.0283603 - kLittleProb, extend_both.prob, 0.001); + BOOST_CHECK_CLOSE(-0.0283603 - little.rest, extend_both.prob, 0.001); BOOST_CHECK_EQUAL(3, extend_both.ngram_length); BOOST_CHECK(!extend_both.independent_left); BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left); @@ -399,7 +400,10 @@ template void BinaryTest() { } BOOST_AUTO_TEST_CASE(write_and_read_probing) { - BinaryTest(); + BinaryTest(); +} +BOOST_AUTO_TEST_CASE(write_and_read_rest_probing) { + BinaryTest(); } BOOST_AUTO_TEST_CASE(write_and_read_trie) { BinaryTest(); @@ -414,6 +418,18 @@ BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) { BinaryTest(); } +BOOST_AUTO_TEST_CASE(rest_max) { + Config config; + config.arpa_complain = Config::NONE; + config.messages = NULL; + + RestProbingModel model(TestLocation(), config); + State state, out; + FullScoreReturn ret(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("."), state)); + BOOST_CHECK_CLOSE(-0.2705918, ret.rest, 0.001); + BOOST_CHECK_CLOSE(-0.01916512, model.FullScore(state, model.GetVocabulary().EndSentence(), out).rest, 0.001); +} + } // namespace } // namespace ngram } // namespace lm diff --git a/klm/lm/model_type.hh b/klm/lm/model_type.hh index 5057ed25..8b35c793 100644 --- a/klm/lm/model_type.hh +++ b/klm/lm/model_type.hh @@ -6,10 +6,17 @@ namespace ngram { /* Not the best numbering system, but it grew this way for historical reasons * and I want to preserve existing binary files. */ -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; +typedef enum {PROBING=0, REST_PROBING=1, TRIE=2, QUANT_TRIE=3, ARRAY_TRIE=4, QUANT_ARRAY_TRIE=5} ModelType; -const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); -const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); +// Historical names. +const ModelType HASH_PROBING = PROBING; +const ModelType TRIE_SORTED = TRIE; +const ModelType QUANT_TRIE_SORTED = QUANT_TRIE; +const ModelType ARRAY_TRIE_SORTED = ARRAY_TRIE; +const ModelType QUANT_ARRAY_TRIE_SORTED = QUANT_ARRAY_TRIE; + +const static ModelType kQuantAdd = static_cast(QUANT_TRIE - TRIE); +const static ModelType kArrayAdd = static_cast(ARRAY_TRIE - TRIE); } // namespace ngram } // namespace lm diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 8f7a0e1c..49757d9a 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -12,22 +12,24 @@ int main(int argc, char *argv[]) { ModelType model_type; if (RecognizeBinary(argv[1], model_type)) { switch(model_type) { - case HASH_PROBING: + case PROBING: Query(argv[1], sentence_context, std::cin, std::cout); break; - case TRIE_SORTED: + case REST_PROBING: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + case TRIE: Query(argv[1], sentence_context, std::cin, std::cout); break; - case QUANT_TRIE_SORTED: + case QUANT_TRIE: Query(argv[1], sentence_context, std::cin, std::cout); break; - case ARRAY_TRIE_SORTED: + case ARRAY_TRIE: Query(argv[1], sentence_context, std::cin, std::cout); break; - case QUANT_ARRAY_TRIE_SORTED: + case QUANT_ARRAY_TRIE: Query(argv[1], sentence_context, std::cin, std::cout); break; - case HASH_SORTED: default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; abort(); @@ -35,8 +37,8 @@ int main(int argc, char *argv[]) { } else { Query(argv[1], sentence_context, std::cin, std::cout); } - - PrintUsage("Total time including destruction:\n"); + std::cerr << "Total time including destruction:\n"; + util::PrintUsage(std::cerr); } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; diff --git a/klm/lm/ngram_query.hh b/klm/lm/ngram_query.hh index 4990df22..dfcda170 100644 --- a/klm/lm/ngram_query.hh +++ b/klm/lm/ngram_query.hh @@ -3,51 +3,20 @@ #include "lm/enumerate_vocab.hh" #include "lm/model.hh" +#include "util/usage.hh" #include -#include #include +#include +#include #include -#include -#if !defined(_WIN32) && !defined(_WIN64) -#include -#include -#endif - namespace lm { namespace ngram { -#if !defined(_WIN32) && !defined(_WIN64) -float FloatSec(const struct timeval &tv) { - return static_cast(tv.tv_sec) + (static_cast(tv.tv_usec) / 1000000000.0); -} -#endif - -void PrintUsage(const char *message) { -#if !defined(_WIN32) && !defined(_WIN64) - struct rusage usage; - if (getrusage(RUSAGE_SELF, &usage)) { - perror("getrusage"); - return; - } - std::cerr << message; - std::cerr << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n'; - - // Linux doesn't set memory usage :-(. - std::ifstream status("/proc/self/status", std::ios::in); - std::string line; - while (getline(status, line)) { - if (!strncmp(line.c_str(), "VmRSS:\t", 7)) { - std::cerr << "rss " << (line.c_str() + 7) << '\n'; - break; - } - } -#endif -} - template void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { - PrintUsage("Loading statistics:\n"); + std::cerr << "Loading statistics:\n"; + util::PrintUsage(std::cerr); typename Model::State state, out; lm::FullScoreReturn ret; std::string word; @@ -84,13 +53,13 @@ template void Query(const Model &model, bool sentence_context, std out_stream << "=" << model.GetVocabulary().EndSentence() << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\t'; } out_stream << "Total: " << total << " OOV: " << oov << '\n'; - } - PrintUsage("After queries:\n"); + } + std::cerr << "After queries:\n"; + util::PrintUsage(std::cerr); } template void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { Config config; -// config.load_method = util::LAZY; M model(file, config); Query(model, sentence_context, in_stream, out_stream); } diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index a8e0cb21..b58c3f3f 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -47,9 +47,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector(static_cast(start) + 8); +void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) { prob_bits_ = config.prob_bits; backoff_bits_ = config.backoff_bits; // We need the reserved values. @@ -57,25 +55,35 @@ void SeparatelyQuantize::SetupMemory(void *start, const Config &config) { if (config.backoff_bits == 0) UTIL_THROW(ConfigException, "You can't quantize backoff to zero"); if (config.prob_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing probability supports at most 25 bits. Currently you have requested " << static_cast(config.prob_bits) << " bits."); if (config.backoff_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing backoff supports at most 25 bits. Currently you have requested " << static_cast(config.backoff_bits) << " bits."); + // Reserve 8 byte header for bit counts. + actual_base_ = static_cast(base); + float *start = reinterpret_cast(actual_base_ + 8); + for (unsigned char i = 0; i < order - 2; ++i) { + tables_[i][0] = Bins(prob_bits_, start); + start += (1ULL << prob_bits_); + tables_[i][1] = Bins(backoff_bits_, start); + start += (1ULL << backoff_bits_); + } + longest_ = tables_[order - 2][0] = Bins(prob_bits_, start); } void SeparatelyQuantize::Train(uint8_t order, std::vector &prob, std::vector &backoff) { TrainProb(order, prob); // Backoff - float *centers = start_ + TableStart(order) + ProbTableLength(); + float *centers = tables_[order - 2][1].Populate(); *(centers++) = kNoExtensionBackoff; *(centers++) = kExtensionBackoff; MakeBins(backoff, centers, (1ULL << backoff_bits_) - 2); } void SeparatelyQuantize::TrainProb(uint8_t order, std::vector &prob) { - float *centers = start_ + TableStart(order); + float *centers = tables_[order - 2][0].Populate(); MakeBins(prob, centers, (1ULL << prob_bits_)); } void SeparatelyQuantize::FinishedLoading(const Config &config) { - uint8_t *actual_base = reinterpret_cast(start_) - 8; + uint8_t *actual_base = actual_base_; *(actual_base++) = kSeparatelyQuantizeVersion; // version *(actual_base++) = config.prob_bits; *(actual_base++) = config.backoff_bits; diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 6d130a57..3e9153e3 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -3,6 +3,7 @@ #include "lm/blank.hh" #include "lm/config.hh" +#include "lm/max_order.hh" #include "lm/model_type.hh" #include "util/bit_packing.hh" @@ -27,37 +28,60 @@ class DontQuantize { static uint8_t MiddleBits(const Config &/*config*/) { return 63; } static uint8_t LongestBits(const Config &/*config*/) { return 31; } - struct Middle { - void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { - util::WriteNonPositiveFloat31(base, bit_offset, prob); - util::WriteFloat32(base, bit_offset + 31, backoff); - } - void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { - prob = util::ReadNonPositiveFloat31(base, bit_offset); - backoff = util::ReadFloat32(base, bit_offset + 31); - } - void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { - prob = util::ReadNonPositiveFloat31(base, bit_offset); - } - void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { - backoff = util::ReadFloat32(base, bit_offset + 31); - } - uint8_t TotalBits() const { return 63; } + class MiddlePointer { + public: + MiddlePointer(const DontQuantize & /*quant*/, unsigned char /*order_minus_2*/, util::BitAddress address) : address_(address) {} + + MiddlePointer() : address_(NULL, 0) {} + + bool Found() const { + return address_.base != NULL; + } + + float Prob() const { + return util::ReadNonPositiveFloat31(address_.base, address_.offset); + } + + float Backoff() const { + return util::ReadFloat32(address_.base, address_.offset + 31); + } + + float Rest() const { return Prob(); } + + void Write(float prob, float backoff) { + util::WriteNonPositiveFloat31(address_.base, address_.offset, prob); + util::WriteFloat32(address_.base, address_.offset + 31, backoff); + } + + private: + util::BitAddress address_; }; - struct Longest { - void Write(void *base, uint64_t bit_offset, float prob) const { - util::WriteNonPositiveFloat31(base, bit_offset, prob); - } - void Read(const void *base, uint64_t bit_offset, float &prob) const { - prob = util::ReadNonPositiveFloat31(base, bit_offset); - } - uint8_t TotalBits() const { return 31; } + class LongestPointer { + public: + explicit LongestPointer(const DontQuantize &/*quant*/, util::BitAddress address) : address_(address) {} + + LongestPointer() : address_(NULL, 0) {} + + bool Found() const { + return address_.base != NULL; + } + + float Prob() const { + return util::ReadNonPositiveFloat31(address_.base, address_.offset); + } + + void Write(float prob) { + util::WriteNonPositiveFloat31(address_.base, address_.offset, prob); + } + + private: + util::BitAddress address_; }; DontQuantize() {} - void SetupMemory(void * /*start*/, const Config & /*config*/) {} + void SetupMemory(void * /*start*/, unsigned char /*order*/, const Config & /*config*/) {} static const bool kTrain = false; // These should never be called because kTrain is false. @@ -65,9 +89,6 @@ class DontQuantize { void TrainProb(uint8_t, std::vector &/*prob*/) {} void FinishedLoading(const Config &) {} - - Middle Mid(uint8_t /*order*/) const { return Middle(); } - Longest Long(uint8_t /*order*/) const { return Longest(); } }; class SeparatelyQuantize { @@ -77,7 +98,9 @@ class SeparatelyQuantize { // Sigh C++ default constructor Bins() {} - Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} + Bins(uint8_t bits, float *begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} + + float *Populate() { return begin_; } uint64_t EncodeProb(float value) const { return Encode(value, 0); @@ -98,13 +121,13 @@ class SeparatelyQuantize { private: uint64_t Encode(float value, size_t reserved) const { - const float *above = std::lower_bound(begin_ + reserved, end_, value); + const float *above = std::lower_bound(static_cast(begin_) + reserved, end_, value); if (above == begin_ + reserved) return reserved; if (above == end_) return end_ - begin_ - 1; return above - begin_ - (value - *(above - 1) < *above - value); } - const float *begin_; + float *begin_; const float *end_; uint8_t bits_; uint64_t mask_; @@ -125,65 +148,61 @@ class SeparatelyQuantize { static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; } static uint8_t LongestBits(const Config &config) { return config.prob_bits; } - class Middle { + class MiddlePointer { public: - Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) : - total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {} + MiddlePointer(const SeparatelyQuantize &quant, unsigned char order_minus_2, const util::BitAddress &address) : bins_(quant.GetTables(order_minus_2)), address_(address) {} - void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { - util::WriteInt57(base, bit_offset, total_bits_, - (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); - } + MiddlePointer() : address_(NULL, 0) {} - void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { - prob = prob_.Decode(util::ReadInt25(base, bit_offset + backoff_.Bits(), prob_.Bits(), prob_.Mask())); - } + bool Found() const { return address_.base != NULL; } - void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { - uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); - prob = prob_.Decode(both >> backoff_.Bits()); - backoff = backoff_.Decode(both & backoff_.Mask()); + float Prob() const { + return ProbBins().Decode(util::ReadInt25(address_.base, address_.offset + BackoffBins().Bits(), ProbBins().Bits(), ProbBins().Mask())); } - void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { - backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask())); + float Backoff() const { + return BackoffBins().Decode(util::ReadInt25(address_.base, address_.offset, BackoffBins().Bits(), BackoffBins().Mask())); } - uint8_t TotalBits() const { - return total_bits_; + float Rest() const { return Prob(); } + + void Write(float prob, float backoff) const { + util::WriteInt57(address_.base, address_.offset, ProbBins().Bits() + BackoffBins().Bits(), + (ProbBins().EncodeProb(prob) << BackoffBins().Bits()) | BackoffBins().EncodeBackoff(backoff)); } private: - const uint8_t total_bits_; - const uint64_t total_mask_; - const Bins prob_; - const Bins backoff_; + const Bins &ProbBins() const { return bins_[0]; } + const Bins &BackoffBins() const { return bins_[1]; } + const Bins *bins_; + + util::BitAddress address_; }; - class Longest { + class LongestPointer { public: - // Sigh C++ default constructor - Longest() {} + LongestPointer(const SeparatelyQuantize &quant, const util::BitAddress &address) : table_(&quant.LongestTable()), address_(address) {} + + LongestPointer() : address_(NULL, 0) {} - Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {} + bool Found() const { return address_.base != NULL; } - void Write(void *base, uint64_t bit_offset, float prob) const { - util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob)); + void Write(float prob) const { + util::WriteInt25(address_.base, address_.offset, table_->Bits(), table_->EncodeProb(prob)); } - void Read(const void *base, uint64_t bit_offset, float &prob) const { - prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask())); + float Prob() const { + return table_->Decode(util::ReadInt25(address_.base, address_.offset, table_->Bits(), table_->Mask())); } - uint8_t TotalBits() const { return prob_.Bits(); } - private: - Bins prob_; + const Bins *table_; + util::BitAddress address_; }; SeparatelyQuantize() {} - void SetupMemory(void *start, const Config &config); + void SetupMemory(void *start, unsigned char order, const Config &config); static const bool kTrain = true; // Assumes 0.0 is removed from backoff. @@ -193,18 +212,17 @@ class SeparatelyQuantize { void FinishedLoading(const Config &config); - Middle Mid(uint8_t order) const { - const float *table = start_ + TableStart(order); - return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength()); - } + const Bins *GetTables(unsigned char order_minus_2) const { return tables_[order_minus_2]; } - Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); } + const Bins &LongestTable() const { return longest_; } private: - size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast(order - 2); } - size_t ProbTableLength() const { return (1ULL << prob_bits_); } + Bins tables_[kMaxOrder - 1][2]; + + Bins longest_; + + uint8_t *actual_base_; - float *start_; uint8_t prob_bits_, backoff_bits_; }; diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 05f761be..2d9a337d 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -83,7 +83,7 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) { } } -void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) { +void ReadBackoff(util::FilePiece &in, float &backoff) { // Always make zero negative. // Negative zero means that no (n+1)-gram has this n-gram as context. // Therefore the hypothesis state can be shorter. Of course, many n-grams @@ -91,12 +91,12 @@ void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) { // back and set the backoff to positive zero in these cases. switch (in.get()) { case '\t': - weights.backoff = in.ReadFloat(); - if (weights.backoff == ngram::kExtensionBackoff) weights.backoff = ngram::kNoExtensionBackoff; + backoff = in.ReadFloat(); + if (backoff == ngram::kExtensionBackoff) backoff = ngram::kNoExtensionBackoff; if ((in.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); break; case '\n': - weights.backoff = ngram::kNoExtensionBackoff; + backoff = ngram::kNoExtensionBackoff; break; default: UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff"); diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh index ab996bde..234d130c 100644 --- a/klm/lm/read_arpa.hh +++ b/klm/lm/read_arpa.hh @@ -16,7 +16,13 @@ void ReadARPACounts(util::FilePiece &in, std::vector &number); void ReadNGramHeader(util::FilePiece &in, unsigned int length); void ReadBackoff(util::FilePiece &in, Prob &weights); -void ReadBackoff(util::FilePiece &in, ProbBackoff &weights); +void ReadBackoff(util::FilePiece &in, float &backoff); +inline void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) { + ReadBackoff(in, weights.backoff); +} +inline void ReadBackoff(util::FilePiece &in, RestWeights &weights) { + ReadBackoff(in, weights.backoff); +} void ReadEnd(util::FilePiece &in); @@ -35,7 +41,7 @@ class PositiveProbWarn { WarningAction action_; }; -template void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams, PositiveProbWarn &warn) { +template void Read1Gram(util::FilePiece &f, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) { try { float prob = f.ReadFloat(); if (prob > 0.0) { @@ -43,7 +49,7 @@ template void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff prob = 0.0; } if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); - ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))]; + Weights &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))]; value.prob = prob; ReadBackoff(f, value); } catch(util::Exception &e) { @@ -53,7 +59,7 @@ template void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff } // Return true if a positive log probability came out. -template void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, ProbBackoff *unigrams, PositiveProbWarn &warn) { +template void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) { ReadNGramHeader(f, 1); for (std::size_t i = 0; i < count; ++i) { Read1Gram(f, vocab, unigrams, warn); diff --git a/klm/lm/return.hh b/klm/lm/return.hh index 1b55091b..622320ce 100644 --- a/klm/lm/return.hh +++ b/klm/lm/return.hh @@ -33,6 +33,9 @@ struct FullScoreReturn { */ bool independent_left; uint64_t extend_left; // Defined only if independent_left + + // Rest cost for extension to the left. + float rest; }; } // namespace lm diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 1d6fb5be..13942309 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -3,7 +3,9 @@ #include "lm/binary_format.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" +#include "lm/model.hh" #include "lm/read_arpa.hh" +#include "lm/value.hh" #include "lm/vocab.hh" #include "util/bit_packing.hh" @@ -14,6 +16,8 @@ namespace lm { namespace ngram { +class ProbingModel; + namespace { /* These are passed to ReadNGrams so that n-grams with zero backoff that appear as context will still be used in state. */ @@ -37,9 +41,9 @@ template class ActivateLowerMiddle { Middle &modify_; }; -class ActivateUnigram { +template class ActivateUnigram { public: - explicit ActivateUnigram(ProbBackoff *unigram) : modify_(unigram) {} + explicit ActivateUnigram(Weights *unigram) : modify_(unigram) {} void operator()(const WordIndex *vocab_ids, const unsigned int /*n*/) { // assert(n == 2); @@ -47,43 +51,124 @@ class ActivateUnigram { } private: - ProbBackoff *modify_; + Weights *modify_; }; -template void FixSRI(int lower, float negative_lower_prob, unsigned int n, const uint64_t *keys, const WordIndex *vocab_ids, ProbBackoff *unigrams, std::vector &middle) { - ProbBackoff blank; - blank.backoff = kNoExtensionBackoff; - // Fix SRI's stupidity. - // Note that negative_lower_prob is the negative of the probability (so it's currently >= 0). We still want the sign bit off to indicate left extension, so I just do -= on the backoffs. - blank.prob = negative_lower_prob; - // An entry was found at lower (order lower + 2). - // We need to insert blanks starting at lower + 1 (order lower + 3). - unsigned int fix = static_cast(lower + 1); - uint64_t backoff_hash = detail::CombineWordHash(static_cast(vocab_ids[1]), vocab_ids[2]); - if (fix == 0) { - // Insert a missing bigram. - blank.prob -= unigrams[vocab_ids[1]].backoff; - SetExtension(unigrams[vocab_ids[1]].backoff); - // Bigram including a unigram's backoff - middle[0].Insert(detail::ProbBackoffEntry::Make(keys[0], blank)); - fix = 1; - } else { - for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]); +// Find the lower order entry, inserting blanks along the way as necessary. +template void FindLower( + const std::vector &keys, + typename Value::Weights &unigram, + std::vector > &middle, + std::vector &between) { + typename util::ProbingHashTable::MutableIterator iter; + typename Value::ProbingEntry entry; + // Backoff will always be 0.0. We'll get the probability and rest in another pass. + entry.value.backoff = kNoExtensionBackoff; + // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. + for (int lower = keys.size() - 2; ; --lower) { + if (lower == -1) { + between.push_back(&unigram); + return; + } + entry.key = keys[lower]; + bool found = middle[lower].FindOrInsert(entry, iter); + between.push_back(&iter->value); + if (found) return; + } +} + +// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here. +template void AdjustLower( + const Added &added, + const Build &build, + std::vector &between, + const unsigned int n, + const std::vector &vocab_ids, + typename Build::Value::Weights *unigrams, + std::vector > &middle) { + typedef typename Build::Value Value; + if (between.size() == 1) { + build.MarkExtends(*between.front(), added); + return; + } + typedef util::ProbingHashTable Middle; + float prob = -fabs(between.back()->prob); + // Order of the n-gram on which probabilities are based. + unsigned char basis = n - between.size(); + assert(basis != 0); + typename Build::Value::Weights **change = &between.back(); + // Skip the basis. + --change; + if (basis == 1) { + // Hallucinate a bigram based on a unigram's backoff and a unigram probability. + float &backoff = unigrams[vocab_ids[1]].backoff; + SetExtension(backoff); + prob += backoff; + (*change)->prob = prob; + build.SetRest(&*vocab_ids.begin(), 2, **change); + basis = 2; + --change; } - // fix >= 1. Insert trigrams and above. - for (; fix <= n - 3; ++fix) { + uint64_t backoff_hash = static_cast(vocab_ids[1]); + for (unsigned char i = 2; i <= basis; ++i) { + backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]); + } + for (; basis < n - 1; ++basis, --change) { typename Middle::MutableIterator gotit; - if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) { + if (middle[basis - 2].UnsafeMutableFind(backoff_hash, gotit)) { float &backoff = gotit->value.backoff; SetExtension(backoff); - blank.prob -= backoff; + prob += backoff; } - middle[fix].Insert(detail::ProbBackoffEntry::Make(keys[fix], blank)); - backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]); + (*change)->prob = prob; + build.SetRest(&*vocab_ids.begin(), basis + 1, **change); + backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[basis+1]); + } + + typename std::vector::const_iterator i(between.begin()); + build.MarkExtends(**i, added); + const typename Value::Weights *longer = *i; + // Everything has probability but is not marked as extending. + for (++i; i != between.end(); ++i) { + build.MarkExtends(**i, *longer); + longer = *i; } } -template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector &middle, Activate activate, Store &store, PositiveProbWarn &warn) { +// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds. +template void MarkLower( + const std::vector &keys, + const Build &build, + typename Build::Value::Weights &unigram, + std::vector > &middle, + int start_order, + const typename Build::Value::Weights &longer) { + if (start_order == 0) return; + typename util::ProbingHashTable::MutableIterator iter; + // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code. + for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) { + if (even_lower == -1) { + build.MarkExtends(unigram, longer); + return; + } + middle[even_lower].UnsafeMutableFind(keys[even_lower], iter); + if (!build.MarkExtends(iter->value, longer)) return; + } +} + +template void ReadNGrams( + util::FilePiece &f, + const unsigned int n, + const size_t count, + const ProbingVocabulary &vocab, + const Build &build, + typename Build::Value::Weights *unigrams, + std::vector > &middle, + Activate activate, + Store &store, + PositiveProbWarn &warn) { + typedef typename Build::Value Value; + typedef util::ProbingHashTable Middle; assert(n >= 2); ReadNGramHeader(f, n); @@ -91,38 +176,25 @@ template void ReadNGrams( // vocab ids of words in reverse order. std::vector vocab_ids(n); std::vector keys(n-1); - typename Store::Entry::Value value; - typename Middle::MutableIterator found; + typename Store::Entry entry; + std::vector between; for (size_t i = 0; i < count; ++i) { - ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn); + ReadNGram(f, n, vocab, &*vocab_ids.begin(), entry.value, warn); + build.SetRest(&*vocab_ids.begin(), n, entry.value); keys[0] = detail::CombineWordHash(static_cast(vocab_ids.front()), vocab_ids[1]); for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); } // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. - util::SetSign(value.prob); - store.Insert(Store::Entry::Make(keys[n-2], value)); - // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. - int lower; - util::FloatEnc fix_prob; - for (lower = n - 3; ; --lower) { - if (lower == -1) { - fix_prob.f = unigrams[vocab_ids.front()].prob; - fix_prob.i &= ~util::kSignBit; - unigrams[vocab_ids.front()].prob = fix_prob.f; - break; - } - if (middle[lower].UnsafeMutableFind(keys[lower], found)) { - // Turn off sign bit to indicate that it extends left. - fix_prob.f = found->value.prob; - fix_prob.i &= ~util::kSignBit; - found->value.prob = fix_prob.f; - // We don't need to recurse further down because this entry already set the bits for lower entries. - break; - } - } - if (lower != static_cast(n) - 3) FixSRI(lower, fix_prob.f, n, &*keys.begin(), &*vocab_ids.begin(), unigrams, middle); + util::SetSign(entry.value.prob); + entry.key = keys[n-2]; + + store.Insert(entry); + between.clear(); + FindLower(keys, unigrams[vocab_ids.front()], middle, between); + AdjustLower(entry.value, build, between, n, vocab_ids, unigrams, middle); + if (Build::kMarkEvenLower) MarkLower(keys, build, unigrams[vocab_ids.front()], middle, n - between.size() - 1, *between.back()); activate(&*vocab_ids.begin(), n); } @@ -132,9 +204,9 @@ template void ReadNGrams( } // namespace namespace detail { -template uint8_t *TemplateHashedSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { +template uint8_t *HashedSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { std::size_t allocated = Unigram::Size(counts[0]); - unigram = Unigram(start, allocated); + unigram_ = Unigram(start, counts[0], allocated); start += allocated; for (unsigned int n = 2; n < counts.size(); ++n) { allocated = Middle::Size(counts[n - 1], config.probing_multiplier); @@ -142,31 +214,63 @@ template uint8_t *TemplateHashedSearch template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing) { +template void HashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing) { // TODO: fix sorted. SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config); PositiveProbWarn warn(config.positive_log_probability); - - Read1Grams(f, counts[0], vocab, unigram.Raw(), warn); + Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn); CheckSpecials(config, vocab); + DispatchBuild(f, counts, config, vocab, warn); +} + +template <> void HashedSearch::DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { + NoRestBuild build; + ApplyBuild(f, counts, config, vocab, warn, build); +} + +template <> void HashedSearch::DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { + switch (config.rest_function) { + case Config::REST_MAX: + { + MaxRestBuild build; + ApplyBuild(f, counts, config, vocab, warn, build); + } + break; + case Config::REST_LOWER: + { + LowerRestBuild build(config, counts.size(), vocab); + ApplyBuild(f, counts, config, vocab, warn, build); + } + break; + } +} + +template template void HashedSearch::ApplyBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { + for (WordIndex i = 0; i < counts[0]; ++i) { + build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]); + } try { if (counts.size() > 2) { - ReadNGrams(f, 2, counts[1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); + ReadNGrams, Middle>( + f, 2, counts[1], vocab, build, unigram_.Raw(), middle_, ActivateUnigram(unigram_.Raw()), middle_[0], warn); } for (unsigned int n = 3; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle(middle_[n-3]), middle_[n-2], warn); + ReadNGrams, Middle>( + f, n, counts[n-1], vocab, build, unigram_.Raw(), middle_, ActivateLowerMiddle(middle_[n-3]), middle_[n-2], warn); } if (counts.size() > 2) { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle(middle_.back()), longest, warn); + ReadNGrams, Longest>( + f, counts.size(), counts[counts.size() - 1], vocab, build, unigram_.Raw(), middle_, ActivateLowerMiddle(middle_.back()), longest_, warn); } else { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), longest, warn); + ReadNGrams, Longest>( + f, counts.size(), counts[counts.size() - 1], vocab, build, unigram_.Raw(), middle_, ActivateUnigram(unigram_.Raw()), longest_, warn); } } catch (util::ProbingSizeException &e) { UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); @@ -174,17 +278,16 @@ template template void TemplateHashe ReadEnd(f); } -template void TemplateHashedSearch::LoadedBinary() { - unigram.LoadedBinary(); +template void HashedSearch::LoadedBinary() { + unigram_.LoadedBinary(); for (typename std::vector::iterator i = middle_.begin(); i != middle_.end(); ++i) { i->LoadedBinary(); } - longest.LoadedBinary(); + longest_.LoadedBinary(); } -template class TemplateHashedSearch; - -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); +template class HashedSearch; +template class HashedSearch; } // namespace detail } // namespace ngram diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 4352c72d..7e8c1220 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -19,6 +19,7 @@ namespace util { class FilePiece; } namespace lm { namespace ngram { struct Backing; +class ProbingVocabulary; namespace detail { inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { @@ -26,54 +27,48 @@ inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { return ret; } -struct HashedSearch { - typedef uint64_t Node; - - class Unigram { - public: - Unigram() {} - - Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast(start)) {} - - static std::size_t Size(uint64_t count) { - return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate - } - - const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; } +#pragma pack(push) +#pragma pack(4) +struct ProbEntry { + uint64_t key; + Prob value; + typedef uint64_t Key; + typedef Prob Value; + uint64_t GetKey() const { + return key; + } +}; - ProbBackoff &Unknown() { return unigram_[0]; } +#pragma pack(pop) - void LoadedBinary() {} +class LongestPointer { + public: + explicit LongestPointer(const float &to) : to_(&to) {} - // For building. - ProbBackoff *Raw() { return unigram_; } + LongestPointer() : to_(NULL) {} - private: - ProbBackoff *unigram_; - }; + bool Found() const { + return to_ != NULL; + } - Unigram unigram; + float Prob() const { + return *to_; + } - void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const { - const ProbBackoff &entry = unigram.Lookup(word); - util::FloatEnc val; - val.f = entry.prob; - ret.independent_left = (val.i & util::kSignBit); - ret.extend_left = static_cast(word); - val.i |= util::kSignBit; - ret.prob = val.f; - backoff = entry.backoff; - next = static_cast(word); - } + private: + const float *to_; }; -template class TemplateHashedSearch : public HashedSearch { +template class HashedSearch { public: - typedef MiddleT Middle; + typedef uint64_t Node; - typedef LongestT Longest; - Longest longest; + typedef typename Value::ProbingProxy UnigramPointer; + typedef typename Value::ProbingProxy MiddlePointer; + typedef ::lm::ngram::detail::LongestPointer LongestPointer; + static const ModelType kModelType = Value::kProbingModelType; + static const bool kDifferentRest = Value::kDifferentRest; static const unsigned int kVersion = 0; // TODO: move probing_multiplier here with next binary file format update. @@ -89,64 +84,55 @@ template class TemplateHashedSearch : public Has uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config); - template void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing); + void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing); - typedef typename std::vector::const_iterator MiddleIter; + void LoadedBinary(); - MiddleIter MiddleBegin() const { return middle_.begin(); } - MiddleIter MiddleEnd() const { return middle_.end(); } + unsigned char Order() const { + return middle_.size() + 2; + } - Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { - util::FloatEnc val; - if (extend_length == 1) { - val.f = unigram.Lookup(static_cast(extend_pointer)).prob; - } else { - typename Middle::ConstIterator found; - if (!middle_[extend_length - 2].Find(extend_pointer, found)) { - std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl; - abort(); - } - val.f = found->value.prob; - } - val.i |= util::kSignBit; - prob = val.f; - return extend_pointer; + typename Value::Weights &UnknownUnigram() { return unigram_.Unknown(); } + + UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const { + extend_left = static_cast(word); + next = extend_left; + UnigramPointer ret(unigram_.Lookup(word)); + independent_left = ret.IndependentLeft(); + return ret; } - bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { - node = CombineWordHash(node, word); +#pragma GCC diagnostic ignored "-Wuninitialized" + MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const { + node = extend_pointer; typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - util::FloatEnc enc; - enc.f = found->value.prob; - ret.independent_left = (enc.i & util::kSignBit); - ret.extend_left = node; - enc.i |= util::kSignBit; - ret.prob = enc.f; - backoff = found->value.backoff; - return true; + bool got = middle_[extend_length - 2].Find(extend_pointer, found); + assert(got); + (void)got; + return MiddlePointer(found->value); } - void LoadedBinary(); - - bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { + MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const { node = CombineWordHash(node, word); typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - backoff = found->value.backoff; - return true; + if (!middle_[order_minus_2].Find(node, found)) { + independent_left = true; + return MiddlePointer(); + } + extend_pointer = node; + MiddlePointer ret(found->value); + independent_left = ret.IndependentLeft(); + return ret; } - bool LookupLongest(WordIndex word, float &prob, Node &node) const { + LongestPointer LookupLongest(WordIndex word, const Node &node) const { // Sign bit is always on because longest n-grams do not extend left. - node = CombineWordHash(node, word); typename Longest::ConstIterator found; - if (!longest.Find(node, found)) return false; - prob = found->value.prob; - return true; + if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer(); + return LongestPointer(found->value.prob); } - // Geenrate a node without necessarily checking that it actually exists. + // Generate a node without necessarily checking that it actually exists. // Optionally return false if it's know to not exist. bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { assert(begin != end); @@ -158,55 +144,54 @@ template class TemplateHashedSearch : public Has } private: - std::vector middle_; -}; + // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. + void DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); -/* These look like perfect candidates for a template, right? Ancient gcc (4.1 - * on RedHat stale linux) doesn't pack templates correctly. ProbBackoffEntry - * is a multiple of 8 bytes anyway. ProbEntry is 12 bytes so it's set to pack. - */ -struct ProbBackoffEntry { - uint64_t key; - ProbBackoff value; - typedef uint64_t Key; - typedef ProbBackoff Value; - uint64_t GetKey() const { - return key; - } - static ProbBackoffEntry Make(uint64_t key, ProbBackoff value) { - ProbBackoffEntry ret; - ret.key = key; - ret.value = value; - return ret; - } -}; + template void ApplyBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); -#pragma pack(push) -#pragma pack(4) -struct ProbEntry { - uint64_t key; - Prob value; - typedef uint64_t Key; - typedef Prob Value; - uint64_t GetKey() const { - return key; - } - static ProbEntry Make(uint64_t key, Prob value) { - ProbEntry ret; - ret.key = key; - ret.value = value; - return ret; - } -}; + class Unigram { + public: + Unigram() {} -#pragma pack(pop) + Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : + unigram_(static_cast(start)) +#ifdef DEBUG + , count_(count) +#endif + {} + + static std::size_t Size(uint64_t count) { + return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate + } + + const typename Value::Weights &Lookup(WordIndex index) const { +#ifdef DEBUG + assert(index < count_); +#endif + return unigram_[index]; + } + + typename Value::Weights &Unknown() { return unigram_[0]; } + void LoadedBinary() {} -struct ProbingHashedSearch : public TemplateHashedSearch< - util::ProbingHashTable, - util::ProbingHashTable > { + // For building. + typename Value::Weights *Raw() { return unigram_; } + + private: + typename Value::Weights *unigram_; +#ifdef DEBUG + uint64_t count_; +#endif + }; + + Unigram unigram_; + + typedef util::ProbingHashTable Middle; + std::vector middle_; - static const ModelType kModelType = HASH_PROBING; + typedef util::ProbingHashTable Longest; + Longest longest_; }; } // namespace detail diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index ffadfa94..18e80d5a 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -273,8 +273,9 @@ class FindBlanks { // Phase to actually write n-grams to the trie. template class WriteEntries { public: - WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : + WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), + quant_(quant), unigrams_(unigrams), middle_(middle), longest_(longest), @@ -290,7 +291,7 @@ template class WriteEntries { void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) { ProbBackoff weights = sri_.GetBlank(order_, order, indices); - middle_[order - 2].Insert(indices[order - 1], weights.prob, weights.backoff); + typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(indices[order - 1])).Write(weights.prob, weights.backoff); } void Middle(const unsigned char order, const void *data) { @@ -301,21 +302,22 @@ template class WriteEntries { SetExtension(weights.backoff); ++context; } - middle_[order - 2].Insert(words[order - 1], weights.prob, weights.backoff); + typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(words[order - 1])).Write(weights.prob, weights.backoff); } void Longest(const void *data) { const WordIndex *words = reinterpret_cast(data); - longest_.Insert(words[order_ - 1], reinterpret_cast(words + order_)->prob); + typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast(words + order_)->prob); } void Cleanup() {} private: RecordReader *contexts_; + const Quant &quant_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; - BitPackedLongest &longest_; + BitPackedMiddle *const middle_; + BitPackedLongest &longest_; BitPacked &bigram_pack_; const unsigned char order_; SRISucks &sri_; @@ -380,7 +382,7 @@ template class BlankManager { }; template void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) { - util::ErsatzProgress progress(progress_out, message, unigram_count + 1); + util::ErsatzProgress progress(unigram_count + 1, progress_out, message); WordIndex unigram = 0; std::priority_queue grams; grams.push(Gram(&unigram, 1)); @@ -502,7 +504,7 @@ template void BuildTrie(SortedFiles &files, std::ve inputs[i-2].Rewind(); } if (Quant::kTrain) { - util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); + util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing"); for (unsigned char i = 2; i < counts.size(); ++i) { TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } @@ -510,7 +512,7 @@ template void BuildTrie(SortedFiles &files, std::ve quant.FinishedLoading(config); } - UnigramValue *unigrams = out.unigram.Raw(); + UnigramValue *unigrams = out.unigram_.Raw(); PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams); unigram_file.reset(); @@ -519,7 +521,7 @@ template void BuildTrie(SortedFiles &files, std::ve } // Fill entries except unigram probabilities. { - WriteEntries writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri); + WriteEntries writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); } @@ -544,14 +546,14 @@ template void BuildTrie(SortedFiles &files, std::ve for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { i->FinishedLoading((i+1)->InsertIndex(), config); } - (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config); + (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config); } } template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { - quant_.SetupMemory(start, config); + quant_.SetupMemory(start, counts.size(), config); start += Quant::Size(counts.size(), config); - unigram.Init(start); + unigram_.Init(start); start += Unigram::Size(counts[0]); FreeMiddles(); middle_begin_ = static_cast(malloc(sizeof(Middle) * (counts.size() - 2))); @@ -565,23 +567,23 @@ template uint8_t *TrieSearch::Setup for (unsigned char i = counts.size() - 1; i >= 2; --i) { new (middle_begin_ + i - 2) Middle( middle_starts[i-2], - quant_.Mid(i), + quant_.MiddleBits(config), counts[i-1], counts[0], counts[i], - (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1]), + (i == counts.size() - 1) ? static_cast(longest_) : static_cast(middle_begin_[i-1]), config); } - longest.Init(start, quant_.Long(counts.size()), counts[0]); + longest_.Init(start, quant_.LongestBits(config), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } template void TrieSearch::LoadedBinary() { - unigram.LoadedBinary(); + unigram_.LoadedBinary(); for (Middle *i = middle_begin_; i != middle_end_; ++i) { i->LoadedBinary(); } - longest.LoadedBinary(); + longest_.LoadedBinary(); } template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 5155ca02..10b22ab1 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -28,13 +28,11 @@ template class TrieSearch { public: typedef NodeRange Node; - typedef ::lm::ngram::trie::Unigram Unigram; - Unigram unigram; - - typedef trie::BitPackedMiddle Middle; + typedef ::lm::ngram::trie::UnigramPointer UnigramPointer; + typedef typename Quant::MiddlePointer MiddlePointer; + typedef typename Quant::LongestPointer LongestPointer; - typedef trie::BitPackedLongest Longest; - Longest longest; + static const bool kDifferentRest = false; static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); @@ -62,55 +60,46 @@ template class TrieSearch { void LoadedBinary(); - typedef const Middle *MiddleIter; + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - const Middle *MiddleBegin() const { return middle_begin_; } - const Middle *MiddleEnd() const { return middle_end_; } + unsigned char Order() const { + return middle_end_ - middle_begin_ + 2; + } - void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); + ProbBackoff &UnknownUnigram() { return unigram_.Unknown(); } - void LookupUnigram(WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { - unigram.Find(word, ret.prob, backoff, node); - ret.independent_left = (node.begin == node.end); - ret.extend_left = static_cast(word); + UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const { + extend_left = static_cast(word); + UnigramPointer ret(unigram_.Find(word, next)); + independent_left = (next.begin == next.end); + return ret; } - bool LookupMiddle(const Middle &mid, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { - if (!mid.Find(word, ret.prob, backoff, node, ret.extend_left)) return false; - ret.independent_left = (node.begin == node.end); - return true; + MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const { + return MiddlePointer(quant_, extend_length - 2, middle_begin_[extend_length - 2].ReadEntry(extend_pointer, node)); } - bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { - return mid.FindNoProb(word, backoff, node); + MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_left) const { + util::BitAddress address(middle_begin_[order_minus_2].Find(word, node, extend_left)); + independent_left = (address.base == NULL) || (node.begin == node.end); + return MiddlePointer(quant_, order_minus_2, address); } - bool LookupLongest(WordIndex word, float &prob, const Node &node) const { - return longest.Find(word, prob, node); + LongestPointer LookupLongest(WordIndex word, const Node &node) const { + return LongestPointer(quant_, longest_.Find(word, node)); } bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - // TODO: don't decode backoff. assert(begin != end); - FullScoreReturn ignored; - float ignored_backoff; - LookupUnigram(*begin, ignored_backoff, node, ignored); + bool independent_left; + uint64_t ignored; + LookupUnigram(*begin, node, independent_left, ignored); for (const WordIndex *i = begin + 1; i < end; ++i) { - if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; + if (independent_left || !LookupMiddle(i - begin - 1, *i, node, independent_left, ignored).Found()) return false; } return true; } - Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { - if (extend_length == 1) { - float ignored; - Node ret; - unigram.Find(static_cast(extend_pointer), prob, ignored, ret); - return ret; - } - return middle_begin_[extend_length - 2].ReadEntry(extend_pointer, prob); - } - private: friend void BuildTrie(SortedFiles &files, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); @@ -122,8 +111,16 @@ template class TrieSearch { free(middle_begin_); } + typedef trie::BitPackedMiddle Middle; + + typedef trie::BitPackedLongest Longest; + Longest longest_; + Middle *middle_begin_, *middle_end_; Quant quant_; + + typedef ::lm::ngram::trie::Unigram Unigram; + Unigram unigram_; }; } // namespace trie diff --git a/klm/lm/state.hh b/klm/lm/state.hh new file mode 100644 index 00000000..c7438414 --- /dev/null +++ b/klm/lm/state.hh @@ -0,0 +1,123 @@ +#ifndef LM_STATE__ +#define LM_STATE__ + +#include "lm/max_order.hh" +#include "lm/word_index.hh" +#include "util/murmur_hash.hh" + +#include + +namespace lm { +namespace ngram { + +// This is a POD but if you want memcmp to return the same as operator==, call +// ZeroRemaining first. +class State { + public: + bool operator==(const State &other) const { + if (length != other.length) return false; + return !memcmp(words, other.words, length * sizeof(WordIndex)); + } + + // Three way comparison function. + int Compare(const State &other) const { + if (length != other.length) return length < other.length ? -1 : 1; + return memcmp(words, other.words, length * sizeof(WordIndex)); + } + + bool operator<(const State &other) const { + if (length != other.length) return length < other.length; + return memcmp(words, other.words, length * sizeof(WordIndex)) < 0; + } + + // Call this before using raw memcmp. + void ZeroRemaining() { + for (unsigned char i = length; i < kMaxOrder - 1; ++i) { + words[i] = 0; + backoff[i] = 0.0; + } + } + + unsigned char Length() const { return length; } + + // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. + // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. + WordIndex words[kMaxOrder - 1]; + float backoff[kMaxOrder - 1]; + unsigned char length; +}; + +inline uint64_t hash_value(const State &state, uint64_t seed = 0) { + return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed); +} + +struct Left { + bool operator==(const Left &other) const { + return + (length == other.length) && + pointers[length - 1] == other.pointers[length - 1] && + full == other.full; + } + + int Compare(const Left &other) const { + if (length < other.length) return -1; + if (length > other.length) return 1; + if (pointers[length - 1] > other.pointers[length - 1]) return 1; + if (pointers[length - 1] < other.pointers[length - 1]) return -1; + return (int)full - (int)other.full; + } + + bool operator<(const Left &other) const { + return Compare(other) == -1; + } + + void ZeroRemaining() { + for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) + *i = 0; + } + + uint64_t pointers[kMaxOrder - 1]; + unsigned char length; + bool full; +}; + +inline uint64_t hash_value(const Left &left) { + unsigned char add[2]; + add[0] = left.length; + add[1] = left.full; + return util::MurmurHashNative(add, 2, left.length ? left.pointers[left.length - 1] : 0); +} + +struct ChartState { + bool operator==(const ChartState &other) { + return (right == other.right) && (left == other.left); + } + + int Compare(const ChartState &other) const { + int lres = left.Compare(other.left); + if (lres) return lres; + return right.Compare(other.right); + } + + bool operator<(const ChartState &other) const { + return Compare(other) == -1; + } + + void ZeroRemaining() { + left.ZeroRemaining(); + right.ZeroRemaining(); + } + + Left left; + State right; +}; + +inline uint64_t hash_value(const ChartState &state) { + return hash_value(state.right, hash_value(state.left)); +} + + +} // namespace ngram +} // namespace lm + +#endif // LM_STATE__ diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 20075bb8..0f1ca574 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,7 +1,6 @@ #include "lm/trie.hh" #include "lm/bhiksha.hh" -#include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" #include "util/sorted_uniform.hh" @@ -58,91 +57,71 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) max_vocab_ = max_vocab; } -template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { +template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config)); } -template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : +template BitPackedMiddle::BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : BitPacked(), - quant_(quant), + quant_bits_(quant_bits), // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary. bhiksha_(base, entries + 1, max_next, config), next_source_(&next_source) { if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - BaseInit(reinterpret_cast(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits()); + BaseInit(reinterpret_cast(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant_bits_ + bhiksha_.InlineBits()); } -template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { +template util::BitAddress BitPackedMiddle::Insert(WordIndex word) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; util::WriteInt57(base_, at_pointer, word_bits_, word); at_pointer += word_bits_; - quant_.Write(base_, at_pointer, prob, backoff); - at_pointer += quant_.TotalBits(); + util::BitAddress ret(base_, at_pointer); + at_pointer += quant_bits_; uint64_t next = next_source_->InsertIndex(); bhiksha_.WriteNext(base_, at_pointer, insert_index_, next); - ++insert_index_; + return ret; } -template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const { +template util::BitAddress BitPackedMiddle::Find(WordIndex word, NodeRange &range, uint64_t &pointer) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { - return false; + return util::BitAddress(NULL, 0); } pointer = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; + bhiksha_.ReadNext(base_, at_pointer + quant_bits_, pointer, total_bits_, range); - quant_.Read(base_, at_pointer, prob, backoff); - at_pointer += quant_.TotalBits(); - - bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range); - - return true; + return util::BitAddress(base_, at_pointer); } -template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { - uint64_t index; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false; - uint64_t at_pointer = index * total_bits_; - at_pointer += word_bits_; - quant_.ReadBackoff(base_, at_pointer, backoff); - at_pointer += quant_.TotalBits(); - bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); - return true; -} - -template void BitPackedMiddle::FinishedLoading(uint64_t next_end, const Config &config) { +template void BitPackedMiddle::FinishedLoading(uint64_t next_end, const Config &config) { uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); bhiksha_.FinishedLoading(config); } -template void BitPackedLongest::Insert(WordIndex index, float prob) { +util::BitAddress BitPackedLongest::Insert(WordIndex index) { assert(index <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; util::WriteInt57(base_, at_pointer, word_bits_, index); at_pointer += word_bits_; - quant_.Write(base_, at_pointer, prob); ++insert_index_; + return util::BitAddress(base_, at_pointer); } -template bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { +util::BitAddress BitPackedLongest::Find(WordIndex word, const NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return util::BitAddress(NULL, 0); at_pointer = at_pointer * total_bits_ + word_bits_; - quant_.Read(base_, at_pointer, prob); - return true; + return util::BitAddress(base_, at_pointer); } -template class BitPackedMiddle; -template class BitPackedMiddle; -template class BitPackedMiddle; -template class BitPackedMiddle; -template class BitPackedLongest; -template class BitPackedLongest; +template class BitPackedMiddle; +template class BitPackedMiddle; } // namespace trie } // namespace ngram diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index ebe9910f..eff93292 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -1,12 +1,13 @@ #ifndef LM_TRIE__ #define LM_TRIE__ -#include +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/bit_packing.hh" #include -#include "lm/word_index.hh" -#include "lm/weights.hh" +#include namespace lm { namespace ngram { @@ -24,6 +25,22 @@ struct UnigramValue { uint64_t Next() const { return next; } }; +class UnigramPointer { + public: + explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {} + + UnigramPointer() : to_(NULL) {} + + bool Found() const { return to_ != NULL; } + + float Prob() const { return to_->prob; } + float Backoff() const { return to_->backoff; } + float Rest() const { return Prob(); } + + private: + const ProbBackoff *to_; +}; + class Unigram { public: Unigram() {} @@ -47,12 +64,11 @@ class Unigram { void LoadedBinary() {} - void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { + UnigramPointer Find(WordIndex word, NodeRange &next) const { UnigramValue *val = unigram_ + word; - prob = val->weights.prob; - backoff = val->weights.backoff; next.begin = val->next; next.end = (val+1)->next; + return UnigramPointer(val->weights); } private: @@ -81,40 +97,36 @@ class BitPacked { uint64_t insert_index_, max_vocab_; }; -template class BitPackedMiddle : public BitPacked { +template class BitPackedMiddle : public BitPacked { public: static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); // next_source need not be initialized. - BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); + BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); - void Insert(WordIndex word, float prob, float backoff); + util::BitAddress Insert(WordIndex word); void FinishedLoading(uint64_t next_end, const Config &config); void LoadedBinary() { bhiksha_.LoadedBinary(); } - bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const; - - bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; + util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const; - NodeRange ReadEntry(uint64_t pointer, float &prob) { + util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) { uint64_t addr = pointer * total_bits_; addr += word_bits_; - quant_.ReadProb(base_, addr, prob); - NodeRange ret; - bhiksha_.ReadNext(base_, addr + quant_.TotalBits(), pointer, total_bits_, ret); - return ret; + bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range); + return util::BitAddress(base_, addr); } private: - Quant quant_; + uint8_t quant_bits_; Bhiksha bhiksha_; const BitPacked *next_source_; }; -template class BitPackedLongest : public BitPacked { +class BitPackedLongest : public BitPacked { public: static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { return BaseSize(entries, max_vocab, quant_bits); @@ -122,19 +134,18 @@ template class BitPackedLongest : public BitPacked { BitPackedLongest() {} - void Init(void *base, const Quant &quant, uint64_t max_vocab) { - quant_ = quant; - BaseInit(base, max_vocab, quant_.TotalBits()); + void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) { + BaseInit(base, max_vocab, quant_bits); } void LoadedBinary() {} - void Insert(WordIndex word, float prob); + util::BitAddress Insert(WordIndex word); - bool Find(WordIndex word, float &prob, const NodeRange &node) const; + util::BitAddress Find(WordIndex word, const NodeRange &node) const; private: - Quant quant_; + uint8_t quant_bits_; }; } // namespace trie diff --git a/klm/lm/value.hh b/klm/lm/value.hh new file mode 100644 index 00000000..85e53f14 --- /dev/null +++ b/klm/lm/value.hh @@ -0,0 +1,157 @@ +#ifndef LM_VALUE__ +#define LM_VALUE__ + +#include "lm/model_type.hh" +#include "lm/value_build.hh" +#include "lm/weights.hh" +#include "util/bit_packing.hh" + +#include + +namespace lm { +namespace ngram { + +// Template proxy for probing unigrams and middle. +template class GenericProbingProxy { + public: + explicit GenericProbingProxy(const Weights &to) : to_(&to) {} + + GenericProbingProxy() : to_(0) {} + + bool Found() const { return to_ != 0; } + + float Prob() const { + util::FloatEnc enc; + enc.f = to_->prob; + enc.i |= util::kSignBit; + return enc.f; + } + + float Backoff() const { return to_->backoff; } + + bool IndependentLeft() const { + util::FloatEnc enc; + enc.f = to_->prob; + return enc.i & util::kSignBit; + } + + protected: + const Weights *to_; +}; + +// Basic proxy for trie unigrams. +template class GenericTrieUnigramProxy { + public: + explicit GenericTrieUnigramProxy(const Weights &to) : to_(&to) {} + + GenericTrieUnigramProxy() : to_(0) {} + + bool Found() const { return to_ != 0; } + float Prob() const { return to_->prob; } + float Backoff() const { return to_->backoff; } + float Rest() const { return Prob(); } + + protected: + const Weights *to_; +}; + +struct BackoffValue { + typedef ProbBackoff Weights; + static const ModelType kProbingModelType = PROBING; + + class ProbingProxy : public GenericProbingProxy { + public: + explicit ProbingProxy(const Weights &to) : GenericProbingProxy(to) {} + ProbingProxy() {} + float Rest() const { return Prob(); } + }; + + class TrieUnigramProxy : public GenericTrieUnigramProxy { + public: + explicit TrieUnigramProxy(const Weights &to) : GenericTrieUnigramProxy(to) {} + TrieUnigramProxy() {} + float Rest() const { return Prob(); } + }; + + struct ProbingEntry { + typedef uint64_t Key; + typedef Weights Value; + uint64_t key; + ProbBackoff value; + uint64_t GetKey() const { return key; } + }; + + struct TrieUnigramValue { + Weights weights; + uint64_t next; + uint64_t Next() const { return next; } + }; + + const static bool kDifferentRest = false; + + template void Callback(const Config &, unsigned int, typename Model::Vocabulary &, C &callback) { + NoRestBuild build; + callback(build); + } +}; + +struct RestValue { + typedef RestWeights Weights; + static const ModelType kProbingModelType = REST_PROBING; + + class ProbingProxy : public GenericProbingProxy { + public: + explicit ProbingProxy(const Weights &to) : GenericProbingProxy(to) {} + ProbingProxy() {} + float Rest() const { return to_->rest; } + }; + + class TrieUnigramProxy : public GenericTrieUnigramProxy { + public: + explicit TrieUnigramProxy(const Weights &to) : GenericTrieUnigramProxy(to) {} + TrieUnigramProxy() {} + float Rest() const { return to_->rest; } + }; + +// gcc 4.1 doesn't properly back dependent types :-(. +#pragma pack(push) +#pragma pack(4) + struct ProbingEntry { + typedef uint64_t Key; + typedef Weights Value; + Key key; + Value value; + Key GetKey() const { return key; } + }; + + struct TrieUnigramValue { + Weights weights; + uint64_t next; + uint64_t Next() const { return next; } + }; +#pragma pack(pop) + + const static bool kDifferentRest = true; + + template void Callback(const Config &config, unsigned int order, typename Model::Vocabulary &vocab, C &callback) { + switch (config.rest_function) { + case Config::REST_MAX: + { + MaxRestBuild build; + callback(build); + } + break; + case Config::REST_LOWER: + { + LowerRestBuild build(config, order, vocab); + callback(build); + } + break; + } + } +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_VALUE__ diff --git a/klm/lm/value_build.cc b/klm/lm/value_build.cc new file mode 100644 index 00000000..6124f8da --- /dev/null +++ b/klm/lm/value_build.cc @@ -0,0 +1,58 @@ +#include "lm/value_build.hh" + +#include "lm/model.hh" +#include "lm/read_arpa.hh" + +namespace lm { +namespace ngram { + +template LowerRestBuild::LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab) { + UTIL_THROW_IF(config.rest_lower_files.size() != order - 1, ConfigException, "This model has order " << order << " so there should be " << (order - 1) << " lower-order models for rest cost purposes."); + Config for_lower = config; + for_lower.rest_lower_files.clear(); + + // Unigram models aren't supported, so this is a custom loader. + // TODO: optimize the unigram loading? + { + util::FilePiece uni(config.rest_lower_files[0].c_str()); + std::vector number; + ReadARPACounts(uni, number); + UTIL_THROW_IF(number.size() != 1, FormatLoadException, "Expected the unigram model to have order 1, not " << number.size()); + ReadNGramHeader(uni, 1); + unigrams_.resize(number[0]); + unigrams_[0] = config.unknown_missing_logprob; + PositiveProbWarn warn; + for (uint64_t i = 0; i < number[0]; ++i) { + WordIndex w; + Prob entry; + ReadNGram(uni, 1, vocab, &w, entry, warn); + unigrams_[w] = entry.prob; + } + } + + try { + for (unsigned int i = 2; i < order; ++i) { + models_.push_back(new Model(config.rest_lower_files[i - 1].c_str(), for_lower)); + UTIL_THROW_IF(models_.back()->Order() != i, FormatLoadException, "Lower order file " << config.rest_lower_files[i-1] << " should have order " << i); + } + } catch (...) { + for (typename std::vector::const_iterator i = models_.begin(); i != models_.end(); ++i) { + delete *i; + } + models_.clear(); + throw; + } + + // TODO: force/check same vocab. +} + +template LowerRestBuild::~LowerRestBuild() { + for (typename std::vector::const_iterator i = models_.begin(); i != models_.end(); ++i) { + delete *i; + } +} + +template class LowerRestBuild; + +} // namespace ngram +} // namespace lm diff --git a/klm/lm/value_build.hh b/klm/lm/value_build.hh new file mode 100644 index 00000000..687a41a0 --- /dev/null +++ b/klm/lm/value_build.hh @@ -0,0 +1,97 @@ +#ifndef LM_VALUE_BUILD__ +#define LM_VALUE_BUILD__ + +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/bit_packing.hh" + +#include + +namespace lm { +namespace ngram { + +class Config; +class BackoffValue; +class RestValue; + +class NoRestBuild { + public: + typedef BackoffValue Value; + + NoRestBuild() {} + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {} + + template bool MarkExtends(ProbBackoff &weights, const Second &) const { + util::UnsetSign(weights.prob); + return false; + } + + // Probing doesn't need to go back to unigram. + const static bool kMarkEvenLower = false; +}; + +class MaxRestBuild { + public: + typedef RestValue Value; + + MaxRestBuild() {} + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const { + weights.rest = weights.prob; + util::SetSign(weights.rest); + } + + bool MarkExtends(RestWeights &weights, const RestWeights &to) const { + util::UnsetSign(weights.prob); + if (weights.rest >= to.rest) return false; + weights.rest = to.rest; + return true; + } + bool MarkExtends(RestWeights &weights, const Prob &to) const { + util::UnsetSign(weights.prob); + if (weights.rest >= to.prob) return false; + weights.rest = to.prob; + return true; + } + + // Probing does need to go back to unigram. + const static bool kMarkEvenLower = true; +}; + +template class LowerRestBuild { + public: + typedef RestValue Value; + + LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab); + + ~LowerRestBuild(); + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const { + typename Model::State ignored; + if (n == 1) { + weights.rest = unigrams_[*vocab_ids]; + } else { + weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob; + } + } + + template bool MarkExtends(RestWeights &weights, const Second &) const { + util::UnsetSign(weights.prob); + return false; + } + + const static bool kMarkEvenLower = false; + + std::vector unigrams_; + + std::vector models_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_VALUE_BUILD__ diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 9fd698bb..5de68f16 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -196,7 +196,7 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) { } } -void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { +void ProbingVocabulary::InternalFinishedLoading() { lookup_.FinishedInserting(); header_->bound = bound_; header_->version = kProbingVocabularyVersion; diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 06fdefe4..c3efcb4a 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -141,7 +141,9 @@ class ProbingVocabulary : public base::Vocabulary { WordIndex Insert(const StringPiece &str); - void FinishedLoading(ProbBackoff *reorder_vocab); + template void FinishedLoading(Weights * /*reorder_vocab*/) { + InternalFinishedLoading(); + } std::size_t UnkCountChangePadding() const { return 0; } @@ -150,6 +152,8 @@ class ProbingVocabulary : public base::Vocabulary { void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); private: + void InternalFinishedLoading(); + typedef util::ProbingHashTable Lookup; Lookup lookup_; diff --git a/klm/lm/weights.hh b/klm/lm/weights.hh index 1f38cf5e..bd5d8034 100644 --- a/klm/lm/weights.hh +++ b/klm/lm/weights.hh @@ -12,6 +12,11 @@ struct ProbBackoff { float prob; float backoff; }; +struct RestWeights { + float prob; + float backoff; + float rest; +}; } // namespace lm #endif // LM_WEIGHTS__ diff --git a/klm/util/Jamfile b/klm/util/Jamfile index b8c14347..3ee2c2c2 100644 --- a/klm/util/Jamfile +++ b/klm/util/Jamfile @@ -1,4 +1,4 @@ -lib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc ../..//z : .. : : .. ; +lib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc usage.cc ../..//z : .. : : .. ; import testing ; diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index a8d6299b..5ceccf2c 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -25,6 +25,7 @@ libklm_util_a_SOURCES = \ file.cc \ file_piece.cc \ mmap.cc \ - murmur_hash.cc + murmur_hash.cc \ + usage.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 73a5cb22..dcbd814c 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -174,6 +174,13 @@ struct BitsMask { uint64_t mask; }; +struct BitAddress { + BitAddress(void *in_base, uint64_t in_offset) : base(in_base), offset(in_offset) {} + + void *base; + uint64_t offset; +}; + } // namespace util #endif // UTIL_BIT_PACKING__ diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc index a82ce672..07b14e26 100644 --- a/klm/util/ersatz_progress.cc +++ b/klm/util/ersatz_progress.cc @@ -12,17 +12,17 @@ namespace { const unsigned char kWidth = 100; } ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits::max()), complete_(next_), out_(NULL) {} ErsatzProgress::~ErsatzProgress() { - if (!out_) return; - Finished(); + if (out_) Finished(); } -ErsatzProgress::ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete) +ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std::string &message) : current_(0), next_(complete / kWidth), complete_(complete), stones_written_(0), out_(to) { if (!out_) { next_ = std::numeric_limits::max(); return; } - *out_ << message << "\n----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"; + if (!message.empty()) *out_ << message << '\n'; + *out_ << "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"; } void ErsatzProgress::Milestone() { diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index 92c345fe..f709dc51 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -1,7 +1,7 @@ #ifndef UTIL_ERSATZ_PROGRESS__ #define UTIL_ERSATZ_PROGRESS__ -#include +#include #include // Ersatz version of boost::progress so core language model doesn't depend on @@ -14,7 +14,7 @@ class ErsatzProgress { ErsatzProgress(); // Null means no output. The null value is useful for passing along the ostream pointer from another caller. - ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete); + explicit ErsatzProgress(std::size_t complete, std::ostream *to = &std::cerr, const std::string &message = ""); ~ErsatzProgress(); diff --git a/klm/util/file.cc b/klm/util/file.cc index de206bc8..1bd056fc 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -43,16 +43,6 @@ int OpenReadOrThrow(const char *name) { return ret; } -int CreateOrThrow(const char *name) { - int ret; -#if defined(_WIN32) || defined(_WIN64) - UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); -#else - UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name); -#endif - return ret; -} - uint64_t SizeFile(int fd) { #if defined(_WIN32) || defined(_WIN64) __int64 ret = _filelengthi64(fd); diff --git a/klm/util/file.hh b/klm/util/file.hh index 72c8ea76..5c57e2a9 100644 --- a/klm/util/file.hh +++ b/klm/util/file.hh @@ -65,10 +65,7 @@ class scoped_FILE { std::FILE *file_; }; -// Open for read only. int OpenReadOrThrow(const char *name); -// Create file if it doesn't exist, truncate if it does. Opened for write. -int CreateOrThrow(const char *name); // Return value for SizeFile when it can't size properly. const uint64_t kBadSize = (uint64_t)-1; diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 081e662b..7b6a01dd 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -18,31 +18,35 @@ #include #include +#ifdef HAVE_ZLIB +#include +#endif + namespace util { ParseNumberException::ParseNumberException(StringPiece value) throw() { *this << "Could not parse \"" << value << "\" into a number"; } +GZException::GZException(void *file) { #ifdef HAVE_ZLIB -GZException::GZException(gzFile file) { int num; - *this << gzerror( file, &num) << " from zlib"; -} + *this << gzerror(file, &num) << " from zlib"; #endif // HAVE_ZLIB +} // Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale). const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; FilePiece::FilePiece(const char *name, std::ostream *show_progress, std::size_t min_buffer) : file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(SizePage()), - progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { + progress_(total_size_, total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name) { Initialize(name, show_progress, min_buffer); } FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std::size_t min_buffer) : file_(fd), total_size_(SizeFile(file_.get())), page_(SizePage()), - progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { + progress_(total_size_, total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name) { Initialize(name, show_progress, min_buffer); } @@ -149,8 +153,9 @@ template T FilePiece::ReadNumber() { SkipSpaces(); while (last_space_ < position_) { if (at_end_) { + if (position_ >= position_end_) throw EndOfFileException(); // Hallucinate a null off the end of the file. - std::string buffer(position_, position_end_); + std::string buffer(position_, position_end_ - position_); char *end; T ret; ParseNumber(buffer.c_str(), end, ret); diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index af93d8aa..b81ac0e2 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -13,10 +13,6 @@ #include -#ifdef HAVE_ZLIB -#include -#endif - namespace util { class ParseNumberException : public Exception { @@ -27,9 +23,7 @@ class ParseNumberException : public Exception { class GZException : public Exception { public: -#ifdef HAVE_ZLIB - explicit GZException(gzFile file); -#endif + explicit GZException(void *file); GZException() throw() {} ~GZException() throw() {} }; @@ -123,7 +117,7 @@ class FilePiece { std::string file_name_; #ifdef HAVE_ZLIB - gzFile gz_file_; + void *gz_file_; #endif // HAVE_ZLIB }; diff --git a/klm/util/have.hh b/klm/util/have.hh index f2f0cf90..b8181e99 100644 --- a/klm/util/have.hh +++ b/klm/util/have.hh @@ -3,13 +3,21 @@ #define UTIL_HAVE__ #ifndef HAVE_ZLIB +#if !defined(_WIN32) && !defined(_WIN64) #define HAVE_ZLIB #endif +#endif -// #define HAVE_ICU +#ifndef HAVE_ICU +//#define HAVE_ICU +#endif #ifndef HAVE_BOOST #define HAVE_BOOST #endif +#ifndef HAVE_THREADS +//#define HAVE_THREADS +#endif + #endif // UTIL_HAVE__ diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index 2db35b56..e0d2570b 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -171,6 +171,20 @@ void *MapZeroedWrite(int fd, std::size_t size) { return MapOrThrow(size, true, kFileFlags, false, fd, 0); } +namespace { + +int CreateOrThrow(const char *name) { + int ret; +#if defined(_WIN32) || defined(_WIN64) + UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); +#else + UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name); +#endif + return ret; +} + +} // namespace + void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) { file.reset(CreateOrThrow(name)); try { diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index 6accc21a..4f519312 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -23,7 +23,7 @@ namespace util { // 64-bit hash for 64-bit platforms -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +uint64_t MurmurHash64A ( const void * key, std::size_t len, uint64_t seed ) { const uint64_t m = 0xc6a4a7935bd1e995ULL; const int r = 47; @@ -81,7 +81,7 @@ uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) // 64-bit hash for 32-bit platforms -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +uint64_t MurmurHash64B ( const void * key, std::size_t len, uint64_t seed ) { const unsigned int m = 0x5bd1e995; const int r = 24; @@ -150,17 +150,18 @@ uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) return h; } + // Trick to test for 64-bit architecture at compile time. namespace { -template uint64_t MurmurHashNativeBackend(const void * key, std::size_t len, unsigned int seed) { +template inline uint64_t MurmurHashNativeBackend(const void * key, std::size_t len, uint64_t seed) { return MurmurHash64A(key, len, seed); } -template <> uint64_t MurmurHashNativeBackend<4>(const void * key, std::size_t len, unsigned int seed) { +template <> inline uint64_t MurmurHashNativeBackend<4>(const void * key, std::size_t len, uint64_t seed) { return MurmurHash64B(key, len, seed); } } // namespace -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { +uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed) { return MurmurHashNativeBackend(key, len, seed); } diff --git a/klm/util/murmur_hash.hh b/klm/util/murmur_hash.hh index 638aaeb2..ae7e88de 100644 --- a/klm/util/murmur_hash.hh +++ b/klm/util/murmur_hash.hh @@ -5,9 +5,9 @@ namespace util { -uint64_t MurmurHash64A(const void * key, std::size_t len, unsigned int seed = 0); -uint64_t MurmurHash64B(const void * key, std::size_t len, unsigned int seed = 0); -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed = 0); +uint64_t MurmurHash64A(const void * key, std::size_t len, uint64_t seed = 0); +uint64_t MurmurHash64B(const void * key, std::size_t len, uint64_t seed = 0); +uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed = 0); } // namespace util diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index f466cebc..3354b68e 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -78,12 +78,33 @@ template bool FindOrInsert(const T &t, MutableIterator &out) { +#ifdef DEBUG + assert(initialized_); +#endif + for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) { + Key got(i->GetKey()); + if (equal_(got, t.GetKey())) { out = i; return true; } + if (equal_(got, invalid_)) { + UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full."); + *i = t; + out = i; + return false; + } + if (++i == end_) i = begin_; + } + } + void FinishedInserting() {} void LoadedBinary() {} // Don't change anything related to GetKey, template bool UnsafeMutableFind(const Key key, MutableIterator &out) { +#ifdef DEBUG + assert(initialized_); +#endif for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) { Key got(i->GetKey()); if (equal_(got, key)) { out = i; return true; } diff --git a/klm/util/usage.cc b/klm/util/usage.cc new file mode 100644 index 00000000..e5cf76f0 --- /dev/null +++ b/klm/util/usage.cc @@ -0,0 +1,46 @@ +#include "util/usage.hh" + +#include +#include + +#include +#include +#if !defined(_WIN32) && !defined(_WIN64) +#include +#include +#endif + +namespace util { + +namespace { +#if !defined(_WIN32) && !defined(_WIN64) +float FloatSec(const struct timeval &tv) { + return static_cast(tv.tv_sec) + (static_cast(tv.tv_usec) / 1000000.0); +} +#endif +} // namespace + +void PrintUsage(std::ostream &out) { +#if !defined(_WIN32) && !defined(_WIN64) + struct rusage usage; + if (getrusage(RUSAGE_SELF, &usage)) { + perror("getrusage"); + return; + } + out << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n'; + + // Linux doesn't set memory usage :-(. + std::ifstream status("/proc/self/status", std::ios::in); + std::string line; + while (getline(status, line)) { + if (!strncmp(line.c_str(), "VmRSS:\t", 7)) { + out << "VmRSS: " << (line.c_str() + 7) << '\n'; + break; + } else if (!strncmp(line.c_str(), "VmPeak:\t", 8)) { + out << "VmPeak: " << (line.c_str() + 8) << '\n'; + } + } +#endif +} + +} // namespace util diff --git a/klm/util/usage.hh b/klm/util/usage.hh new file mode 100644 index 00000000..d331ff74 --- /dev/null +++ b/klm/util/usage.hh @@ -0,0 +1,8 @@ +#ifndef UTIL_USAGE__ +#define UTIL_USAGE__ +#include + +namespace util { +void PrintUsage(std::ostream &to); +} // namespace util +#endif // UTIL_USAGE__ -- cgit v1.2.3