diff options
Diffstat (limited to 'klm/lm/search_hashed.cc')
-rw-r--r-- | klm/lm/search_hashed.cc | 29 |
1 files changed, 14 insertions, 15 deletions
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 2d6f15b2..62275d27 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -54,7 +54,7 @@ template <class Weights> class ActivateUnigram { Weights *modify_; }; -// Find the lower order entry, inserting blanks along the way as necessary. +// Find the lower order entry, inserting blanks along the way as necessary. template <class Value> void FindLower( const std::vector<uint64_t> &keys, typename Value::Weights &unigram, @@ -64,7 +64,7 @@ template <class Value> void FindLower( typename Value::ProbingEntry entry; // Backoff will always be 0.0. We'll get the probability and rest in another pass. entry.value.backoff = kNoExtensionBackoff; - // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. + // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. for (int lower = keys.size() - 2; ; --lower) { if (lower == -1) { between.push_back(&unigram); @@ -77,11 +77,11 @@ template <class Value> void FindLower( } } -// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here. +// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here. template <class Added, class Build> void AdjustLower( const Added &added, const Build &build, - std::vector<typename Build::Value::Weights *> &between, + std::vector<typename Build::Value::Weights *> &between, const unsigned int n, const std::vector<WordIndex> &vocab_ids, typename Build::Value::Weights *unigrams, @@ -93,14 +93,14 @@ template <class Added, class Build> void AdjustLower( } typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle; float prob = -fabs(between.back()->prob); - // Order of the n-gram on which probabilities are based. + // Order of the n-gram on which probabilities are based. unsigned char basis = n - between.size(); assert(basis != 0); typename Build::Value::Weights **change = &between.back(); // Skip the basis. --change; if (basis == 1) { - // Hallucinate a bigram based on a unigram's backoff and a unigram probability. + // Hallucinate a bigram based on a unigram's backoff and a unigram probability. float &backoff = unigrams[vocab_ids[1]].backoff; SetExtension(backoff); prob += backoff; @@ -128,14 +128,14 @@ template <class Added, class Build> void AdjustLower( typename std::vector<typename Value::Weights *>::const_iterator i(between.begin()); build.MarkExtends(**i, added); const typename Value::Weights *longer = *i; - // Everything has probability but is not marked as extending. + // Everything has probability but is not marked as extending. for (++i; i != between.end(); ++i) { build.MarkExtends(**i, *longer); longer = *i; } } -// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds. +// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds. template <class Build> void MarkLower( const std::vector<uint64_t> &keys, const Build &build, @@ -144,15 +144,15 @@ template <class Build> void MarkLower( int start_order, const typename Build::Value::Weights &longer) { if (start_order == 0) return; - typename util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash>::MutableIterator iter; - // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code. + // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code. for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) { if (even_lower == -1) { build.MarkExtends(unigram, longer); return; } - middle[even_lower].UnsafeMutableFind(keys[even_lower], iter); - if (!build.MarkExtends(iter->value, longer)) return; + if (!build.MarkExtends( + middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value, + longer)) return; } } @@ -168,7 +168,6 @@ template <class Build, class Activate, class Store> void ReadNGrams( Store &store, PositiveProbWarn &warn) { typedef typename Build::Value Value; - typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle; assert(n >= 2); ReadNGramHeader(f, n); @@ -186,7 +185,7 @@ template <class Build, class Activate, class Store> void ReadNGrams( for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); } - // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. + // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. util::SetSign(entry.value.prob); entry.key = keys[n-2]; @@ -203,7 +202,7 @@ template <class Build, class Activate, class Store> void ReadNGrams( } // namespace namespace detail { - + template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { std::size_t allocated = Unigram::Size(counts[0]); unigram_ = Unigram(start, counts[0], allocated); |