diff options
Diffstat (limited to 'klm/lm/search_hashed.hh')
-rw-r--r-- | klm/lm/search_hashed.hh | 229 |
1 files changed, 107 insertions, 122 deletions
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 4352c72d..7e8c1220 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -19,6 +19,7 @@ namespace util { class FilePiece; } namespace lm { namespace ngram { struct Backing; +class ProbingVocabulary; namespace detail { inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { @@ -26,54 +27,48 @@ inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { return ret; } -struct HashedSearch { - typedef uint64_t Node; - - class Unigram { - public: - Unigram() {} - - Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {} - - static std::size_t Size(uint64_t count) { - return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk> - } - - const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; } +#pragma pack(push) +#pragma pack(4) +struct ProbEntry { + uint64_t key; + Prob value; + typedef uint64_t Key; + typedef Prob Value; + uint64_t GetKey() const { + return key; + } +}; - ProbBackoff &Unknown() { return unigram_[0]; } +#pragma pack(pop) - void LoadedBinary() {} +class LongestPointer { + public: + explicit LongestPointer(const float &to) : to_(&to) {} - // For building. - ProbBackoff *Raw() { return unigram_; } + LongestPointer() : to_(NULL) {} - private: - ProbBackoff *unigram_; - }; + bool Found() const { + return to_ != NULL; + } - Unigram unigram; + float Prob() const { + return *to_; + } - void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const { - const ProbBackoff &entry = unigram.Lookup(word); - util::FloatEnc val; - val.f = entry.prob; - ret.independent_left = (val.i & util::kSignBit); - ret.extend_left = static_cast<uint64_t>(word); - val.i |= util::kSignBit; - ret.prob = val.f; - backoff = entry.backoff; - next = static_cast<Node>(word); - } + private: + const float *to_; }; -template <class MiddleT, class LongestT> class TemplateHashedSearch : public HashedSearch { +template <class Value> class HashedSearch { public: - typedef MiddleT Middle; + typedef uint64_t Node; - typedef LongestT Longest; - Longest longest; + typedef typename Value::ProbingProxy UnigramPointer; + typedef typename Value::ProbingProxy MiddlePointer; + typedef ::lm::ngram::detail::LongestPointer LongestPointer; + static const ModelType kModelType = Value::kProbingModelType; + static const bool kDifferentRest = Value::kDifferentRest; static const unsigned int kVersion = 0; // TODO: move probing_multiplier here with next binary file format update. @@ -89,64 +84,55 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); - template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing); + void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing); - typedef typename std::vector<Middle>::const_iterator MiddleIter; + void LoadedBinary(); - MiddleIter MiddleBegin() const { return middle_.begin(); } - MiddleIter MiddleEnd() const { return middle_.end(); } + unsigned char Order() const { + return middle_.size() + 2; + } - Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { - util::FloatEnc val; - if (extend_length == 1) { - val.f = unigram.Lookup(static_cast<uint64_t>(extend_pointer)).prob; - } else { - typename Middle::ConstIterator found; - if (!middle_[extend_length - 2].Find(extend_pointer, found)) { - std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl; - abort(); - } - val.f = found->value.prob; - } - val.i |= util::kSignBit; - prob = val.f; - return extend_pointer; + typename Value::Weights &UnknownUnigram() { return unigram_.Unknown(); } + + UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const { + extend_left = static_cast<uint64_t>(word); + next = extend_left; + UnigramPointer ret(unigram_.Lookup(word)); + independent_left = ret.IndependentLeft(); + return ret; } - bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { - node = CombineWordHash(node, word); +#pragma GCC diagnostic ignored "-Wuninitialized" + MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const { + node = extend_pointer; typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - util::FloatEnc enc; - enc.f = found->value.prob; - ret.independent_left = (enc.i & util::kSignBit); - ret.extend_left = node; - enc.i |= util::kSignBit; - ret.prob = enc.f; - backoff = found->value.backoff; - return true; + bool got = middle_[extend_length - 2].Find(extend_pointer, found); + assert(got); + (void)got; + return MiddlePointer(found->value); } - void LoadedBinary(); - - bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { + MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const { node = CombineWordHash(node, word); typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - backoff = found->value.backoff; - return true; + if (!middle_[order_minus_2].Find(node, found)) { + independent_left = true; + return MiddlePointer(); + } + extend_pointer = node; + MiddlePointer ret(found->value); + independent_left = ret.IndependentLeft(); + return ret; } - bool LookupLongest(WordIndex word, float &prob, Node &node) const { + LongestPointer LookupLongest(WordIndex word, const Node &node) const { // Sign bit is always on because longest n-grams do not extend left. - node = CombineWordHash(node, word); typename Longest::ConstIterator found; - if (!longest.Find(node, found)) return false; - prob = found->value.prob; - return true; + if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer(); + return LongestPointer(found->value.prob); } - // Geenrate a node without necessarily checking that it actually exists. + // Generate a node without necessarily checking that it actually exists. // Optionally return false if it's know to not exist. bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { assert(begin != end); @@ -158,55 +144,54 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has } private: - std::vector<Middle> middle_; -}; + // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. + void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); -/* These look like perfect candidates for a template, right? Ancient gcc (4.1 - * on RedHat stale linux) doesn't pack templates correctly. ProbBackoffEntry - * is a multiple of 8 bytes anyway. ProbEntry is 12 bytes so it's set to pack. - */ -struct ProbBackoffEntry { - uint64_t key; - ProbBackoff value; - typedef uint64_t Key; - typedef ProbBackoff Value; - uint64_t GetKey() const { - return key; - } - static ProbBackoffEntry Make(uint64_t key, ProbBackoff value) { - ProbBackoffEntry ret; - ret.key = key; - ret.value = value; - return ret; - } -}; + template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); -#pragma pack(push) -#pragma pack(4) -struct ProbEntry { - uint64_t key; - Prob value; - typedef uint64_t Key; - typedef Prob Value; - uint64_t GetKey() const { - return key; - } - static ProbEntry Make(uint64_t key, Prob value) { - ProbEntry ret; - ret.key = key; - ret.value = value; - return ret; - } -}; + class Unigram { + public: + Unigram() {} -#pragma pack(pop) + Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : + unigram_(static_cast<typename Value::Weights*>(start)) +#ifdef DEBUG + , count_(count) +#endif + {} + + static std::size_t Size(uint64_t count) { + return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk> + } + + const typename Value::Weights &Lookup(WordIndex index) const { +#ifdef DEBUG + assert(index < count_); +#endif + return unigram_[index]; + } + + typename Value::Weights &Unknown() { return unigram_[0]; } + void LoadedBinary() {} -struct ProbingHashedSearch : public TemplateHashedSearch< - util::ProbingHashTable<ProbBackoffEntry, util::IdentityHash>, - util::ProbingHashTable<ProbEntry, util::IdentityHash> > { + // For building. + typename Value::Weights *Raw() { return unigram_; } + + private: + typename Value::Weights *unigram_; +#ifdef DEBUG + uint64_t count_; +#endif + }; + + Unigram unigram_; + + typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle; + std::vector<Middle> middle_; - static const ModelType kModelType = HASH_PROBING; + typedef util::ProbingHashTable<ProbEntry, util::IdentityHash> Longest; + Longest longest_; }; } // namespace detail |