summaryrefslogtreecommitdiff
path: root/klm/lm/search_hashed.hh
diff options
context:
space:
mode:
authorKenneth Heafield <kenlm@kheafield.com>2011-06-26 18:40:15 -0400
committerKenneth Heafield <kenlm@kheafield.com>2011-06-26 18:40:15 -0400
commit205893513c8343fdc55789e427fab4c8b536dc12 (patch)
tree67fdaa819488e231b5d70b2227527510571f2108 /klm/lm/search_hashed.hh
parent9366fc1ce04385290722bd703933bf0c1c166671 (diff)
Quantization
Diffstat (limited to 'klm/lm/search_hashed.hh')
-rw-r--r--klm/lm/search_hashed.hh122
1 files changed, 55 insertions, 67 deletions
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index 6dc11fb3..f3acdefc 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -8,7 +8,6 @@
#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
-#include "util/sorted_uniform.hh"
#include <algorithm>
#include <vector>
@@ -62,73 +61,71 @@ struct HashedSearch {
}
};
-template <class MiddleT, class LongestT> struct TemplateHashedSearch : public HashedSearch {
- typedef MiddleT Middle;
- std::vector<Middle> middle;
+template <class MiddleT, class LongestT> class TemplateHashedSearch : public HashedSearch {
+ public:
+ typedef MiddleT Middle;
- typedef LongestT Longest;
- Longest longest;
+ typedef LongestT Longest;
+ Longest longest;
- static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
- std::size_t ret = Unigram::Size(counts[0]);
- for (unsigned char n = 1; n < counts.size() - 1; ++n) {
- ret += Middle::Size(counts[n], config.probing_multiplier);
- }
- return ret + Longest::Size(counts.back(), config.probing_multiplier);
- }
+ // TODO: move probing_multiplier here with next binary file format update.
+ static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
- uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
- std::size_t allocated = Unigram::Size(counts[0]);
- unigram = Unigram(start, allocated);
- start += allocated;
- for (unsigned int n = 2; n < counts.size(); ++n) {
- allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
- middle.push_back(Middle(start, allocated));
- start += allocated;
+ static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
+ std::size_t ret = Unigram::Size(counts[0]);
+ for (unsigned char n = 1; n < counts.size() - 1; ++n) {
+ ret += Middle::Size(counts[n], config.probing_multiplier);
+ }
+ return ret + Longest::Size(counts.back(), config.probing_multiplier);
}
- allocated = Longest::Size(counts.back(), config.probing_multiplier);
- longest = Longest(start, allocated);
- start += allocated;
- return start;
- }
- template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing);
+ uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
- bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const {
- node = CombineWordHash(node, word);
- typename Middle::ConstIterator found;
- if (!middle.Find(node, found)) return false;
- prob = found->GetValue().prob;
- backoff = found->GetValue().backoff;
- return true;
- }
+ template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing);
- bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const {
- node = CombineWordHash(node, word);
- typename Middle::ConstIterator found;
- if (!middle.Find(node, found)) return false;
- backoff = found->GetValue().backoff;
- return true;
- }
+ const Middle *MiddleBegin() const { return &*middle_.begin(); }
+ const Middle *MiddleEnd() const { return &*middle_.end(); }
- bool LookupLongest(WordIndex word, float &prob, Node &node) const {
- node = CombineWordHash(node, word);
- typename Longest::ConstIterator found;
- if (!longest.Find(node, found)) return false;
- prob = found->GetValue().prob;
- return true;
- }
+ bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const {
+ node = CombineWordHash(node, word);
+ typename Middle::ConstIterator found;
+ if (!middle.Find(node, found)) return false;
+ prob = found->GetValue().prob;
+ backoff = found->GetValue().backoff;
+ return true;
+ }
+
+ void LoadedBinary();
- // Geenrate a node without necessarily checking that it actually exists.
- // Optionally return false if it's know to not exist.
- bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
- assert(begin != end);
- node = static_cast<Node>(*begin);
- for (const WordIndex *i = begin + 1; i < end; ++i) {
- node = CombineWordHash(node, *i);
+ bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const {
+ node = CombineWordHash(node, word);
+ typename Middle::ConstIterator found;
+ if (!middle.Find(node, found)) return false;
+ backoff = found->GetValue().backoff;
+ return true;
}
- return true;
- }
+
+ bool LookupLongest(WordIndex word, float &prob, Node &node) const {
+ node = CombineWordHash(node, word);
+ typename Longest::ConstIterator found;
+ if (!longest.Find(node, found)) return false;
+ prob = found->GetValue().prob;
+ return true;
+ }
+
+ // Geenrate a node without necessarily checking that it actually exists.
+ // Optionally return false if it's know to not exist.
+ bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
+ assert(begin != end);
+ node = static_cast<Node>(*begin);
+ for (const WordIndex *i = begin + 1; i < end; ++i) {
+ node = CombineWordHash(node, *i);
+ }
+ return true;
+ }
+
+ private:
+ std::vector<Middle> middle_;
};
// std::identity is an SGI extension :-(
@@ -143,15 +140,6 @@ struct ProbingHashedSearch : public TemplateHashedSearch<
static const ModelType kModelType = HASH_PROBING;
};
-struct SortedHashedSearch : public TemplateHashedSearch<
- util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, ProbBackoff> >,
- util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, Prob> > > {
-
- SortedHashedSearch();
-
- static const ModelType kModelType = HASH_SORTED;
-};
-
} // namespace detail
} // namespace ngram
} // namespace lm