#ifndef LM_SEARCH_HASHED__ #define LM_SEARCH_HASHED__ #include "lm/model_type.hh" #include "lm/config.hh" #include "lm/read_arpa.hh" #include "lm/return.hh" #include "lm/weights.hh" #include "util/bit_packing.hh" #include "util/probing_hash_table.hh" #include <algorithm> #include <iostream> #include <vector> namespace util { class FilePiece; } namespace lm { namespace ngram { struct Backing; class ProbingVocabulary; namespace detail { inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL); return ret; } #pragma pack(push) #pragma pack(4) struct ProbEntry { uint64_t key; Prob value; typedef uint64_t Key; typedef Prob Value; uint64_t GetKey() const { return key; } }; #pragma pack(pop) class LongestPointer { public: explicit LongestPointer(const float &to) : to_(&to) {} LongestPointer() : to_(NULL) {} bool Found() const { return to_ != NULL; } float Prob() const { return *to_; } private: const float *to_; }; template <class Value> class HashedSearch { public: typedef uint64_t Node; typedef typename Value::ProbingProxy UnigramPointer; typedef typename Value::ProbingProxy MiddlePointer; typedef ::lm::ngram::detail::LongestPointer LongestPointer; static const ModelType kModelType = Value::kProbingModelType; static const bool kDifferentRest = Value::kDifferentRest; static const unsigned int kVersion = 0; // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { uint64_t ret = Unigram::Size(counts[0]); for (unsigned char n = 1; n < counts.size() - 1; ++n) { ret += Middle::Size(counts[n], config.probing_multiplier); } return ret + Longest::Size(counts.back(), config.probing_multiplier); } uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing); void LoadedBinary(); unsigned char Order() const { return middle_.size() + 2; } typename Value::Weights &UnknownUnigram() { return unigram_.Unknown(); } UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const { extend_left = static_cast<uint64_t>(word); next = extend_left; UnigramPointer ret(unigram_.Lookup(word)); independent_left = ret.IndependentLeft(); return ret; } MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const { node = extend_pointer; return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value); } MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const { node = CombineWordHash(node, word); typename Middle::ConstIterator found; if (!middle_[order_minus_2].Find(node, found)) { independent_left = true; return MiddlePointer(); } extend_pointer = node; MiddlePointer ret(found->value); independent_left = ret.IndependentLeft(); return ret; } LongestPointer LookupLongest(WordIndex word, const Node &node) const { // Sign bit is always on because longest n-grams do not extend left. typename Longest::ConstIterator found; if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer(); return LongestPointer(found->value.prob); } // Generate a node without necessarily checking that it actually exists. // Optionally return false if it's know to not exist. bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { assert(begin != end); node = static_cast<Node>(*begin); for (const WordIndex *i = begin + 1; i < end; ++i) { node = CombineWordHash(node, *i); } return true; } private: // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); class Unigram { public: Unigram() {} Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : unigram_(static_cast<typename Value::Weights*>(start)) #ifdef DEBUG , count_(count) #endif {} static uint64_t Size(uint64_t count) { return (count + 1) * sizeof(typename Value::Weights); // +1 for hallucinate <unk> } const typename Value::Weights &Lookup(WordIndex index) const { #ifdef DEBUG assert(index < count_); #endif return unigram_[index]; } typename Value::Weights &Unknown() { return unigram_[0]; } void LoadedBinary() {} // For building. typename Value::Weights *Raw() { return unigram_; } private: typename Value::Weights *unigram_; #ifdef DEBUG uint64_t count_; #endif }; Unigram unigram_; typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle; std::vector<Middle> middle_; typedef util::ProbingHashTable<ProbEntry, util::IdentityHash> Longest; Longest longest_; }; } // namespace detail } // namespace ngram } // namespace lm #endif // LM_SEARCH_HASHED__