diff options
Diffstat (limited to 'klm/lm/vocab.hh')
-rw-r--r-- | klm/lm/vocab.hh | 36 |
1 files changed, 26 insertions, 10 deletions
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 3c3414fb..06fdefe4 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -4,7 +4,6 @@ #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" -#include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" #include "util/string_piece.hh" @@ -83,7 +82,7 @@ class SortedVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } - void LoadedBinary(int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); private: uint64_t *begin_, *end_; @@ -100,6 +99,26 @@ class SortedVocabulary : public base::Vocabulary { std::vector<std::string> strings_to_enumerate_; }; +#pragma pack(push) +#pragma pack(4) +struct ProbingVocabuaryEntry { + uint64_t key; + WordIndex value; + + typedef uint64_t Key; + uint64_t GetKey() const { + return key; + } + + static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) { + ProbingVocabuaryEntry ret; + ret.key = key; + ret.value = value; + return ret; + } +}; +#pragma pack(pop) + // Vocabulary storing a map from uint64_t to WordIndex. class ProbingVocabulary : public base::Vocabulary { public: @@ -107,7 +126,7 @@ class ProbingVocabulary : public base::Vocabulary { WordIndex Index(const StringPiece &str) const { Lookup::ConstIterator i; - return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0; + return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; } static size_t Size(std::size_t entries, const Config &config); @@ -124,17 +143,14 @@ class ProbingVocabulary : public base::Vocabulary { void FinishedLoading(ProbBackoff *reorder_vocab); + std::size_t UnkCountChangePadding() const { return 0; } + bool SawUnk() const { return saw_unk_; } - void LoadedBinary(int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); private: - // std::identity is an SGI extension :-( - struct IdentityHash : public std::unary_function<uint64_t, std::size_t> { - std::size_t operator()(uint64_t arg) const { return static_cast<std::size_t>(arg); } - }; - - typedef util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, WordIndex>, IdentityHash> Lookup; + typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup; Lookup lookup_; |