summaryrefslogtreecommitdiff
path: root/klm/lm/search_hashed.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_hashed.cc')
-rw-r--r--klm/lm/search_hashed.cc79
1 files changed, 64 insertions, 15 deletions
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 82c53ec8..334adf12 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -1,10 +1,12 @@
#include "lm/search_hashed.hh"
+#include "lm/binary_format.hh"
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/read_arpa.hh"
#include "lm/vocab.hh"
+#include "util/bit_packing.hh"
#include "util/file_piece.hh"
#include <string>
@@ -48,30 +50,77 @@ class ActivateUnigram {
ProbBackoff *modify_;
};
-template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) {
-
- ReadNGramHeader(f, n);
+template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsigned int n, const uint64_t *keys, const WordIndex *vocab_ids, ProbBackoff *unigrams, std::vector<Middle> &middle) {
ProbBackoff blank;
- blank.prob = kBlankProb;
- blank.backoff = kBlankBackoff;
+ blank.backoff = kNoExtensionBackoff;
+ // Fix SRI's stupidity.
+ // Note that negative_lower_prob is the negative of the probability (so it's currently >= 0). We still want the sign bit off to indicate left extension, so I just do -= on the backoffs.
+ blank.prob = negative_lower_prob;
+ // An entry was found at lower (order lower + 2).
+ // We need to insert blanks starting at lower + 1 (order lower + 3).
+ unsigned int fix = static_cast<unsigned int>(lower + 1);
+ uint64_t backoff_hash = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids[1]), vocab_ids[2]);
+ if (fix == 0) {
+ // Insert a missing bigram.
+ blank.prob -= unigrams[vocab_ids[1]].backoff;
+ SetExtension(unigrams[vocab_ids[1]].backoff);
+ // Bigram including a unigram's backoff
+ middle[0].Insert(Middle::Packing::Make(keys[0], blank));
+ fix = 1;
+ } else {
+ for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]);
+ }
+ // fix >= 1. Insert trigrams and above.
+ for (; fix <= n - 3; ++fix) {
+ typename Middle::MutableIterator gotit;
+ if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) {
+ float &backoff = gotit->MutableValue().backoff;
+ SetExtension(backoff);
+ blank.prob -= backoff;
+ }
+ middle[fix].Insert(Middle::Packing::Make(keys[fix], blank));
+ backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]);
+ }
+}
+
+template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) {
+ ReadNGramHeader(f, n);
// vocab ids of words in reverse order
WordIndex vocab_ids[n];
uint64_t keys[n - 1];
typename Store::Packing::Value value;
- typename Middle::ConstIterator found;
+ typename Middle::MutableIterator found;
for (size_t i = 0; i < count; ++i) {
ReadNGram(f, n, vocab, vocab_ids, value, warn);
+
keys[0] = detail::CombineWordHash(static_cast<uint64_t>(*vocab_ids), vocab_ids[1]);
for (unsigned int h = 1; h < n - 1; ++h) {
keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);
}
+ // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
+ util::SetSign(value.prob);
store.Insert(Store::Packing::Make(keys[n-2], value));
- // Go back and insert blanks.
- for (int lower = n - 3; lower >= 0; --lower) {
- if (middle[lower].Find(keys[lower], found)) break;
- middle[lower].Insert(Middle::Packing::Make(keys[lower], blank));
+ // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
+ int lower;
+ util::FloatEnc fix_prob;
+ for (lower = n - 3; ; --lower) {
+ if (lower == -1) {
+ fix_prob.f = unigrams[vocab_ids[0]].prob;
+ fix_prob.i &= ~util::kSignBit;
+ unigrams[vocab_ids[0]].prob = fix_prob.f;
+ break;
+ }
+ if (middle[lower].UnsafeMutableFind(keys[lower], found)) {
+ // Turn off sign bit to indicate that it extends left.
+ fix_prob.f = found->MutableValue().prob;
+ fix_prob.i &= ~util::kSignBit;
+ found->MutableValue().prob = fix_prob.f;
+ // We don't need to recurse further down because this entry already set the bits for lower entries.
+ break;
+ }
}
+ if (lower != static_cast<int>(n) - 3) FixSRI(lower, fix_prob.f, n, keys, vocab_ids, unigrams, middle);
activate(vocab_ids, n);
}
@@ -107,15 +156,15 @@ template <class MiddleT, class LongestT> template <class Voc> void TemplateHashe
try {
if (counts.size() > 2) {
- ReadNGrams(f, 2, counts[1], vocab, middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn);
+ ReadNGrams(f, 2, counts[1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn);
}
for (unsigned int n = 3; n < counts.size(); ++n) {
- ReadNGrams(f, n, counts[n-1], vocab, middle_, ActivateLowerMiddle<Middle>(middle_[n-3]), middle_[n-2], warn);
+ ReadNGrams(f, n, counts[n-1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_[n-3]), middle_[n-2], warn);
}
if (counts.size() > 2) {
- ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateLowerMiddle<Middle>(middle_.back()), longest, warn);
+ ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_.back()), longest, warn);
} else {
- ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateUnigram(unigram.Raw()), longest, warn);
+ ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), longest, warn);
}
} catch (util::ProbingSizeException &e) {
UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n");
@@ -133,7 +182,7 @@ template <class MiddleT, class LongestT> void TemplateHashedSearch<MiddleT, Long
template class TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>;
-template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab, Backing &backing);
+template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab, Backing &backing);
} // namespace detail
} // namespace ngram