summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/left.hh70
-rw-r--r--klm/lm/model.cc5
-rw-r--r--klm/lm/model.hh25
3 files changed, 76 insertions, 24 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index 837be765..effa0560 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -1,3 +1,40 @@
+/* Efficient left and right language model state for sentence fragments.
+ * Intended usage:
+ * Store ChartState with every chart entry.
+ * To do a rule application:
+ * 1. Make a ChartState object for your new entry.
+ * 2. Construct RuleScore.
+ * 3. Going from left to right, call Terminal or NonTerminal.
+ * For terminals, just pass the vocab id.
+ * For non-terminals, pass that non-terminal's ChartState.
+ * If your decoder expects scores inclusive of subtree scores (i.e. you
+ * label entries with the highest-scoring path), pass the non-terminal's
+ * score as prob.
+ * If your decoder expects relative scores and will walk the chart later,
+ * pass prob = 0.0.
+ * In other words, the only effect of prob is that it gets added to the
+ * returned log probability.
+ * 4. Call Finish. It returns the log probability.
+ *
+ * There's a couple more details:
+ * Do not pass <s> to Terminal as it is formally not a word in the sentence,
+ * only context. Instead, call BeginSentence. If called, it should be the
+ * first call after RuleScore is constructed (since <s> is always the
+ * leftmost).
+ *
+ * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal.
+ *
+ * Hashing and sorting comparison operators are provided. All state objects
+ * are POD. If you intend to use memcmp on raw state objects, you must call
+ * ZeroRemaining first, as the value of array entries beyond length is
+ * otherwise undefined.
+ *
+ * Usage is of course not limited to chart decoding. Anything that generates
+ * sentence fragments missing left context could benefit. For example, a
+ * phrase-based decoder could pre-score phrases, storing ChartState with each
+ * phrase, even if hypotheses are generated left-to-right.
+ */
+
#ifndef LM_LEFT__
#define LM_LEFT__
@@ -5,6 +42,8 @@
#include "lm/model.hh"
#include "lm/return.hh"
+#include "util/murmur_hash.hh"
+
#include <algorithm>
namespace lm {
@@ -18,23 +57,30 @@ struct Left {
}
int Compare(const Left &other) const {
- if (length != other.length) {
- return (int)length - (int)other.length;
- }
+ if (length != other.length) return length < other.length ? -1 : 1;
if (pointers[length - 1] > other.pointers[length - 1]) return 1;
if (pointers[length - 1] < other.pointers[length - 1]) return -1;
return 0;
}
+ bool operator<(const Left &other) const {
+ if (length != other.length) return length < other.length;
+ return pointers[length - 1] < other.pointers[length - 1];
+ }
+
void ZeroRemaining() {
for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i)
*i = 0;
}
- uint64_t pointers[kMaxOrder - 1];
unsigned char length;
+ uint64_t pointers[kMaxOrder - 1];
};
+inline size_t hash_value(const Left &left) {
+ return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]);
+}
+
struct ChartState {
bool operator==(const ChartState &other) {
return (left == other.left) && (right == other.right) && (full == other.full);
@@ -48,16 +94,27 @@ struct ChartState {
return (int)full - (int)other.full;
}
+ bool operator<(const ChartState &other) const {
+ return Compare(other) == -1;
+ }
+
void ZeroRemaining() {
left.ZeroRemaining();
right.ZeroRemaining();
}
Left left;
- State right;
bool full;
+ State right;
};
+inline size_t hash_value(const ChartState &state) {
+ size_t hashes[2];
+ hashes[0] = hash_value(state.left);
+ hashes[1] = hash_value(state.right);
+ return util::MurmurHashNative(hashes, sizeof(size_t), state.full);
+}
+
template <class M> class RuleScore {
public:
explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) {
@@ -73,8 +130,7 @@ template <class M> class RuleScore {
void Terminal(WordIndex word) {
State copy(out_.right);
- FullScoreReturn ret = model_.FullScore(copy, word, out_.right);
- ProcessRet(ret);
+ ProcessRet(model_.FullScore(copy, word, out_.right));
if (out_.right.length != copy.length + 1) left_done_ = true;
}
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index ca581d8a..25f1ab7c 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -14,11 +14,6 @@
namespace lm {
namespace ngram {
-
-size_t hash_value(const State &state) {
- return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length);
-}
-
namespace detail {
template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType;
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index fe91af2e..c278acd6 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -12,6 +12,8 @@
#include "lm/vocab.hh"
#include "lm/weights.hh"
+#include "util/murmur_hash.hh"
+
#include <algorithm>
#include <vector>
@@ -28,21 +30,18 @@ class State {
public:
bool operator==(const State &other) const {
if (length != other.length) return false;
- const WordIndex *end = words + length;
- for (const WordIndex *first = words, *second = other.words;
- first != end; ++first, ++second) {
- if (*first != *second) return false;
- }
- // If the histories are equal, so are the backoffs.
- return true;
+ return !memcmp(words, other.words, length * sizeof(WordIndex));
}
// Three way comparison function.
int Compare(const State &other) const {
- if (length == other.length) {
- return memcmp(words, other.words, length * sizeof(WordIndex));
- }
- return (length < other.length) ? -1 : 1;
+ if (length != other.length) return length < other.length ? -1 : 1;
+ return memcmp(words, other.words, length * sizeof(WordIndex));
+ }
+
+ bool operator<(const State &other) const {
+ if (length != other.length) return length < other.length;
+ return memcmp(words, other.words, length * sizeof(WordIndex)) < 0;
}
// Call this before using raw memcmp.
@@ -62,7 +61,9 @@ class State {
unsigned char length;
};
-size_t hash_value(const State &state);
+inline size_t hash_value(const State &state) {
+ return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length);
+}
namespace detail {