#ifndef LM_SEARCH_HASHED__ #define LM_SEARCH_HASHED__ #include "lm/model_type.hh" #include "lm/config.hh" #include "lm/read_arpa.hh" #include "lm/return.hh" #include "lm/weights.hh" #include "util/bit_packing.hh" #include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" #include <algorithm> #include <iostream> #include <vector> namespace util { class FilePiece; } namespace lm { namespace ngram { struct Backing; namespace detail { inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL); return ret; } struct HashedSearch { typedef uint64_t Node; class Unigram { public: Unigram() {} Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {} static std::size_t Size(uint64_t count) { return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk> } const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; } ProbBackoff &Unknown() { return unigram_[0]; } void LoadedBinary() {} // For building. ProbBackoff *Raw() { return unigram_; } private: ProbBackoff *unigram_; }; Unigram unigram; void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const { const ProbBackoff &entry = unigram.Lookup(word); util::FloatEnc val; val.f = entry.prob; ret.independent_left = (val.i & util::kSignBit); ret.extend_left = static_cast<uint64_t>(word); val.i |= util::kSignBit; ret.prob = val.f; backoff = entry.backoff; next = static_cast<Node>(word); } }; template <class MiddleT, class LongestT> class TemplateHashedSearch : public HashedSearch { public: typedef MiddleT Middle; typedef LongestT Longest; Longest longest; static const unsigned int kVersion = 0; // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { std::size_t ret = Unigram::Size(counts[0]); for (unsigned char n = 1; n < counts.size() - 1; ++n) { ret += Middle::Size(counts[n], config.probing_multiplier); } return ret + Longest::Size(counts.back(), config.probing_multiplier); } uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing); const Middle *MiddleBegin() const { return &*middle_.begin(); } const Middle *MiddleEnd() const { return &*middle_.end(); } Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { util::FloatEnc val; if (extend_length == 1) { val.f = unigram.Lookup(static_cast<uint64_t>(extend_pointer)).prob; } else { typename Middle::ConstIterator found; if (!middle_[extend_length - 2].Find(extend_pointer, found)) { std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl; abort(); } val.f = found->GetValue().prob; } val.i |= util::kSignBit; prob = val.f; return extend_pointer; } bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { node = CombineWordHash(node, word); typename Middle::ConstIterator found; if (!middle.Find(node, found)) return false; util::FloatEnc enc; enc.f = found->GetValue().prob; ret.independent_left = (enc.i & util::kSignBit); ret.extend_left = node; enc.i |= util::kSignBit; ret.prob = enc.f; backoff = found->GetValue().backoff; return true; } void LoadedBinary(); bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { node = CombineWordHash(node, word); typename Middle::ConstIterator found; if (!middle.Find(node, found)) return false; backoff = found->GetValue().backoff; return true; } bool LookupLongest(WordIndex word, float &prob, Node &node) const { // Sign bit is always on because longest n-grams do not extend left. node = CombineWordHash(node, word); typename Longest::ConstIterator found; if (!longest.Find(node, found)) return false; prob = found->GetValue().prob; return true; } // Geenrate a node without necessarily checking that it actually exists. // Optionally return false if it's know to not exist. bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { assert(begin != end); node = static_cast<Node>(*begin); for (const WordIndex *i = begin + 1; i < end; ++i) { node = CombineWordHash(node, *i); } return true; } private: std::vector<Middle> middle_; }; // std::identity is an SGI extension :-( struct IdentityHash : public std::unary_function<uint64_t, size_t> { size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); } }; struct ProbingHashedSearch : public TemplateHashedSearch< util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>, util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > { static const ModelType kModelType = HASH_PROBING; }; } // namespace detail } // namespace ngram } // namespace lm #endif // LM_SEARCH_HASHED__