summaryrefslogtreecommitdiff
path: root/klm/lm/vocab.hh
diff options
context:
space:
mode:
authorKenneth Heafield <kenlm@kheafield.com>2011-06-26 18:40:15 -0400
committerKenneth Heafield <kenlm@kheafield.com>2011-06-26 18:40:15 -0400
commit205893513c8343fdc55789e427fab4c8b536dc12 (patch)
tree67fdaa819488e231b5d70b2227527510571f2108 /klm/lm/vocab.hh
parent9366fc1ce04385290722bd703933bf0c1c166671 (diff)
Quantization
Diffstat (limited to 'klm/lm/vocab.hh')
-rw-r--r--klm/lm/vocab.hh35
1 files changed, 22 insertions, 13 deletions
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 546c1649..c92518e4 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -9,6 +9,7 @@
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
+#include <limits>
#include <string>
#include <vector>
@@ -44,22 +45,16 @@ class WriteWordsWrapper : public EnumerateVocab {
// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
class SortedVocabulary : public base::Vocabulary {
- private:
- // Sorted uniform requires a GetKey function.
- struct Entry {
- uint64_t GetKey() const { return key; }
- uint64_t key;
- bool operator<(const Entry &other) const {
- return key < other.key;
- }
- };
-
public:
SortedVocabulary();
WordIndex Index(const StringPiece &str) const {
- const Entry *found;
- if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) {
+ const uint64_t *found;
+ if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>(
+ util::IdentityAccessor<uint64_t>(),
+ begin_ - 1, 0,
+ end_, std::numeric_limits<uint64_t>::max(),
+ detail::HashForVocab(str), found)) {
return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
} else {
return 0;
@@ -68,6 +63,10 @@ class SortedVocabulary : public base::Vocabulary {
static size_t Size(std::size_t entries, const Config &config);
+ // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
+ // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases.
+ WordIndex Bound() const { return bound_; }
+
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
@@ -83,7 +82,11 @@ class SortedVocabulary : public base::Vocabulary {
void LoadedBinary(int fd, EnumerateVocab *to);
private:
- Entry *begin_, *end_;
+ uint64_t *begin_, *end_;
+
+ WordIndex bound_;
+
+ WordIndex highest_value_;
bool saw_unk_;
@@ -105,6 +108,12 @@ class ProbingVocabulary : public base::Vocabulary {
static size_t Size(std::size_t entries, const Config &config);
+ // Vocab words are [0, Bound()).
+ // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary.
+ // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update.
+ // Specifically, the binary file format does not currently indicate whether <unk> is in count or not.
+ WordIndex Bound() const { return available_; }
+
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);