summaryrefslogtreecommitdiff
path: root/klm/lm/search_hashed.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_hashed.hh')
-rw-r--r--klm/lm/search_hashed.hh43
1 files changed, 38 insertions, 5 deletions
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index c62985e4..e289fd11 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -1,15 +1,18 @@
#ifndef LM_SEARCH_HASHED__
#define LM_SEARCH_HASHED__
-#include "lm/binary_format.hh"
+#include "lm/model_type.hh"
#include "lm/config.hh"
#include "lm/read_arpa.hh"
+#include "lm/return.hh"
#include "lm/weights.hh"
+#include "util/bit_packing.hh"
#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include <algorithm>
+#include <iostream>
#include <vector>
namespace util { class FilePiece; }
@@ -52,9 +55,14 @@ struct HashedSearch {
Unigram unigram;
- void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {
+ void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const {
const ProbBackoff &entry = unigram.Lookup(word);
- prob = entry.prob;
+ util::FloatEnc val;
+ val.f = entry.prob;
+ ret.independent_left = (val.i & util::kSignBit);
+ ret.extend_left = static_cast<uint64_t>(word);
+ val.i |= util::kSignBit;
+ ret.prob = val.f;
backoff = entry.backoff;
next = static_cast<Node>(word);
}
@@ -67,6 +75,8 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
typedef LongestT Longest;
Longest longest;
+ static const unsigned int kVersion = 0;
+
// TODO: move probing_multiplier here with next binary file format update.
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
@@ -85,11 +95,33 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
const Middle *MiddleBegin() const { return &*middle_.begin(); }
const Middle *MiddleEnd() const { return &*middle_.end(); }
- bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const {
+ Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
+ util::FloatEnc val;
+ if (extend_length == 1) {
+ val.f = unigram.Lookup(static_cast<uint64_t>(extend_pointer)).prob;
+ } else {
+ typename Middle::ConstIterator found;
+ if (!middle_[extend_length - 2].Find(extend_pointer, found)) {
+ std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl;
+ abort();
+ }
+ val.f = found->GetValue().prob;
+ }
+ val.i |= util::kSignBit;
+ prob = val.f;
+ return extend_pointer;
+ }
+
+ bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
- prob = found->GetValue().prob;
+ util::FloatEnc enc;
+ enc.f = found->GetValue().prob;
+ ret.independent_left = (enc.i & util::kSignBit);
+ ret.extend_left = node;
+ enc.i |= util::kSignBit;
+ ret.prob = enc.f;
backoff = found->GetValue().backoff;
return true;
}
@@ -105,6 +137,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
}
bool LookupLongest(WordIndex word, float &prob, Node &node) const {
+ // Sign bit is always on because longest n-grams do not extend left.
node = CombineWordHash(node, word);
typename Longest::ConstIterator found;
if (!longest.Find(node, found)) return false;