summaryrefslogtreecommitdiff
path: root/klm/lm/left.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/left.hh')
-rw-r--r--klm/lm/left.hh251
1 files changed, 251 insertions, 0 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
new file mode 100644
index 00000000..bb3f5539
--- /dev/null
+++ b/klm/lm/left.hh
@@ -0,0 +1,251 @@
+/* Efficient left and right language model state for sentence fragments.
+ * Intended usage:
+ * Store ChartState with every chart entry.
+ * To do a rule application:
+ * 1. Make a ChartState object for your new entry.
+ * 2. Construct RuleScore.
+ * 3. Going from left to right, call Terminal or NonTerminal.
+ * For terminals, just pass the vocab id.
+ * For non-terminals, pass that non-terminal's ChartState.
+ * If your decoder expects scores inclusive of subtree scores (i.e. you
+ * label entries with the highest-scoring path), pass the non-terminal's
+ * score as prob.
+ * If your decoder expects relative scores and will walk the chart later,
+ * pass prob = 0.0.
+ * In other words, the only effect of prob is that it gets added to the
+ * returned log probability.
+ * 4. Call Finish. It returns the log probability.
+ *
+ * There's a couple more details:
+ * Do not pass <s> to Terminal as it is formally not a word in the sentence,
+ * only context. Instead, call BeginSentence. If called, it should be the
+ * first call after RuleScore is constructed (since <s> is always the
+ * leftmost).
+ *
+ * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal.
+ *
+ * Hashing and sorting comparison operators are provided. All state objects
+ * are POD. If you intend to use memcmp on raw state objects, you must call
+ * ZeroRemaining first, as the value of array entries beyond length is
+ * otherwise undefined.
+ *
+ * Usage is of course not limited to chart decoding. Anything that generates
+ * sentence fragments missing left context could benefit. For example, a
+ * phrase-based decoder could pre-score phrases, storing ChartState with each
+ * phrase, even if hypotheses are generated left-to-right.
+ */
+
+#ifndef LM_LEFT__
+#define LM_LEFT__
+
+#include "lm/max_order.hh"
+#include "lm/model.hh"
+#include "lm/return.hh"
+
+#include "util/murmur_hash.hh"
+
+#include <algorithm>
+
+namespace lm {
+namespace ngram {
+
+struct Left {
+ bool operator==(const Left &other) const {
+ return
+ (length == other.length) &&
+ pointers[length - 1] == other.pointers[length - 1];
+ }
+
+ int Compare(const Left &other) const {
+ if (length != other.length) return length < other.length ? -1 : 1;
+ if (pointers[length - 1] > other.pointers[length - 1]) return 1;
+ if (pointers[length - 1] < other.pointers[length - 1]) return -1;
+ return 0;
+ }
+
+ bool operator<(const Left &other) const {
+ if (length != other.length) return length < other.length;
+ return pointers[length - 1] < other.pointers[length - 1];
+ }
+
+ void ZeroRemaining() {
+ for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i)
+ *i = 0;
+ }
+
+ unsigned char length;
+ uint64_t pointers[kMaxOrder - 1];
+};
+
+inline size_t hash_value(const Left &left) {
+ return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]);
+}
+
+struct ChartState {
+ bool operator==(const ChartState &other) {
+ return (left == other.left) && (right == other.right) && (full == other.full);
+ }
+
+ int Compare(const ChartState &other) const {
+ int lres = left.Compare(other.left);
+ if (lres) return lres;
+ int rres = right.Compare(other.right);
+ if (rres) return rres;
+ return (int)full - (int)other.full;
+ }
+
+ bool operator<(const ChartState &other) const {
+ return Compare(other) == -1;
+ }
+
+ void ZeroRemaining() {
+ left.ZeroRemaining();
+ right.ZeroRemaining();
+ }
+
+ Left left;
+ bool full;
+ State right;
+};
+
+inline size_t hash_value(const ChartState &state) {
+ size_t hashes[2];
+ hashes[0] = hash_value(state.left);
+ hashes[1] = hash_value(state.right);
+ return util::MurmurHashNative(hashes, sizeof(size_t), state.full);
+}
+
+template <class M> class RuleScore {
+ public:
+ explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) {
+ out.left.length = 0;
+ out.right.length = 0;
+ }
+
+ void BeginSentence() {
+ out_.right = model_.BeginSentenceState();
+ // out_.left is empty.
+ left_done_ = true;
+ }
+
+ void Terminal(WordIndex word) {
+ State copy(out_.right);
+ FullScoreReturn ret(model_.FullScore(copy, word, out_.right));
+ prob_ += ret.prob;
+ if (left_done_) return;
+ if (ret.independent_left) {
+ left_done_ = true;
+ return;
+ }
+ out_.left.pointers[out_.left.length++] = ret.extend_left;
+ if (out_.right.length != copy.length + 1)
+ left_done_ = true;
+ }
+
+ // Faster version of NonTerminal for the case where the rule begins with a non-terminal.
+ void BeginNonTerminal(const ChartState &in, float prob) {
+ prob_ = prob;
+ out_ = in;
+ left_done_ = in.full;
+ }
+
+ void NonTerminal(const ChartState &in, float prob) {
+ prob_ += prob;
+
+ if (!in.left.length) {
+ if (in.full) {
+ for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i;
+ left_done_ = true;
+ out_.right = in.right;
+ }
+ return;
+ }
+
+ if (!out_.right.length) {
+ out_.right = in.right;
+ if (left_done_) return;
+ if (out_.left.length) {
+ left_done_ = true;
+ } else {
+ out_.left = in.left;
+ left_done_ = in.full;
+ }
+ return;
+ }
+
+ float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1];
+ float *back = backoffs, *back2 = backoffs2;
+ unsigned char next_use;
+ FullScoreReturn ret;
+ ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use));
+ if (!next_use) {
+ left_done_ = true;
+ out_.right = in.right;
+ return;
+ }
+ unsigned char extend_length = 2;
+ for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) {
+ ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use));
+ if (!next_use) {
+ left_done_ = true;
+ out_.right = in.right;
+ return;
+ }
+ std::swap(back, back2);
+ }
+
+ if (in.full) {
+ for (const float *i = back; i != back + next_use; ++i) prob_ += *i;
+ left_done_ = true;
+ out_.right = in.right;
+ return;
+ }
+
+ // Right state was minimized, so it's already independent of the new words to the left.
+ if (in.right.length < in.left.length) {
+ out_.right = in.right;
+ return;
+ }
+
+ // Shift exisiting words down.
+ for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) {
+ *(i + in.right.length) = *i;
+ }
+ // Add words from in.right.
+ std::copy(in.right.words, in.right.words + in.right.length, out_.right.words);
+ // Assemble backoff composed on the existing state's backoff followed by the new state's backoff.
+ std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff);
+ std::copy(back, back + next_use, out_.right.backoff + in.right.length);
+ out_.right.length = in.right.length + next_use;
+ }
+
+ float Finish() {
+ // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
+ out_.full = left_done_ || (out_.left.length == model_.Order() - 1);
+ return prob_;
+ }
+
+ private:
+ void ProcessRet(const FullScoreReturn &ret) {
+ prob_ += ret.prob;
+ if (left_done_) return;
+ if (ret.independent_left) {
+ left_done_ = true;
+ return;
+ }
+ out_.left.pointers[out_.left.length++] = ret.extend_left;
+ }
+
+ const M &model_;
+
+ ChartState &out_;
+
+ bool left_done_;
+
+ float prob_;
+};
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_LEFT__