From 3204a77dfd5bff0b5c6d12a272ec939a882c7697 Mon Sep 17 00:00:00 2001 From: redpony Date: Wed, 10 Nov 2010 22:45:13 +0000 Subject: forgotten files git-svn-id: https://ws10smt.googlecode.com/svn/trunk@710 ec762483-ff6d-05da-a07a-a48fb63a330f --- klm/lm/search_hashed.hh | 156 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 klm/lm/search_hashed.hh (limited to 'klm/lm/search_hashed.hh') diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh new file mode 100644 index 00000000..1ee2b9e9 --- /dev/null +++ b/klm/lm/search_hashed.hh @@ -0,0 +1,156 @@ +#ifndef LM_SEARCH_HASHED__ +#define LM_SEARCH_HASHED__ + +#include "lm/binary_format.hh" +#include "lm/config.hh" +#include "lm/read_arpa.hh" +#include "lm/weights.hh" + +#include "util/key_value_packing.hh" +#include "util/probing_hash_table.hh" +#include "util/sorted_uniform.hh" + +#include +#include + +namespace util { class FilePiece; } + +namespace lm { +namespace ngram { +namespace detail { + +inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { + uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast(next) * 17894857484156487943ULL); + return ret; +} + +struct HashedSearch { + typedef uint64_t Node; + + class Unigram { + public: + Unigram() {} + + Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast(start)) {} + + static std::size_t Size(uint64_t count) { + return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate + } + + const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; } + + ProbBackoff &Unknown() { return unigram_[0]; } + + void LoadedBinary() {} + + // For building. + ProbBackoff *Raw() { return unigram_; } + + private: + ProbBackoff *unigram_; + }; + + Unigram unigram; + + bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + const ProbBackoff &entry = unigram.Lookup(word); + prob = entry.prob; + backoff = entry.backoff; + next = static_cast(word); + return true; + } +}; + +template struct TemplateHashedSearch : public HashedSearch { + typedef MiddleT Middle; + std::vector middle; + + typedef LongestT Longest; + Longest longest; + + static std::size_t Size(const std::vector &counts, const Config &config) { + std::size_t ret = Unigram::Size(counts[0]); + for (unsigned char n = 1; n < counts.size() - 1; ++n) { + ret += Middle::Size(counts[n], config.probing_multiplier); + } + return ret + Longest::Size(counts.back(), config.probing_multiplier); + } + + uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { + std::size_t allocated = Unigram::Size(counts[0]); + unigram = Unigram(start, allocated); + start += allocated; + for (unsigned int n = 2; n < counts.size(); ++n) { + allocated = Middle::Size(counts[n - 1], config.probing_multiplier); + middle.push_back(Middle(start, allocated)); + start += allocated; + } + allocated = Longest::Size(counts.back(), config.probing_multiplier); + longest = Longest(start, allocated); + start += allocated; + return start; + } + + template void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab); + + bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + prob = found->GetValue().prob; + backoff = found->GetValue().backoff; + return true; + } + + bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + backoff = found->GetValue().backoff; + return true; + } + + bool LookupLongest(WordIndex word, float &prob, Node &node) const { + node = CombineWordHash(node, word); + typename Longest::ConstIterator found; + if (!longest.Find(node, found)) return false; + prob = found->GetValue().prob; + return true; + } + + // Geenrate a node without necessarily checking that it actually exists. + // Optionally return false if it's know to not exist. + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + assert(begin != end); + node = static_cast(*begin); + for (const WordIndex *i = begin + 1; i < end; ++i) { + node = CombineWordHash(node, *i); + } + return true; + } +}; + +// std::identity is an SGI extension :-( +struct IdentityHash : public std::unary_function { + size_t operator()(uint64_t arg) const { return static_cast(arg); } +}; + +struct ProbingHashedSearch : public TemplateHashedSearch< + util::ProbingHashTable, IdentityHash>, + util::ProbingHashTable, IdentityHash> > { + + static const ModelType kModelType = HASH_PROBING; +}; + +struct SortedHashedSearch : public TemplateHashedSearch< + util::SortedUniformMap >, + util::SortedUniformMap > > { + + static const ModelType kModelType = HASH_SORTED; +}; + +} // namespace detail +} // namespace ngram +} // namespace lm + +#endif // LM_SEARCH_HASHED__ -- cgit v1.2.3