summaryrefslogtreecommitdiff
path: root/klm/lm/model.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r--klm/lm/model.cc25
1 files changed, 11 insertions, 14 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index f0579c0c..a1d10b3d 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -44,17 +44,13 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge
begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff;
State null_context = State();
null_context.valid_length_ = 0;
- P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2);
+ P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2);
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
SetupMemory(start, params.counts, config);
vocab_.LoadedBinary(fd, config.enumerate_vocab);
- search_.unigram.LoadedBinary();
- for (typename std::vector<Middle>::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) {
- i->LoadedBinary();
- }
- search_.longest.LoadedBinary();
+ search_.LoadedBinary();
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
@@ -116,8 +112,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
}
float backoff;
// i is the order of the backoff we're looking for.
- for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
- if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
+ const Middle *mid_iter = search_.MiddleBegin() + start - 2;
+ for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) {
+ if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break;
ret.prob += backoff;
}
return ret;
@@ -135,7 +132,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node);
out_state.valid_length_ = HasExtension(out_state.backoff_[0]) ? 1 : 0;
float *backoff_out = out_state.backoff_ + 1;
- const typename Search::Middle *mid = &*search_.middle.begin();
+ const typename Search::Middle *mid = search_.MiddleBegin();
for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) {
if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) {
std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_);
@@ -183,7 +180,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
// Ok now we now that the bigram contains known words. Start by looking it up.
const WordIndex *hist_iter = context_rbegin;
- typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin();
+ const typename Search::Middle *mid_iter = search_.MiddleBegin();
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == context_rend) {
// Ran out of history. Typically no backoff, but this could be a blank.
@@ -192,7 +189,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
return ret;
}
- if (mid_iter == search_.middle.end()) break;
+ if (mid_iter == search_.MiddleEnd()) break;
float revert = ret.prob;
if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) {
@@ -227,9 +224,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
return ret;
}
-template class GenericModel<ProbingHashedSearch, ProbingVocabulary>;
-template class GenericModel<SortedHashedSearch, SortedVocabulary>;
-template class GenericModel<trie::TrieSearch, SortedVocabulary>;
+template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; // HASH_PROBING
+template class GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary>; // TRIE_SORTED
+template class GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary>; // TRIE_SORTED_QUANT
} // namespace detail
} // namespace ngram