summaryrefslogtreecommitdiff
path: root/klm/lm/left.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/left.hh')
-rw-r--r--klm/lm/left.hh70
1 files changed, 63 insertions, 7 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index 837be765..effa0560 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -1,3 +1,40 @@
+/* Efficient left and right language model state for sentence fragments.
+ * Intended usage:
+ * Store ChartState with every chart entry.
+ * To do a rule application:
+ * 1. Make a ChartState object for your new entry.
+ * 2. Construct RuleScore.
+ * 3. Going from left to right, call Terminal or NonTerminal.
+ * For terminals, just pass the vocab id.
+ * For non-terminals, pass that non-terminal's ChartState.
+ * If your decoder expects scores inclusive of subtree scores (i.e. you
+ * label entries with the highest-scoring path), pass the non-terminal's
+ * score as prob.
+ * If your decoder expects relative scores and will walk the chart later,
+ * pass prob = 0.0.
+ * In other words, the only effect of prob is that it gets added to the
+ * returned log probability.
+ * 4. Call Finish. It returns the log probability.
+ *
+ * There's a couple more details:
+ * Do not pass <s> to Terminal as it is formally not a word in the sentence,
+ * only context. Instead, call BeginSentence. If called, it should be the
+ * first call after RuleScore is constructed (since <s> is always the
+ * leftmost).
+ *
+ * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal.
+ *
+ * Hashing and sorting comparison operators are provided. All state objects
+ * are POD. If you intend to use memcmp on raw state objects, you must call
+ * ZeroRemaining first, as the value of array entries beyond length is
+ * otherwise undefined.
+ *
+ * Usage is of course not limited to chart decoding. Anything that generates
+ * sentence fragments missing left context could benefit. For example, a
+ * phrase-based decoder could pre-score phrases, storing ChartState with each
+ * phrase, even if hypotheses are generated left-to-right.
+ */
+
#ifndef LM_LEFT__
#define LM_LEFT__
@@ -5,6 +42,8 @@
#include "lm/model.hh"
#include "lm/return.hh"
+#include "util/murmur_hash.hh"
+
#include <algorithm>
namespace lm {
@@ -18,23 +57,30 @@ struct Left {
}
int Compare(const Left &other) const {
- if (length != other.length) {
- return (int)length - (int)other.length;
- }
+ if (length != other.length) return length < other.length ? -1 : 1;
if (pointers[length - 1] > other.pointers[length - 1]) return 1;
if (pointers[length - 1] < other.pointers[length - 1]) return -1;
return 0;
}
+ bool operator<(const Left &other) const {
+ if (length != other.length) return length < other.length;
+ return pointers[length - 1] < other.pointers[length - 1];
+ }
+
void ZeroRemaining() {
for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i)
*i = 0;
}
- uint64_t pointers[kMaxOrder - 1];
unsigned char length;
+ uint64_t pointers[kMaxOrder - 1];
};
+inline size_t hash_value(const Left &left) {
+ return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]);
+}
+
struct ChartState {
bool operator==(const ChartState &other) {
return (left == other.left) && (right == other.right) && (full == other.full);
@@ -48,16 +94,27 @@ struct ChartState {
return (int)full - (int)other.full;
}
+ bool operator<(const ChartState &other) const {
+ return Compare(other) == -1;
+ }
+
void ZeroRemaining() {
left.ZeroRemaining();
right.ZeroRemaining();
}
Left left;
- State right;
bool full;
+ State right;
};
+inline size_t hash_value(const ChartState &state) {
+ size_t hashes[2];
+ hashes[0] = hash_value(state.left);
+ hashes[1] = hash_value(state.right);
+ return util::MurmurHashNative(hashes, sizeof(size_t), state.full);
+}
+
template <class M> class RuleScore {
public:
explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) {
@@ -73,8 +130,7 @@ template <class M> class RuleScore {
void Terminal(WordIndex word) {
State copy(out_.right);
- FullScoreReturn ret = model_.FullScore(copy, word, out_.right);
- ProcessRet(ret);
+ ProcessRet(model_.FullScore(copy, word, out_.right));
if (out_.right.length != copy.length + 1) left_done_ = true;
}