diff options
Diffstat (limited to 'klm/lm/search_hashed.hh')
-rw-r--r-- | klm/lm/search_hashed.hh | 43 |
1 files changed, 38 insertions, 5 deletions
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index c62985e4..e289fd11 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -1,15 +1,18 @@ #ifndef LM_SEARCH_HASHED__ #define LM_SEARCH_HASHED__ -#include "lm/binary_format.hh" +#include "lm/model_type.hh" #include "lm/config.hh" #include "lm/read_arpa.hh" +#include "lm/return.hh" #include "lm/weights.hh" +#include "util/bit_packing.hh" #include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" #include <algorithm> +#include <iostream> #include <vector> namespace util { class FilePiece; } @@ -52,9 +55,14 @@ struct HashedSearch { Unigram unigram; - void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const { const ProbBackoff &entry = unigram.Lookup(word); - prob = entry.prob; + util::FloatEnc val; + val.f = entry.prob; + ret.independent_left = (val.i & util::kSignBit); + ret.extend_left = static_cast<uint64_t>(word); + val.i |= util::kSignBit; + ret.prob = val.f; backoff = entry.backoff; next = static_cast<Node>(word); } @@ -67,6 +75,8 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has typedef LongestT Longest; Longest longest; + static const unsigned int kVersion = 0; + // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} @@ -85,11 +95,33 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has const Middle *MiddleBegin() const { return &*middle_.begin(); } const Middle *MiddleEnd() const { return &*middle_.end(); } - bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { + util::FloatEnc val; + if (extend_length == 1) { + val.f = unigram.Lookup(static_cast<uint64_t>(extend_pointer)).prob; + } else { + typename Middle::ConstIterator found; + if (!middle_[extend_length - 2].Find(extend_pointer, found)) { + std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl; + abort(); + } + val.f = found->GetValue().prob; + } + val.i |= util::kSignBit; + prob = val.f; + return extend_pointer; + } + + bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { node = CombineWordHash(node, word); typename Middle::ConstIterator found; if (!middle.Find(node, found)) return false; - prob = found->GetValue().prob; + util::FloatEnc enc; + enc.f = found->GetValue().prob; + ret.independent_left = (enc.i & util::kSignBit); + ret.extend_left = node; + enc.i |= util::kSignBit; + ret.prob = enc.f; backoff = found->GetValue().backoff; return true; } @@ -105,6 +137,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has } bool LookupLongest(WordIndex word, float &prob, Node &node) const { + // Sign bit is always on because longest n-grams do not extend left. node = CombineWordHash(node, word); typename Longest::ConstIterator found; if (!longest.Find(node, found)) return false; |