summaryrefslogtreecommitdiff
path: root/klm/lm/model.hh
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-10-20 02:31:25 +0200
committerPatrick Simianer <p@simianer.de>2011-10-20 02:31:25 +0200
commit92e48b652530d2d2bb4f2694501f95a60d727cb2 (patch)
treeb484bd0c4216525690de8b14fb654c9581a300c2 /klm/lm/model.hh
parent0e70073cec6cdcafaf60d4fbcbd1adf82ae21c8e (diff)
parent082b6c77e0703ccd1c85947828c33d4b0eef20f0 (diff)
finalized merge
Diffstat (limited to 'klm/lm/model.hh')
-rw-r--r--klm/lm/model.hh62
1 files changed, 42 insertions, 20 deletions
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index 21595321..c278acd6 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -12,6 +12,8 @@
#include "lm/vocab.hh"
#include "lm/weights.hh"
+#include "util/murmur_hash.hh"
+
#include <algorithm>
#include <vector>
@@ -27,42 +29,41 @@ namespace ngram {
class State {
public:
bool operator==(const State &other) const {
- if (valid_length_ != other.valid_length_) return false;
- const WordIndex *end = history_ + valid_length_;
- for (const WordIndex *first = history_, *second = other.history_;
- first != end; ++first, ++second) {
- if (*first != *second) return false;
- }
- // If the histories are equal, so are the backoffs.
- return true;
+ if (length != other.length) return false;
+ return !memcmp(words, other.words, length * sizeof(WordIndex));
}
// Three way comparison function.
int Compare(const State &other) const {
- if (valid_length_ == other.valid_length_) {
- return memcmp(history_, other.history_, valid_length_ * sizeof(WordIndex));
- }
- return (valid_length_ < other.valid_length_) ? -1 : 1;
+ if (length != other.length) return length < other.length ? -1 : 1;
+ return memcmp(words, other.words, length * sizeof(WordIndex));
+ }
+
+ bool operator<(const State &other) const {
+ if (length != other.length) return length < other.length;
+ return memcmp(words, other.words, length * sizeof(WordIndex)) < 0;
}
// Call this before using raw memcmp.
void ZeroRemaining() {
- for (unsigned char i = valid_length_; i < kMaxOrder - 1; ++i) {
- history_[i] = 0;
- backoff_[i] = 0.0;
+ for (unsigned char i = length; i < kMaxOrder - 1; ++i) {
+ words[i] = 0;
+ backoff[i] = 0.0;
}
}
- unsigned char ValidLength() const { return valid_length_; }
+ unsigned char Length() const { return length; }
// You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
// This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
- WordIndex history_[kMaxOrder - 1];
- float backoff_[kMaxOrder - 1];
- unsigned char valid_length_;
+ WordIndex words[kMaxOrder - 1];
+ float backoff[kMaxOrder - 1];
+ unsigned char length;
};
-size_t hash_value(const State &state);
+inline size_t hash_value(const State &state) {
+ return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length);
+}
namespace detail {
@@ -75,6 +76,8 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
// This is the model type returned by RecognizeBinary.
static const ModelType kModelType;
+ static const unsigned int kVersion = Search::kVersion;
+
/* Get the size of memory that will be mapped given ngram counts. This
* does not include small non-mapped control structures, such as this class
* itself.
@@ -114,6 +117,25 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
*/
void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const;
+ /* More efficient version of FullScore where a partial n-gram has already
+ * been scored.
+ * NOTE: THE RETURNED .prob IS RELATIVE, NOT ABSOLUTE. So for example, if
+ * the n-gram does not end up extending further left, then 0 is returned.
+ */
+ FullScoreReturn ExtendLeft(
+ // Additional context in reverse order. This will update add_rend to
+ const WordIndex *add_rbegin, const WordIndex *add_rend,
+ // Backoff weights to use.
+ const float *backoff_in,
+ // extend_left returned by a previous query.
+ uint64_t extend_pointer,
+ // Length of n-gram that the pointer corresponds to.
+ unsigned char extend_length,
+ // Where to write additional backoffs for [extend_length + 1, min(Order() - 1, return.ngram_length)]
+ float *backoff_out,
+ // Amount of additional content that should be considered by the next call.
+ unsigned char &next_use) const;
+
private:
friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);