diff options
Diffstat (limited to 'klm/lm/search_hashed.cc')
| -rw-r--r-- | klm/lm/search_hashed.cc | 22 | 
1 files changed, 12 insertions, 10 deletions
| diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 247832b0..1d6fb5be 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -30,7 +30,7 @@ template <class Middle> class ActivateLowerMiddle {        // TODO: somehow get text of n-gram for this error message.        if (!modify_.UnsafeMutableFind(hash, i))          UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram"); -      SetExtension(i->MutableValue().backoff); +      SetExtension(i->value.backoff);      }    private: @@ -65,7 +65,7 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign      blank.prob -= unigrams[vocab_ids[1]].backoff;      SetExtension(unigrams[vocab_ids[1]].backoff);      // Bigram including a unigram's backoff -    middle[0].Insert(Middle::Packing::Make(keys[0], blank)); +    middle[0].Insert(detail::ProbBackoffEntry::Make(keys[0], blank));      fix = 1;    } else {      for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]); @@ -74,22 +74,24 @@ template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsign    for (; fix <= n - 3; ++fix) {      typename Middle::MutableIterator gotit;      if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) { -      float &backoff = gotit->MutableValue().backoff; +      float &backoff = gotit->value.backoff;        SetExtension(backoff);        blank.prob -= backoff;      } -    middle[fix].Insert(Middle::Packing::Make(keys[fix], blank)); +    middle[fix].Insert(detail::ProbBackoffEntry::Make(keys[fix], blank));      backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]);    }  }  template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) { +  assert(n >= 2);    ReadNGramHeader(f, n); -  // vocab ids of words in reverse order +  // Both vocab_ids and keys are non-empty because n >= 2. +  // vocab ids of words in reverse order.    std::vector<WordIndex> vocab_ids(n);    std::vector<uint64_t> keys(n-1); -  typename Store::Packing::Value value; +  typename Store::Entry::Value value;    typename Middle::MutableIterator found;    for (size_t i = 0; i < count; ++i) {      ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn); @@ -100,7 +102,7 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(      }      // Initially the sign bit is on, indicating it does not extend left.  Most already have this but there might +0.0.        util::SetSign(value.prob); -    store.Insert(Store::Packing::Make(keys[n-2], value)); +    store.Insert(Store::Entry::Make(keys[n-2], value));      // Go back and find the longest right-aligned entry, informing it that it extends left.  Normally this will match immediately, but sometimes SRI is dumb.        int lower;      util::FloatEnc fix_prob; @@ -113,9 +115,9 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(        }        if (middle[lower].UnsafeMutableFind(keys[lower], found)) {          // Turn off sign bit to indicate that it extends left.   -        fix_prob.f = found->MutableValue().prob; +        fix_prob.f = found->value.prob;          fix_prob.i &= ~util::kSignBit; -        found->MutableValue().prob = fix_prob.f; +        found->value.prob = fix_prob.f;          // We don't need to recurse further down because this entry already set the bits for lower entries.            break;        } @@ -147,7 +149,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT,  template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {    // TODO: fix sorted. -  SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config); +  SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config);    PositiveProbWarn warn(config.positive_log_probability); | 
