summaryrefslogtreecommitdiff
path: root/klm/lm/search_hashed.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_hashed.hh')
-rw-r--r--klm/lm/search_hashed.hh229
1 files changed, 107 insertions, 122 deletions
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index 4352c72d..7e8c1220 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -19,6 +19,7 @@ namespace util { class FilePiece; }
namespace lm {
namespace ngram {
struct Backing;
+class ProbingVocabulary;
namespace detail {
inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
@@ -26,54 +27,48 @@ inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
return ret;
}
-struct HashedSearch {
- typedef uint64_t Node;
-
- class Unigram {
- public:
- Unigram() {}
-
- Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {}
-
- static std::size_t Size(uint64_t count) {
- return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
- }
-
- const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; }
+#pragma pack(push)
+#pragma pack(4)
+struct ProbEntry {
+ uint64_t key;
+ Prob value;
+ typedef uint64_t Key;
+ typedef Prob Value;
+ uint64_t GetKey() const {
+ return key;
+ }
+};
- ProbBackoff &Unknown() { return unigram_[0]; }
+#pragma pack(pop)
- void LoadedBinary() {}
+class LongestPointer {
+ public:
+ explicit LongestPointer(const float &to) : to_(&to) {}
- // For building.
- ProbBackoff *Raw() { return unigram_; }
+ LongestPointer() : to_(NULL) {}
- private:
- ProbBackoff *unigram_;
- };
+ bool Found() const {
+ return to_ != NULL;
+ }
- Unigram unigram;
+ float Prob() const {
+ return *to_;
+ }
- void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const {
- const ProbBackoff &entry = unigram.Lookup(word);
- util::FloatEnc val;
- val.f = entry.prob;
- ret.independent_left = (val.i & util::kSignBit);
- ret.extend_left = static_cast<uint64_t>(word);
- val.i |= util::kSignBit;
- ret.prob = val.f;
- backoff = entry.backoff;
- next = static_cast<Node>(word);
- }
+ private:
+ const float *to_;
};
-template <class MiddleT, class LongestT> class TemplateHashedSearch : public HashedSearch {
+template <class Value> class HashedSearch {
public:
- typedef MiddleT Middle;
+ typedef uint64_t Node;
- typedef LongestT Longest;
- Longest longest;
+ typedef typename Value::ProbingProxy UnigramPointer;
+ typedef typename Value::ProbingProxy MiddlePointer;
+ typedef ::lm::ngram::detail::LongestPointer LongestPointer;
+ static const ModelType kModelType = Value::kProbingModelType;
+ static const bool kDifferentRest = Value::kDifferentRest;
static const unsigned int kVersion = 0;
// TODO: move probing_multiplier here with next binary file format update.
@@ -89,64 +84,55 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
- template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing);
+ void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing);
- typedef typename std::vector<Middle>::const_iterator MiddleIter;
+ void LoadedBinary();
- MiddleIter MiddleBegin() const { return middle_.begin(); }
- MiddleIter MiddleEnd() const { return middle_.end(); }
+ unsigned char Order() const {
+ return middle_.size() + 2;
+ }
- Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
- util::FloatEnc val;
- if (extend_length == 1) {
- val.f = unigram.Lookup(static_cast<uint64_t>(extend_pointer)).prob;
- } else {
- typename Middle::ConstIterator found;
- if (!middle_[extend_length - 2].Find(extend_pointer, found)) {
- std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl;
- abort();
- }
- val.f = found->value.prob;
- }
- val.i |= util::kSignBit;
- prob = val.f;
- return extend_pointer;
+ typename Value::Weights &UnknownUnigram() { return unigram_.Unknown(); }
+
+ UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
+ extend_left = static_cast<uint64_t>(word);
+ next = extend_left;
+ UnigramPointer ret(unigram_.Lookup(word));
+ independent_left = ret.IndependentLeft();
+ return ret;
}
- bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
- node = CombineWordHash(node, word);
+#pragma GCC diagnostic ignored "-Wuninitialized"
+ MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
+ node = extend_pointer;
typename Middle::ConstIterator found;
- if (!middle.Find(node, found)) return false;
- util::FloatEnc enc;
- enc.f = found->value.prob;
- ret.independent_left = (enc.i & util::kSignBit);
- ret.extend_left = node;
- enc.i |= util::kSignBit;
- ret.prob = enc.f;
- backoff = found->value.backoff;
- return true;
+ bool got = middle_[extend_length - 2].Find(extend_pointer, found);
+ assert(got);
+ (void)got;
+ return MiddlePointer(found->value);
}
- void LoadedBinary();
-
- bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const {
+ MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const {
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
- if (!middle.Find(node, found)) return false;
- backoff = found->value.backoff;
- return true;
+ if (!middle_[order_minus_2].Find(node, found)) {
+ independent_left = true;
+ return MiddlePointer();
+ }
+ extend_pointer = node;
+ MiddlePointer ret(found->value);
+ independent_left = ret.IndependentLeft();
+ return ret;
}
- bool LookupLongest(WordIndex word, float &prob, Node &node) const {
+ LongestPointer LookupLongest(WordIndex word, const Node &node) const {
// Sign bit is always on because longest n-grams do not extend left.
- node = CombineWordHash(node, word);
typename Longest::ConstIterator found;
- if (!longest.Find(node, found)) return false;
- prob = found->value.prob;
- return true;
+ if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer();
+ return LongestPointer(found->value.prob);
}
- // Geenrate a node without necessarily checking that it actually exists.
+ // Generate a node without necessarily checking that it actually exists.
// Optionally return false if it's know to not exist.
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
assert(begin != end);
@@ -158,55 +144,54 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
}
private:
- std::vector<Middle> middle_;
-};
+ // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
+ void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
-/* These look like perfect candidates for a template, right? Ancient gcc (4.1
- * on RedHat stale linux) doesn't pack templates correctly. ProbBackoffEntry
- * is a multiple of 8 bytes anyway. ProbEntry is 12 bytes so it's set to pack.
- */
-struct ProbBackoffEntry {
- uint64_t key;
- ProbBackoff value;
- typedef uint64_t Key;
- typedef ProbBackoff Value;
- uint64_t GetKey() const {
- return key;
- }
- static ProbBackoffEntry Make(uint64_t key, ProbBackoff value) {
- ProbBackoffEntry ret;
- ret.key = key;
- ret.value = value;
- return ret;
- }
-};
+ template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
-#pragma pack(push)
-#pragma pack(4)
-struct ProbEntry {
- uint64_t key;
- Prob value;
- typedef uint64_t Key;
- typedef Prob Value;
- uint64_t GetKey() const {
- return key;
- }
- static ProbEntry Make(uint64_t key, Prob value) {
- ProbEntry ret;
- ret.key = key;
- ret.value = value;
- return ret;
- }
-};
+ class Unigram {
+ public:
+ Unigram() {}
-#pragma pack(pop)
+ Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
+ unigram_(static_cast<typename Value::Weights*>(start))
+#ifdef DEBUG
+ , count_(count)
+#endif
+ {}
+
+ static std::size_t Size(uint64_t count) {
+ return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
+ }
+
+ const typename Value::Weights &Lookup(WordIndex index) const {
+#ifdef DEBUG
+ assert(index < count_);
+#endif
+ return unigram_[index];
+ }
+
+ typename Value::Weights &Unknown() { return unigram_[0]; }
+ void LoadedBinary() {}
-struct ProbingHashedSearch : public TemplateHashedSearch<
- util::ProbingHashTable<ProbBackoffEntry, util::IdentityHash>,
- util::ProbingHashTable<ProbEntry, util::IdentityHash> > {
+ // For building.
+ typename Value::Weights *Raw() { return unigram_; }
+
+ private:
+ typename Value::Weights *unigram_;
+#ifdef DEBUG
+ uint64_t count_;
+#endif
+ };
+
+ Unigram unigram_;
+
+ typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
+ std::vector<Middle> middle_;
- static const ModelType kModelType = HASH_PROBING;
+ typedef util::ProbingHashTable<ProbEntry, util::IdentityHash> Longest;
+ Longest longest_;
};
} // namespace detail