summaryrefslogtreecommitdiff
path: root/klm/lm/search_hashed.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_hashed.cc')
-rw-r--r--klm/lm/search_hashed.cc52
1 files changed, 47 insertions, 5 deletions
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 9200aeb6..00d03f4e 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -14,7 +14,41 @@ namespace ngram {
namespace {
-template <class Voc, class Store, class Middle> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector<Middle> &middle, Store &store) {
+/* These are passed to ReadNGrams so that n-grams with zero backoff that appear as context will still be used in state. */
+template <class Middle> class ActivateLowerMiddle {
+ public:
+ explicit ActivateLowerMiddle(Middle &middle) : modify_(middle) {}
+
+ void operator()(const WordIndex *vocab_ids, const unsigned int n) {
+ uint64_t hash = static_cast<WordIndex>(vocab_ids[1]);
+ for (const WordIndex *i = vocab_ids + 2; i < vocab_ids + n; ++i) {
+ hash = detail::CombineWordHash(hash, *i);
+ }
+ typename Middle::MutableIterator i;
+ // TODO: somehow get text of n-gram for this error message.
+ if (!modify_.UnsafeMutableFind(hash, i))
+ UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram");
+ SetExtension(i->MutableValue().backoff);
+ }
+
+ private:
+ Middle &modify_;
+};
+
+class ActivateUnigram {
+ public:
+ explicit ActivateUnigram(ProbBackoff *unigram) : modify_(unigram) {}
+
+ void operator()(const WordIndex *vocab_ids, const unsigned int /*n*/) {
+ // assert(n == 2);
+ SetExtension(modify_[vocab_ids[1]].backoff);
+ }
+
+ private:
+ ProbBackoff *modify_;
+};
+
+template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector<Middle> &middle, Activate activate, Store &store) {
ReadNGramHeader(f, n);
ProbBackoff blank;
@@ -38,6 +72,7 @@ template <class Voc, class Store, class Middle> void ReadNGrams(util::FilePiece
if (middle[lower].Find(keys[lower], found)) break;
middle[lower].Insert(Middle::Packing::Make(keys[lower], blank));
}
+ activate(vocab_ids, n);
}
store.FinishedInserting();
@@ -53,12 +88,19 @@ template <class MiddleT, class LongestT> template <class Voc> void TemplateHashe
Read1Grams(f, counts[0], vocab, unigram.Raw());
try {
- for (unsigned int n = 2; n < counts.size(); ++n) {
- ReadNGrams(f, n, counts[n-1], vocab, middle, middle[n-2]);
+ if (counts.size() > 2) {
+ ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0]);
+ }
+ for (unsigned int n = 3; n < counts.size(); ++n) {
+ ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle<Middle>(middle[n-3]), middle[n-2]);
+ }
+ if (counts.size() > 2) {
+ ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest);
+ } else {
+ ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle<Middle>(middle.back()), longest);
}
- ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, longest);
} catch (util::ProbingSizeException &e) {
- UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces. ");
+ UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n");
}
}