summaryrefslogtreecommitdiff
path: root/klm/lm/search_hashed.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_hashed.hh')
-rw-r--r--klm/lm/search_hashed.hh63
1 files changed, 50 insertions, 13 deletions
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index e289fd11..4352c72d 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -8,7 +8,6 @@
#include "lm/weights.hh"
#include "util/bit_packing.hh"
-#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include <algorithm>
@@ -92,8 +91,10 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing);
- const Middle *MiddleBegin() const { return &*middle_.begin(); }
- const Middle *MiddleEnd() const { return &*middle_.end(); }
+ typedef typename std::vector<Middle>::const_iterator MiddleIter;
+
+ MiddleIter MiddleBegin() const { return middle_.begin(); }
+ MiddleIter MiddleEnd() const { return middle_.end(); }
Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
util::FloatEnc val;
@@ -105,7 +106,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl;
abort();
}
- val.f = found->GetValue().prob;
+ val.f = found->value.prob;
}
val.i |= util::kSignBit;
prob = val.f;
@@ -117,12 +118,12 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
util::FloatEnc enc;
- enc.f = found->GetValue().prob;
+ enc.f = found->value.prob;
ret.independent_left = (enc.i & util::kSignBit);
ret.extend_left = node;
enc.i |= util::kSignBit;
ret.prob = enc.f;
- backoff = found->GetValue().backoff;
+ backoff = found->value.backoff;
return true;
}
@@ -132,7 +133,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
- backoff = found->GetValue().backoff;
+ backoff = found->value.backoff;
return true;
}
@@ -141,7 +142,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
node = CombineWordHash(node, word);
typename Longest::ConstIterator found;
if (!longest.Find(node, found)) return false;
- prob = found->GetValue().prob;
+ prob = found->value.prob;
return true;
}
@@ -160,14 +161,50 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
std::vector<Middle> middle_;
};
-// std::identity is an SGI extension :-(
-struct IdentityHash : public std::unary_function<uint64_t, size_t> {
- size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
+/* These look like perfect candidates for a template, right? Ancient gcc (4.1
+ * on RedHat stale linux) doesn't pack templates correctly. ProbBackoffEntry
+ * is a multiple of 8 bytes anyway. ProbEntry is 12 bytes so it's set to pack.
+ */
+struct ProbBackoffEntry {
+ uint64_t key;
+ ProbBackoff value;
+ typedef uint64_t Key;
+ typedef ProbBackoff Value;
+ uint64_t GetKey() const {
+ return key;
+ }
+ static ProbBackoffEntry Make(uint64_t key, ProbBackoff value) {
+ ProbBackoffEntry ret;
+ ret.key = key;
+ ret.value = value;
+ return ret;
+ }
};
+#pragma pack(push)
+#pragma pack(4)
+struct ProbEntry {
+ uint64_t key;
+ Prob value;
+ typedef uint64_t Key;
+ typedef Prob Value;
+ uint64_t GetKey() const {
+ return key;
+ }
+ static ProbEntry Make(uint64_t key, Prob value) {
+ ProbEntry ret;
+ ret.key = key;
+ ret.value = value;
+ return ret;
+ }
+};
+
+#pragma pack(pop)
+
+
struct ProbingHashedSearch : public TemplateHashedSearch<
- util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>,
- util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > {
+ util::ProbingHashTable<ProbBackoffEntry, util::IdentityHash>,
+ util::ProbingHashTable<ProbEntry, util::IdentityHash> > {
static const ModelType kModelType = HASH_PROBING;
};