diff options
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r-- | klm/lm/model.cc | 25 |
1 files changed, 11 insertions, 14 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc index f0579c0c..a1d10b3d 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -44,17 +44,13 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; State null_context = State(); null_context.valid_length_ = 0; - P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2); + P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { SetupMemory(start, params.counts, config); vocab_.LoadedBinary(fd, config.enumerate_vocab); - search_.unigram.LoadedBinary(); - for (typename std::vector<Middle>::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) { - i->LoadedBinary(); - } - search_.longest.LoadedBinary(); + search_.LoadedBinary(); } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { @@ -116,8 +112,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, } float backoff; // i is the order of the backoff we're looking for. - for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { - if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; + const Middle *mid_iter = search_.MiddleBegin() + start - 2; + for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) { + if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break; ret.prob += backoff; } return ret; @@ -135,7 +132,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node); out_state.valid_length_ = HasExtension(out_state.backoff_[0]) ? 1 : 0; float *backoff_out = out_state.backoff_ + 1; - const typename Search::Middle *mid = &*search_.middle.begin(); + const typename Search::Middle *mid = search_.MiddleBegin(); for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) { if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) { std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); @@ -183,7 +180,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, // Ok now we now that the bigram contains known words. Start by looking it up. const WordIndex *hist_iter = context_rbegin; - typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin(); + const typename Search::Middle *mid_iter = search_.MiddleBegin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { // Ran out of history. Typically no backoff, but this could be a blank. @@ -192,7 +189,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, return ret; } - if (mid_iter == search_.middle.end()) break; + if (mid_iter == search_.MiddleEnd()) break; float revert = ret.prob; if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) { @@ -227,9 +224,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, return ret; } -template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; -template class GenericModel<SortedHashedSearch, SortedVocabulary>; -template class GenericModel<trie::TrieSearch, SortedVocabulary>; +template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; // HASH_PROBING +template class GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary>; // TRIE_SORTED +template class GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary>; // TRIE_SORTED_QUANT } // namespace detail } // namespace ngram |