summaryrefslogtreecommitdiff
path: root/klm/lm/search_hashed.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_hashed.cc')
-rw-r--r--klm/lm/search_hashed.cc29
1 files changed, 14 insertions, 15 deletions
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 2d6f15b2..62275d27 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -54,7 +54,7 @@ template <class Weights> class ActivateUnigram {
Weights *modify_;
};
-// Find the lower order entry, inserting blanks along the way as necessary.
+// Find the lower order entry, inserting blanks along the way as necessary.
template <class Value> void FindLower(
const std::vector<uint64_t> &keys,
typename Value::Weights &unigram,
@@ -64,7 +64,7 @@ template <class Value> void FindLower(
typename Value::ProbingEntry entry;
// Backoff will always be 0.0. We'll get the probability and rest in another pass.
entry.value.backoff = kNoExtensionBackoff;
- // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
+ // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
for (int lower = keys.size() - 2; ; --lower) {
if (lower == -1) {
between.push_back(&unigram);
@@ -77,11 +77,11 @@ template <class Value> void FindLower(
}
}
-// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.
+// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.
template <class Added, class Build> void AdjustLower(
const Added &added,
const Build &build,
- std::vector<typename Build::Value::Weights *> &between,
+ std::vector<typename Build::Value::Weights *> &between,
const unsigned int n,
const std::vector<WordIndex> &vocab_ids,
typename Build::Value::Weights *unigrams,
@@ -93,14 +93,14 @@ template <class Added, class Build> void AdjustLower(
}
typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
float prob = -fabs(between.back()->prob);
- // Order of the n-gram on which probabilities are based.
+ // Order of the n-gram on which probabilities are based.
unsigned char basis = n - between.size();
assert(basis != 0);
typename Build::Value::Weights **change = &between.back();
// Skip the basis.
--change;
if (basis == 1) {
- // Hallucinate a bigram based on a unigram's backoff and a unigram probability.
+ // Hallucinate a bigram based on a unigram's backoff and a unigram probability.
float &backoff = unigrams[vocab_ids[1]].backoff;
SetExtension(backoff);
prob += backoff;
@@ -128,14 +128,14 @@ template <class Added, class Build> void AdjustLower(
typename std::vector<typename Value::Weights *>::const_iterator i(between.begin());
build.MarkExtends(**i, added);
const typename Value::Weights *longer = *i;
- // Everything has probability but is not marked as extending.
+ // Everything has probability but is not marked as extending.
for (++i; i != between.end(); ++i) {
build.MarkExtends(**i, *longer);
longer = *i;
}
}
-// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds.
+// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds.
template <class Build> void MarkLower(
const std::vector<uint64_t> &keys,
const Build &build,
@@ -144,15 +144,15 @@ template <class Build> void MarkLower(
int start_order,
const typename Build::Value::Weights &longer) {
if (start_order == 0) return;
- typename util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash>::MutableIterator iter;
- // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.
+ // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.
for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) {
if (even_lower == -1) {
build.MarkExtends(unigram, longer);
return;
}
- middle[even_lower].UnsafeMutableFind(keys[even_lower], iter);
- if (!build.MarkExtends(iter->value, longer)) return;
+ if (!build.MarkExtends(
+ middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value,
+ longer)) return;
}
}
@@ -168,7 +168,6 @@ template <class Build, class Activate, class Store> void ReadNGrams(
Store &store,
PositiveProbWarn &warn) {
typedef typename Build::Value Value;
- typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
assert(n >= 2);
ReadNGramHeader(f, n);
@@ -186,7 +185,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(
for (unsigned int h = 1; h < n - 1; ++h) {
keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);
}
- // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
+ // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
util::SetSign(entry.value.prob);
entry.key = keys[n-2];
@@ -203,7 +202,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(
} // namespace
namespace detail {
-
+
template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
std::size_t allocated = Unigram::Size(counts[0]);
unigram_ = Unigram(start, counts[0], allocated);