diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-07-05 23:19:54 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-07-05 23:19:54 -0400 |
commit | c3e46171f722f6276e2613ea6cb087b07325d794 (patch) | |
tree | 4a289539c4e7a972009dc2f1b680004b959547df /klm/lm/vocab.hh | |
parent | f91319978f6e74e5c4e5701da8fbbacb96a3161e (diff) | |
parent | 59932be2de387ecfcaa81a8387e8f21d5123c050 (diff) |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'klm/lm/vocab.hh')
-rw-r--r-- | klm/lm/vocab.hh | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 546c1649..c92518e4 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -9,6 +9,7 @@ #include "util/sorted_uniform.hh" #include "util/string_piece.hh" +#include <limits> #include <string> #include <vector> @@ -44,22 +45,16 @@ class WriteWordsWrapper : public EnumerateVocab { // Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. class SortedVocabulary : public base::Vocabulary { - private: - // Sorted uniform requires a GetKey function. - struct Entry { - uint64_t GetKey() const { return key; } - uint64_t key; - bool operator<(const Entry &other) const { - return key < other.key; - } - }; - public: SortedVocabulary(); WordIndex Index(const StringPiece &str) const { - const Entry *found; - if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) { + const uint64_t *found; + if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>( + util::IdentityAccessor<uint64_t>(), + begin_ - 1, 0, + end_, std::numeric_limits<uint64_t>::max(), + detail::HashForVocab(str), found)) { return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table. } else { return 0; @@ -68,6 +63,10 @@ class SortedVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); + // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. + // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases. + WordIndex Bound() const { return bound_; } + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); @@ -83,7 +82,11 @@ class SortedVocabulary : public base::Vocabulary { void LoadedBinary(int fd, EnumerateVocab *to); private: - Entry *begin_, *end_; + uint64_t *begin_, *end_; + + WordIndex bound_; + + WordIndex highest_value_; bool saw_unk_; @@ -105,6 +108,12 @@ class ProbingVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); + // Vocab words are [0, Bound()). + // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary. + // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update. + // Specifically, the binary file format does not currently indicate whether <unk> is in count or not. + WordIndex Bound() const { return available_; } + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); |