From f111672dd611f78656fceb3df3729a290453ef56 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 21 Sep 2011 18:23:50 -0400 Subject: Updated kenlm. Includes left state support but not the cdec-side use of it. Updated binary format. --- klm/lm/left.hh | 181 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 klm/lm/left.hh (limited to 'klm/lm/left.hh') diff --git a/klm/lm/left.hh b/klm/lm/left.hh new file mode 100644 index 00000000..df69e97a --- /dev/null +++ b/klm/lm/left.hh @@ -0,0 +1,181 @@ +#ifndef LM_LEFT__ +#define LM_LEFT__ + +#include "lm/max_order.hh" +#include "lm/model.hh" +#include "lm/return.hh" + +#include + +namespace lm { +namespace ngram { + +struct Left { + bool operator==(const Left &other) const { + return + (length == other.length) && + pointers[length - 1] == other.pointers[length - 1]; + } + + int Compare(const Left &other) const { + if (length != other.length) { + return (int)length - (int)other.length; + } + if (pointers[length - 1] > other.pointers[length - 1]) return 1; + if (pointers[length - 1] < other.pointers[length - 1]) return -1; + return 0; + } + + uint64_t pointers[kMaxOrder - 1]; + unsigned char length; +}; + +struct ChartState { + bool operator==(const ChartState &other) { + return (left == other.left) && (right == other.right) && (full == other.full); + } + + int Compare(const ChartState &other) const { + int lres = left.Compare(other.left); + if (lres) return lres; + int rres = right.Compare(other.right); + if (rres) return rres; + return (int)full - (int)other.full; + } + + Left left; + State right; + bool full; +}; + +template class RuleScore { + public: + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { + out.left.length = 0; + out.right.length = 0; + } + + void BeginSentence() { + out_.right = model_.BeginSentenceState(); + // out_.left is empty. + left_done_ = true; + } + + void Terminal(WordIndex word) { + State copy(out_.right); + FullScoreReturn ret = model_.FullScore(copy, word, out_.right); + ProcessRet(ret); + if (out_.right.length != copy.length + 1) left_done_ = true; + } + + // Faster version of NonTerminal for the case where the rule begins with a non-terminal. + void BeginNonTerminal(const ChartState &in, float prob) { + prob_ = prob; + out_ = in; + left_write_ = out_.left.pointers + out_.left.length; + left_done_ = in.full; + } + + void NonTerminal(const ChartState &in, float prob) { + prob_ += prob; + + if (!in.left.length) { + if (in.full) { + for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + } + return; + } + + if (!out_.right.length) { + out_.right = in.right; + if (left_done_) return; + if (left_write_ != out_.left.pointers) { + left_done_ = true; + } else { + out_.left = in.left; + left_write_ = out_.left.pointers + in.left.length; + left_done_ = in.full; + } + return; + } + + float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; + float *back = backoffs, *back2 = backoffs2; + unsigned char next_use; + FullScoreReturn ret; + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + unsigned char extend_length = 2; + for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) { + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + std::swap(back, back2); + } + + if (in.full) { + for (const float *i = back; i != back + next_use; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + return; + } + + // Right state was minimized, so it's already independent of the new words to the left. + if (in.right.length < in.left.length) { + out_.right = in.right; + return; + } + + // Shift exisiting words down. + for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { + *(i + in.right.length) = *i; + } + // Add words from in.right. + std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); + // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. + std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); + std::copy(back, back + next_use, out_.right.backoff + in.right.length); + out_.right.length = in.right.length + next_use; + } + + float Finish() { + out_.left.length = left_write_ - out_.left.pointers; + out_.full = left_done_; + return prob_; + } + + private: + void ProcessRet(const FullScoreReturn &ret) { + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + *(left_write_++) = ret.extend_left; + } + + const M &model_; + + ChartState &out_; + + bool left_done_; + + uint64_t *left_write_; + + float prob_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_LEFT__ -- cgit v1.2.3 From fcbb924a575df56de53eacce886ebf9ccf3283ed Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 23 Sep 2011 16:09:56 -0400 Subject: Add ZeroRemaining --- klm/lm/left.hh | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'klm/lm/left.hh') diff --git a/klm/lm/left.hh b/klm/lm/left.hh index df69e97a..837be765 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -26,6 +26,11 @@ struct Left { return 0; } + void ZeroRemaining() { + for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) + *i = 0; + } + uint64_t pointers[kMaxOrder - 1]; unsigned char length; }; @@ -43,6 +48,11 @@ struct ChartState { return (int)full - (int)other.full; } + void ZeroRemaining() { + left.ZeroRemaining(); + right.ZeroRemaining(); + } + Left left; State right; bool full; -- cgit v1.2.3 From 2e5720a8e7141a75ae549c6be74f50bd18068ef1 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sat, 24 Sep 2011 07:58:58 -0400 Subject: Belated documentation --- klm/lm/left.hh | 70 +++++++++++++++++++++++++++++++++++++++++++++++++++------ klm/lm/model.cc | 5 ----- klm/lm/model.hh | 25 +++++++++++---------- 3 files changed, 76 insertions(+), 24 deletions(-) (limited to 'klm/lm/left.hh') diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 837be765..effa0560 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -1,3 +1,40 @@ +/* Efficient left and right language model state for sentence fragments. + * Intended usage: + * Store ChartState with every chart entry. + * To do a rule application: + * 1. Make a ChartState object for your new entry. + * 2. Construct RuleScore. + * 3. Going from left to right, call Terminal or NonTerminal. + * For terminals, just pass the vocab id. + * For non-terminals, pass that non-terminal's ChartState. + * If your decoder expects scores inclusive of subtree scores (i.e. you + * label entries with the highest-scoring path), pass the non-terminal's + * score as prob. + * If your decoder expects relative scores and will walk the chart later, + * pass prob = 0.0. + * In other words, the only effect of prob is that it gets added to the + * returned log probability. + * 4. Call Finish. It returns the log probability. + * + * There's a couple more details: + * Do not pass to Terminal as it is formally not a word in the sentence, + * only context. Instead, call BeginSentence. If called, it should be the + * first call after RuleScore is constructed (since is always the + * leftmost). + * + * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal. + * + * Hashing and sorting comparison operators are provided. All state objects + * are POD. If you intend to use memcmp on raw state objects, you must call + * ZeroRemaining first, as the value of array entries beyond length is + * otherwise undefined. + * + * Usage is of course not limited to chart decoding. Anything that generates + * sentence fragments missing left context could benefit. For example, a + * phrase-based decoder could pre-score phrases, storing ChartState with each + * phrase, even if hypotheses are generated left-to-right. + */ + #ifndef LM_LEFT__ #define LM_LEFT__ @@ -5,6 +42,8 @@ #include "lm/model.hh" #include "lm/return.hh" +#include "util/murmur_hash.hh" + #include namespace lm { @@ -18,23 +57,30 @@ struct Left { } int Compare(const Left &other) const { - if (length != other.length) { - return (int)length - (int)other.length; - } + if (length != other.length) return length < other.length ? -1 : 1; if (pointers[length - 1] > other.pointers[length - 1]) return 1; if (pointers[length - 1] < other.pointers[length - 1]) return -1; return 0; } + bool operator<(const Left &other) const { + if (length != other.length) return length < other.length; + return pointers[length - 1] < other.pointers[length - 1]; + } + void ZeroRemaining() { for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) *i = 0; } - uint64_t pointers[kMaxOrder - 1]; unsigned char length; + uint64_t pointers[kMaxOrder - 1]; }; +inline size_t hash_value(const Left &left) { + return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]); +} + struct ChartState { bool operator==(const ChartState &other) { return (left == other.left) && (right == other.right) && (full == other.full); @@ -48,16 +94,27 @@ struct ChartState { return (int)full - (int)other.full; } + bool operator<(const ChartState &other) const { + return Compare(other) == -1; + } + void ZeroRemaining() { left.ZeroRemaining(); right.ZeroRemaining(); } Left left; - State right; bool full; + State right; }; +inline size_t hash_value(const ChartState &state) { + size_t hashes[2]; + hashes[0] = hash_value(state.left); + hashes[1] = hash_value(state.right); + return util::MurmurHashNative(hashes, sizeof(size_t), state.full); +} + template class RuleScore { public: explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { @@ -73,8 +130,7 @@ template class RuleScore { void Terminal(WordIndex word) { State copy(out_.right); - FullScoreReturn ret = model_.FullScore(copy, word, out_.right); - ProcessRet(ret); + ProcessRet(model_.FullScore(copy, word, out_.right)); if (out_.right.length != copy.length + 1) left_done_ = true; } diff --git a/klm/lm/model.cc b/klm/lm/model.cc index ca581d8a..25f1ab7c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -14,11 +14,6 @@ namespace lm { namespace ngram { - -size_t hash_value(const State &state) { - return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); -} - namespace detail { template const ModelType GenericModel::kModelType = Search::kModelType; diff --git a/klm/lm/model.hh b/klm/lm/model.hh index fe91af2e..c278acd6 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -12,6 +12,8 @@ #include "lm/vocab.hh" #include "lm/weights.hh" +#include "util/murmur_hash.hh" + #include #include @@ -28,21 +30,18 @@ class State { public: bool operator==(const State &other) const { if (length != other.length) return false; - const WordIndex *end = words + length; - for (const WordIndex *first = words, *second = other.words; - first != end; ++first, ++second) { - if (*first != *second) return false; - } - // If the histories are equal, so are the backoffs. - return true; + return !memcmp(words, other.words, length * sizeof(WordIndex)); } // Three way comparison function. int Compare(const State &other) const { - if (length == other.length) { - return memcmp(words, other.words, length * sizeof(WordIndex)); - } - return (length < other.length) ? -1 : 1; + if (length != other.length) return length < other.length ? -1 : 1; + return memcmp(words, other.words, length * sizeof(WordIndex)); + } + + bool operator<(const State &other) const { + if (length != other.length) return length < other.length; + return memcmp(words, other.words, length * sizeof(WordIndex)) < 0; } // Call this before using raw memcmp. @@ -62,7 +61,9 @@ class State { unsigned char length; }; -size_t hash_value(const State &state); +inline size_t hash_value(const State &state) { + return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); +} namespace detail { -- cgit v1.2.3 From 0e1ffb6c1528e44f63ae8bac466bd5163e973974 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 11 Oct 2011 14:58:52 -0400 Subject: Trie fixes for SRI --- klm/lm/left.hh | 26 +++++++++++++++----------- klm/lm/search_trie.cc | 7 ++++++- klm/lm/search_trie.hh | 2 +- klm/lm/trie_sort.cc | 2 +- klm/util/scoped.hh | 2 +- 5 files changed, 24 insertions(+), 15 deletions(-) (limited to 'klm/lm/left.hh') diff --git a/klm/lm/left.hh b/klm/lm/left.hh index effa0560..bb3f5539 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -117,7 +117,7 @@ inline size_t hash_value(const ChartState &state) { template class RuleScore { public: - explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { out.left.length = 0; out.right.length = 0; } @@ -130,15 +130,22 @@ template class RuleScore { void Terminal(WordIndex word) { State copy(out_.right); - ProcessRet(model_.FullScore(copy, word, out_.right)); - if (out_.right.length != copy.length + 1) left_done_ = true; + FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + out_.left.pointers[out_.left.length++] = ret.extend_left; + if (out_.right.length != copy.length + 1) + left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. void BeginNonTerminal(const ChartState &in, float prob) { prob_ = prob; out_ = in; - left_write_ = out_.left.pointers + out_.left.length; left_done_ = in.full; } @@ -157,11 +164,10 @@ template class RuleScore { if (!out_.right.length) { out_.right = in.right; if (left_done_) return; - if (left_write_ != out_.left.pointers) { + if (out_.left.length) { left_done_ = true; } else { out_.left = in.left; - left_write_ = out_.left.pointers + in.left.length; left_done_ = in.full; } return; @@ -214,8 +220,8 @@ template class RuleScore { } float Finish() { - out_.left.length = left_write_ - out_.left.pointers; - out_.full = left_done_; + // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. + out_.full = left_done_ || (out_.left.length == model_.Order() - 1); return prob_; } @@ -227,7 +233,7 @@ template class RuleScore { left_done_ = true; return; } - *(left_write_++) = ret.extend_left; + out_.left.pointers[out_.left.length++] = ret.extend_left; } const M &model_; @@ -236,8 +242,6 @@ template class RuleScore { bool left_done_; - uint64_t *left_write_; - float prob_; }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 6479813b..5d8c70db 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -151,6 +151,11 @@ class BackoffMessages { private: void FinishedAdding() { Resize(current_ - (uint8_t*)backing_.get()); + // Sort requests in same order as files. + std::sort( + util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)), + util::SizedIterator(util::SizedProxy(current_, entry_size_)), + util::SizedCompare(EntryCompare((entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex)))); current_ = (uint8_t*)backing_.get(); } @@ -525,7 +530,7 @@ template void BuildTrie(const std::string &file_pre const RecordReader &context = contexts[order - 2]; if (context) { FormatLoadException e; - e << "An " << static_cast(order) << "-gram has context"; + e << "A " << static_cast(order) << "-gram has context"; const WordIndex *ctx = reinterpret_cast(context.Data()); for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) { e << ' ' << *i; diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index c3e02a98..33ae8cff 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -36,7 +36,7 @@ template class TrieSearch { static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); - static const unsigned int kVersion = 0; + static const unsigned int kVersion = 1; static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 86f28493..bb126f18 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -191,7 +191,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); files.push_back(assembled.str()); MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); - MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order, FirstCombine()); + MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order - 1, FirstCombine()); files.pop_front(); files.pop_front(); } diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index 12e6652b..93e2e817 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -50,7 +50,7 @@ class scoped_malloc { void call_realloc(std::size_t to) { void *ret; - UTIL_THROW_IF(!(ret = std::realloc(p_, to)), util::ErrnoException, "realloc to " << to << " bytes failed."); + UTIL_THROW_IF(!(ret = std::realloc(p_, to)) && to, util::ErrnoException, "realloc to " << to << " bytes failed."); p_ = ret; } -- cgit v1.2.3