diff options
Diffstat (limited to 'klm/lm/left.hh')
-rw-r--r-- | klm/lm/left.hh | 70 |
1 files changed, 63 insertions, 7 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 837be765..effa0560 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -1,3 +1,40 @@ +/* Efficient left and right language model state for sentence fragments. + * Intended usage: + * Store ChartState with every chart entry. + * To do a rule application: + * 1. Make a ChartState object for your new entry. + * 2. Construct RuleScore. + * 3. Going from left to right, call Terminal or NonTerminal. + * For terminals, just pass the vocab id. + * For non-terminals, pass that non-terminal's ChartState. + * If your decoder expects scores inclusive of subtree scores (i.e. you + * label entries with the highest-scoring path), pass the non-terminal's + * score as prob. + * If your decoder expects relative scores and will walk the chart later, + * pass prob = 0.0. + * In other words, the only effect of prob is that it gets added to the + * returned log probability. + * 4. Call Finish. It returns the log probability. + * + * There's a couple more details: + * Do not pass <s> to Terminal as it is formally not a word in the sentence, + * only context. Instead, call BeginSentence. If called, it should be the + * first call after RuleScore is constructed (since <s> is always the + * leftmost). + * + * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal. + * + * Hashing and sorting comparison operators are provided. All state objects + * are POD. If you intend to use memcmp on raw state objects, you must call + * ZeroRemaining first, as the value of array entries beyond length is + * otherwise undefined. + * + * Usage is of course not limited to chart decoding. Anything that generates + * sentence fragments missing left context could benefit. For example, a + * phrase-based decoder could pre-score phrases, storing ChartState with each + * phrase, even if hypotheses are generated left-to-right. + */ + #ifndef LM_LEFT__ #define LM_LEFT__ @@ -5,6 +42,8 @@ #include "lm/model.hh" #include "lm/return.hh" +#include "util/murmur_hash.hh" + #include <algorithm> namespace lm { @@ -18,23 +57,30 @@ struct Left { } int Compare(const Left &other) const { - if (length != other.length) { - return (int)length - (int)other.length; - } + if (length != other.length) return length < other.length ? -1 : 1; if (pointers[length - 1] > other.pointers[length - 1]) return 1; if (pointers[length - 1] < other.pointers[length - 1]) return -1; return 0; } + bool operator<(const Left &other) const { + if (length != other.length) return length < other.length; + return pointers[length - 1] < other.pointers[length - 1]; + } + void ZeroRemaining() { for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) *i = 0; } - uint64_t pointers[kMaxOrder - 1]; unsigned char length; + uint64_t pointers[kMaxOrder - 1]; }; +inline size_t hash_value(const Left &left) { + return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]); +} + struct ChartState { bool operator==(const ChartState &other) { return (left == other.left) && (right == other.right) && (full == other.full); @@ -48,16 +94,27 @@ struct ChartState { return (int)full - (int)other.full; } + bool operator<(const ChartState &other) const { + return Compare(other) == -1; + } + void ZeroRemaining() { left.ZeroRemaining(); right.ZeroRemaining(); } Left left; - State right; bool full; + State right; }; +inline size_t hash_value(const ChartState &state) { + size_t hashes[2]; + hashes[0] = hash_value(state.left); + hashes[1] = hash_value(state.right); + return util::MurmurHashNative(hashes, sizeof(size_t), state.full); +} + template <class M> class RuleScore { public: explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { @@ -73,8 +130,7 @@ template <class M> class RuleScore { void Terminal(WordIndex word) { State copy(out_.right); - FullScoreReturn ret = model_.FullScore(copy, word, out_.right); - ProcessRet(ret); + ProcessRet(model_.FullScore(copy, word, out_.right)); if (out_.right.length != copy.length + 1) left_done_ = true; } |