diff options
Diffstat (limited to 'klm/lm/left.hh')
-rw-r--r-- | klm/lm/left.hh | 251 |
1 files changed, 251 insertions, 0 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh new file mode 100644 index 00000000..bb3f5539 --- /dev/null +++ b/klm/lm/left.hh @@ -0,0 +1,251 @@ +/* Efficient left and right language model state for sentence fragments. + * Intended usage: + * Store ChartState with every chart entry. + * To do a rule application: + * 1. Make a ChartState object for your new entry. + * 2. Construct RuleScore. + * 3. Going from left to right, call Terminal or NonTerminal. + * For terminals, just pass the vocab id. + * For non-terminals, pass that non-terminal's ChartState. + * If your decoder expects scores inclusive of subtree scores (i.e. you + * label entries with the highest-scoring path), pass the non-terminal's + * score as prob. + * If your decoder expects relative scores and will walk the chart later, + * pass prob = 0.0. + * In other words, the only effect of prob is that it gets added to the + * returned log probability. + * 4. Call Finish. It returns the log probability. + * + * There's a couple more details: + * Do not pass <s> to Terminal as it is formally not a word in the sentence, + * only context. Instead, call BeginSentence. If called, it should be the + * first call after RuleScore is constructed (since <s> is always the + * leftmost). + * + * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal. + * + * Hashing and sorting comparison operators are provided. All state objects + * are POD. If you intend to use memcmp on raw state objects, you must call + * ZeroRemaining first, as the value of array entries beyond length is + * otherwise undefined. + * + * Usage is of course not limited to chart decoding. Anything that generates + * sentence fragments missing left context could benefit. For example, a + * phrase-based decoder could pre-score phrases, storing ChartState with each + * phrase, even if hypotheses are generated left-to-right. + */ + +#ifndef LM_LEFT__ +#define LM_LEFT__ + +#include "lm/max_order.hh" +#include "lm/model.hh" +#include "lm/return.hh" + +#include "util/murmur_hash.hh" + +#include <algorithm> + +namespace lm { +namespace ngram { + +struct Left { + bool operator==(const Left &other) const { + return + (length == other.length) && + pointers[length - 1] == other.pointers[length - 1]; + } + + int Compare(const Left &other) const { + if (length != other.length) return length < other.length ? -1 : 1; + if (pointers[length - 1] > other.pointers[length - 1]) return 1; + if (pointers[length - 1] < other.pointers[length - 1]) return -1; + return 0; + } + + bool operator<(const Left &other) const { + if (length != other.length) return length < other.length; + return pointers[length - 1] < other.pointers[length - 1]; + } + + void ZeroRemaining() { + for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) + *i = 0; + } + + unsigned char length; + uint64_t pointers[kMaxOrder - 1]; +}; + +inline size_t hash_value(const Left &left) { + return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]); +} + +struct ChartState { + bool operator==(const ChartState &other) { + return (left == other.left) && (right == other.right) && (full == other.full); + } + + int Compare(const ChartState &other) const { + int lres = left.Compare(other.left); + if (lres) return lres; + int rres = right.Compare(other.right); + if (rres) return rres; + return (int)full - (int)other.full; + } + + bool operator<(const ChartState &other) const { + return Compare(other) == -1; + } + + void ZeroRemaining() { + left.ZeroRemaining(); + right.ZeroRemaining(); + } + + Left left; + bool full; + State right; +}; + +inline size_t hash_value(const ChartState &state) { + size_t hashes[2]; + hashes[0] = hash_value(state.left); + hashes[1] = hash_value(state.right); + return util::MurmurHashNative(hashes, sizeof(size_t), state.full); +} + +template <class M> class RuleScore { + public: + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { + out.left.length = 0; + out.right.length = 0; + } + + void BeginSentence() { + out_.right = model_.BeginSentenceState(); + // out_.left is empty. + left_done_ = true; + } + + void Terminal(WordIndex word) { + State copy(out_.right); + FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + out_.left.pointers[out_.left.length++] = ret.extend_left; + if (out_.right.length != copy.length + 1) + left_done_ = true; + } + + // Faster version of NonTerminal for the case where the rule begins with a non-terminal. + void BeginNonTerminal(const ChartState &in, float prob) { + prob_ = prob; + out_ = in; + left_done_ = in.full; + } + + void NonTerminal(const ChartState &in, float prob) { + prob_ += prob; + + if (!in.left.length) { + if (in.full) { + for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + } + return; + } + + if (!out_.right.length) { + out_.right = in.right; + if (left_done_) return; + if (out_.left.length) { + left_done_ = true; + } else { + out_.left = in.left; + left_done_ = in.full; + } + return; + } + + float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; + float *back = backoffs, *back2 = backoffs2; + unsigned char next_use; + FullScoreReturn ret; + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + unsigned char extend_length = 2; + for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) { + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + std::swap(back, back2); + } + + if (in.full) { + for (const float *i = back; i != back + next_use; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + return; + } + + // Right state was minimized, so it's already independent of the new words to the left. + if (in.right.length < in.left.length) { + out_.right = in.right; + return; + } + + // Shift exisiting words down. + for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { + *(i + in.right.length) = *i; + } + // Add words from in.right. + std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); + // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. + std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); + std::copy(back, back + next_use, out_.right.backoff + in.right.length); + out_.right.length = in.right.length + next_use; + } + + float Finish() { + // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. + out_.full = left_done_ || (out_.left.length == model_.Order() - 1); + return prob_; + } + + private: + void ProcessRet(const FullScoreReturn &ret) { + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + out_.left.pointers[out_.left.length++] = ret.extend_left; + } + + const M &model_; + + ChartState &out_; + + bool left_done_; + + float prob_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_LEFT__ |