diff options
Diffstat (limited to 'klm/lm/search_hashed.cc')
-rw-r--r-- | klm/lm/search_hashed.cc | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 247832b0..1d6fb5be 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -30,7 +30,7 @@ template <class Middle> class ActivateLowerMiddle { // TODO: somehow get text of n-gram for this error message. if (!modify_.UnsafeMutableFind(hash, i)) UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram"); - SetExtension(i->MutableValue().backoff); + SetExtension(i->value.backoff); } private: @@ -65,7 +65,7 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign blank.prob -= unigrams[vocab_ids[1]].backoff; SetExtension(unigrams[vocab_ids[1]].backoff); // Bigram including a unigram's backoff - middle[0].Insert(Middle::Packing::Make(keys[0], blank)); + middle[0].Insert(detail::ProbBackoffEntry::Make(keys[0], blank)); fix = 1; } else { for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]); @@ -74,22 +74,24 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign for (; fix <= n - 3; ++fix) { typename Middle::MutableIterator gotit; if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) { - float &backoff = gotit->MutableValue().backoff; + float &backoff = gotit->value.backoff; SetExtension(backoff); blank.prob -= backoff; } - middle[fix].Insert(Middle::Packing::Make(keys[fix], blank)); + middle[fix].Insert(detail::ProbBackoffEntry::Make(keys[fix], blank)); backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]); } } template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) { + assert(n >= 2); ReadNGramHeader(f, n); - // vocab ids of words in reverse order + // Both vocab_ids and keys are non-empty because n >= 2. + // vocab ids of words in reverse order. std::vector<WordIndex> vocab_ids(n); std::vector<uint64_t> keys(n-1); - typename Store::Packing::Value value; + typename Store::Entry::Value value; typename Middle::MutableIterator found; for (size_t i = 0; i < count; ++i) { ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn); @@ -100,7 +102,7 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams( } // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. util::SetSign(value.prob); - store.Insert(Store::Packing::Make(keys[n-2], value)); + store.Insert(Store::Entry::Make(keys[n-2], value)); // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. int lower; util::FloatEnc fix_prob; @@ -113,9 +115,9 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams( } if (middle[lower].UnsafeMutableFind(keys[lower], found)) { // Turn off sign bit to indicate that it extends left. - fix_prob.f = found->MutableValue().prob; + fix_prob.f = found->value.prob; fix_prob.i &= ~util::kSignBit; - found->MutableValue().prob = fix_prob.f; + found->value.prob = fix_prob.f; // We don't need to recurse further down because this entry already set the bits for lower entries. break; } @@ -147,7 +149,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT, template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. - SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config); + SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config); PositiveProbWarn warn(config.positive_log_probability); |