diff options
author | Kenneth Heafield <kenlm@kheafield.com> | 2011-09-24 07:58:58 -0400 |
---|---|---|
committer | Kenneth Heafield <kenlm@kheafield.com> | 2011-09-24 07:58:58 -0400 |
commit | 23fc955f722906ac927df04106c1f0474ba8ca2d (patch) | |
tree | 90d1144de4b8b7441173847fcc4a4206f33f08fc /klm | |
parent | 9b1fbf198b3cee920a76e13f7a21c7fcf0cc5f56 (diff) |
Belated documentation
Diffstat (limited to 'klm')
-rw-r--r-- | klm/lm/left.hh | 70 | ||||
-rw-r--r-- | klm/lm/model.cc | 5 | ||||
-rw-r--r-- | klm/lm/model.hh | 25 |
3 files changed, 76 insertions, 24 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 837be765..effa0560 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -1,3 +1,40 @@ +/* Efficient left and right language model state for sentence fragments. + * Intended usage: + * Store ChartState with every chart entry. + * To do a rule application: + * 1. Make a ChartState object for your new entry. + * 2. Construct RuleScore. + * 3. Going from left to right, call Terminal or NonTerminal. + * For terminals, just pass the vocab id. + * For non-terminals, pass that non-terminal's ChartState. + * If your decoder expects scores inclusive of subtree scores (i.e. you + * label entries with the highest-scoring path), pass the non-terminal's + * score as prob. + * If your decoder expects relative scores and will walk the chart later, + * pass prob = 0.0. + * In other words, the only effect of prob is that it gets added to the + * returned log probability. + * 4. Call Finish. It returns the log probability. + * + * There's a couple more details: + * Do not pass <s> to Terminal as it is formally not a word in the sentence, + * only context. Instead, call BeginSentence. If called, it should be the + * first call after RuleScore is constructed (since <s> is always the + * leftmost). + * + * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal. + * + * Hashing and sorting comparison operators are provided. All state objects + * are POD. If you intend to use memcmp on raw state objects, you must call + * ZeroRemaining first, as the value of array entries beyond length is + * otherwise undefined. + * + * Usage is of course not limited to chart decoding. Anything that generates + * sentence fragments missing left context could benefit. For example, a + * phrase-based decoder could pre-score phrases, storing ChartState with each + * phrase, even if hypotheses are generated left-to-right. + */ + #ifndef LM_LEFT__ #define LM_LEFT__ @@ -5,6 +42,8 @@ #include "lm/model.hh" #include "lm/return.hh" +#include "util/murmur_hash.hh" + #include <algorithm> namespace lm { @@ -18,23 +57,30 @@ struct Left { } int Compare(const Left &other) const { - if (length != other.length) { - return (int)length - (int)other.length; - } + if (length != other.length) return length < other.length ? -1 : 1; if (pointers[length - 1] > other.pointers[length - 1]) return 1; if (pointers[length - 1] < other.pointers[length - 1]) return -1; return 0; } + bool operator<(const Left &other) const { + if (length != other.length) return length < other.length; + return pointers[length - 1] < other.pointers[length - 1]; + } + void ZeroRemaining() { for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) *i = 0; } - uint64_t pointers[kMaxOrder - 1]; unsigned char length; + uint64_t pointers[kMaxOrder - 1]; }; +inline size_t hash_value(const Left &left) { + return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]); +} + struct ChartState { bool operator==(const ChartState &other) { return (left == other.left) && (right == other.right) && (full == other.full); @@ -48,16 +94,27 @@ struct ChartState { return (int)full - (int)other.full; } + bool operator<(const ChartState &other) const { + return Compare(other) == -1; + } + void ZeroRemaining() { left.ZeroRemaining(); right.ZeroRemaining(); } Left left; - State right; bool full; + State right; }; +inline size_t hash_value(const ChartState &state) { + size_t hashes[2]; + hashes[0] = hash_value(state.left); + hashes[1] = hash_value(state.right); + return util::MurmurHashNative(hashes, sizeof(size_t), state.full); +} + template <class M> class RuleScore { public: explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { @@ -73,8 +130,7 @@ template <class M> class RuleScore { void Terminal(WordIndex word) { State copy(out_.right); - FullScoreReturn ret = model_.FullScore(copy, word, out_.right); - ProcessRet(ret); + ProcessRet(model_.FullScore(copy, word, out_.right)); if (out_.right.length != copy.length + 1) left_done_ = true; } diff --git a/klm/lm/model.cc b/klm/lm/model.cc index ca581d8a..25f1ab7c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -14,11 +14,6 @@ namespace lm { namespace ngram { - -size_t hash_value(const State &state) { - return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); -} - namespace detail { template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType; diff --git a/klm/lm/model.hh b/klm/lm/model.hh index fe91af2e..c278acd6 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -12,6 +12,8 @@ #include "lm/vocab.hh" #include "lm/weights.hh" +#include "util/murmur_hash.hh" + #include <algorithm> #include <vector> @@ -28,21 +30,18 @@ class State { public: bool operator==(const State &other) const { if (length != other.length) return false; - const WordIndex *end = words + length; - for (const WordIndex *first = words, *second = other.words; - first != end; ++first, ++second) { - if (*first != *second) return false; - } - // If the histories are equal, so are the backoffs. - return true; + return !memcmp(words, other.words, length * sizeof(WordIndex)); } // Three way comparison function. int Compare(const State &other) const { - if (length == other.length) { - return memcmp(words, other.words, length * sizeof(WordIndex)); - } - return (length < other.length) ? -1 : 1; + if (length != other.length) return length < other.length ? -1 : 1; + return memcmp(words, other.words, length * sizeof(WordIndex)); + } + + bool operator<(const State &other) const { + if (length != other.length) return length < other.length; + return memcmp(words, other.words, length * sizeof(WordIndex)) < 0; } // Call this before using raw memcmp. @@ -62,7 +61,9 @@ class State { unsigned char length; }; -size_t hash_value(const State &state); +inline size_t hash_value(const State &state) { + return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); +} namespace detail { |