summaryrefslogtreecommitdiff
path: root/klm/lm/vocab.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/vocab.hh')
-rw-r--r--klm/lm/vocab.hh36
1 files changed, 26 insertions, 10 deletions
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 3c3414fb..06fdefe4 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -4,7 +4,6 @@
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/virtual_interface.hh"
-#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
@@ -83,7 +82,7 @@ class SortedVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
private:
uint64_t *begin_, *end_;
@@ -100,6 +99,26 @@ class SortedVocabulary : public base::Vocabulary {
std::vector<std::string> strings_to_enumerate_;
};
+#pragma pack(push)
+#pragma pack(4)
+struct ProbingVocabuaryEntry {
+ uint64_t key;
+ WordIndex value;
+
+ typedef uint64_t Key;
+ uint64_t GetKey() const {
+ return key;
+ }
+
+ static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) {
+ ProbingVocabuaryEntry ret;
+ ret.key = key;
+ ret.value = value;
+ return ret;
+ }
+};
+#pragma pack(pop)
+
// Vocabulary storing a map from uint64_t to WordIndex.
class ProbingVocabulary : public base::Vocabulary {
public:
@@ -107,7 +126,7 @@ class ProbingVocabulary : public base::Vocabulary {
WordIndex Index(const StringPiece &str) const {
Lookup::ConstIterator i;
- return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0;
+ return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
}
static size_t Size(std::size_t entries, const Config &config);
@@ -124,17 +143,14 @@ class ProbingVocabulary : public base::Vocabulary {
void FinishedLoading(ProbBackoff *reorder_vocab);
+ std::size_t UnkCountChangePadding() const { return 0; }
+
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
private:
- // std::identity is an SGI extension :-(
- struct IdentityHash : public std::unary_function<uint64_t, std::size_t> {
- std::size_t operator()(uint64_t arg) const { return static_cast<std::size_t>(arg); }
- };
-
- typedef util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, WordIndex>, IdentityHash> Lookup;
+ typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup;
Lookup lookup_;