summaryrefslogtreecommitdiff
path: root/klm/lm/search_hashed.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_hashed.cc')
-rw-r--r--klm/lm/search_hashed.cc22
1 files changed, 12 insertions, 10 deletions
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 247832b0..1d6fb5be 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -30,7 +30,7 @@ template <class Middle> class ActivateLowerMiddle {
// TODO: somehow get text of n-gram for this error message.
if (!modify_.UnsafeMutableFind(hash, i))
UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram");
- SetExtension(i->MutableValue().backoff);
+ SetExtension(i->value.backoff);
}
private:
@@ -65,7 +65,7 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign
blank.prob -= unigrams[vocab_ids[1]].backoff;
SetExtension(unigrams[vocab_ids[1]].backoff);
// Bigram including a unigram's backoff
- middle[0].Insert(Middle::Packing::Make(keys[0], blank));
+ middle[0].Insert(detail::ProbBackoffEntry::Make(keys[0], blank));
fix = 1;
} else {
for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]);
@@ -74,22 +74,24 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign
for (; fix <= n - 3; ++fix) {
typename Middle::MutableIterator gotit;
if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) {
- float &backoff = gotit->MutableValue().backoff;
+ float &backoff = gotit->value.backoff;
SetExtension(backoff);
blank.prob -= backoff;
}
- middle[fix].Insert(Middle::Packing::Make(keys[fix], blank));
+ middle[fix].Insert(detail::ProbBackoffEntry::Make(keys[fix], blank));
backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]);
}
}
template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) {
+ assert(n >= 2);
ReadNGramHeader(f, n);
- // vocab ids of words in reverse order
+ // Both vocab_ids and keys are non-empty because n >= 2.
+ // vocab ids of words in reverse order.
std::vector<WordIndex> vocab_ids(n);
std::vector<uint64_t> keys(n-1);
- typename Store::Packing::Value value;
+ typename Store::Entry::Value value;
typename Middle::MutableIterator found;
for (size_t i = 0; i < count; ++i) {
ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn);
@@ -100,7 +102,7 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
}
// Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
util::SetSign(value.prob);
- store.Insert(Store::Packing::Make(keys[n-2], value));
+ store.Insert(Store::Entry::Make(keys[n-2], value));
// Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
int lower;
util::FloatEnc fix_prob;
@@ -113,9 +115,9 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
}
if (middle[lower].UnsafeMutableFind(keys[lower], found)) {
// Turn off sign bit to indicate that it extends left.
- fix_prob.f = found->MutableValue().prob;
+ fix_prob.f = found->value.prob;
fix_prob.i &= ~util::kSignBit;
- found->MutableValue().prob = fix_prob.f;
+ found->value.prob = fix_prob.f;
// We don't need to recurse further down because this entry already set the bits for lower entries.
break;
}
@@ -147,7 +149,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT,
template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {
// TODO: fix sorted.
- SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config);
+ SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config);
PositiveProbWarn warn(config.positive_log_probability);