summaryrefslogtreecommitdiff
path: root/klm/lm/vocab.hh
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-07-05 23:19:54 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2011-07-05 23:19:54 -0400
commitc3e46171f722f6276e2613ea6cb087b07325d794 (patch)
tree4a289539c4e7a972009dc2f1b680004b959547df /klm/lm/vocab.hh
parentf91319978f6e74e5c4e5701da8fbbacb96a3161e (diff)
parent59932be2de387ecfcaa81a8387e8f21d5123c050 (diff)
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'klm/lm/vocab.hh')
-rw-r--r--klm/lm/vocab.hh35
1 files changed, 22 insertions, 13 deletions
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 546c1649..c92518e4 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -9,6 +9,7 @@
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
+#include <limits>
#include <string>
#include <vector>
@@ -44,22 +45,16 @@ class WriteWordsWrapper : public EnumerateVocab {
// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
class SortedVocabulary : public base::Vocabulary {
- private:
- // Sorted uniform requires a GetKey function.
- struct Entry {
- uint64_t GetKey() const { return key; }
- uint64_t key;
- bool operator<(const Entry &other) const {
- return key < other.key;
- }
- };
-
public:
SortedVocabulary();
WordIndex Index(const StringPiece &str) const {
- const Entry *found;
- if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) {
+ const uint64_t *found;
+ if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>(
+ util::IdentityAccessor<uint64_t>(),
+ begin_ - 1, 0,
+ end_, std::numeric_limits<uint64_t>::max(),
+ detail::HashForVocab(str), found)) {
return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
} else {
return 0;
@@ -68,6 +63,10 @@ class SortedVocabulary : public base::Vocabulary {
static size_t Size(std::size_t entries, const Config &config);
+ // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
+ // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases.
+ WordIndex Bound() const { return bound_; }
+
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
@@ -83,7 +82,11 @@ class SortedVocabulary : public base::Vocabulary {
void LoadedBinary(int fd, EnumerateVocab *to);
private:
- Entry *begin_, *end_;
+ uint64_t *begin_, *end_;
+
+ WordIndex bound_;
+
+ WordIndex highest_value_;
bool saw_unk_;
@@ -105,6 +108,12 @@ class ProbingVocabulary : public base::Vocabulary {
static size_t Size(std::size_t entries, const Config &config);
+ // Vocab words are [0, Bound()).
+ // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary.
+ // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update.
+ // Specifically, the binary file format does not currently indicate whether <unk> is in count or not.
+ WordIndex Bound() const { return available_; }
+
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);