diff options
Diffstat (limited to 'klm/lm/search_hashed.hh')
-rw-r--r-- | klm/lm/search_hashed.hh | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh new file mode 100644 index 00000000..1ee2b9e9 --- /dev/null +++ b/klm/lm/search_hashed.hh @@ -0,0 +1,156 @@ +#ifndef LM_SEARCH_HASHED__ +#define LM_SEARCH_HASHED__ + +#include "lm/binary_format.hh" +#include "lm/config.hh" +#include "lm/read_arpa.hh" +#include "lm/weights.hh" + +#include "util/key_value_packing.hh" +#include "util/probing_hash_table.hh" +#include "util/sorted_uniform.hh" + +#include <algorithm> +#include <vector> + +namespace util { class FilePiece; } + +namespace lm { +namespace ngram { +namespace detail { + +inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { + uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL); + return ret; +} + +struct HashedSearch { + typedef uint64_t Node; + + class Unigram { + public: + Unigram() {} + + Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {} + + static std::size_t Size(uint64_t count) { + return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk> + } + + const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; } + + ProbBackoff &Unknown() { return unigram_[0]; } + + void LoadedBinary() {} + + // For building. + ProbBackoff *Raw() { return unigram_; } + + private: + ProbBackoff *unigram_; + }; + + Unigram unigram; + + bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + const ProbBackoff &entry = unigram.Lookup(word); + prob = entry.prob; + backoff = entry.backoff; + next = static_cast<Node>(word); + return true; + } +}; + +template <class MiddleT, class LongestT> struct TemplateHashedSearch : public HashedSearch { + typedef MiddleT Middle; + std::vector<Middle> middle; + + typedef LongestT Longest; + Longest longest; + + static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { + std::size_t ret = Unigram::Size(counts[0]); + for (unsigned char n = 1; n < counts.size() - 1; ++n) { + ret += Middle::Size(counts[n], config.probing_multiplier); + } + return ret + Longest::Size(counts.back(), config.probing_multiplier); + } + + uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { + std::size_t allocated = Unigram::Size(counts[0]); + unigram = Unigram(start, allocated); + start += allocated; + for (unsigned int n = 2; n < counts.size(); ++n) { + allocated = Middle::Size(counts[n - 1], config.probing_multiplier); + middle.push_back(Middle(start, allocated)); + start += allocated; + } + allocated = Longest::Size(counts.back(), config.probing_multiplier); + longest = Longest(start, allocated); + start += allocated; + return start; + } + + template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab); + + bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + prob = found->GetValue().prob; + backoff = found->GetValue().backoff; + return true; + } + + bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + backoff = found->GetValue().backoff; + return true; + } + + bool LookupLongest(WordIndex word, float &prob, Node &node) const { + node = CombineWordHash(node, word); + typename Longest::ConstIterator found; + if (!longest.Find(node, found)) return false; + prob = found->GetValue().prob; + return true; + } + + // Geenrate a node without necessarily checking that it actually exists. + // Optionally return false if it's know to not exist. + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + assert(begin != end); + node = static_cast<Node>(*begin); + for (const WordIndex *i = begin + 1; i < end; ++i) { + node = CombineWordHash(node, *i); + } + return true; + } +}; + +// std::identity is an SGI extension :-( +struct IdentityHash : public std::unary_function<uint64_t, size_t> { + size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); } +}; + +struct ProbingHashedSearch : public TemplateHashedSearch< + util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>, + util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > { + + static const ModelType kModelType = HASH_PROBING; +}; + +struct SortedHashedSearch : public TemplateHashedSearch< + util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, ProbBackoff> >, + util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, Prob> > > { + + static const ModelType kModelType = HASH_SORTED; +}; + +} // namespace detail +} // namespace ngram +} // namespace lm + +#endif // LM_SEARCH_HASHED__ |