diff options
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r-- | klm/lm/model.cc | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc index e4c1ec1d..478ebed1 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -46,7 +46,7 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { SetupMemory(start, params.counts, config); - vocab_.LoadedBinary(fd, config.enumerate_vocab); + vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab); search_.LoadedBinary(); } @@ -82,13 +82,18 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.unigram.Unknown().backoff = 0.0; search_.unigram.Unknown().prob = config.unknown_missing_logprob; } - FinishFile(config, kModelType, kVersion, counts, backing_); + FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_); } catch (util::Exception &e) { e << " Byte: " << f.Offset(); throw; } } +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { + util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); + Search::UpdateConfigFromBinary(fd, counts, config); +} + template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) { @@ -114,7 +119,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, } float backoff; // i is the order of the backoff we're looking for. - const Middle *mid_iter = search_.MiddleBegin() + start - 2; + typename Search::MiddleIter mid_iter = search_.MiddleBegin() + start - 2; for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) { if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break; ret.prob += backoff; @@ -134,7 +139,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.LookupUnigram(*context_rbegin, out_state.backoff[0], node, ignored); out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; float *backoff_out = out_state.backoff + 1; - const typename Search::Middle *mid = search_.MiddleBegin(); + typename Search::MiddleIter mid(search_.MiddleBegin()); for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) { if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) { std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words); @@ -161,7 +166,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, // If this function is called, then it does depend on left words. ret.independent_left = false; ret.extend_left = extend_pointer; - const typename Search::Middle *mid_iter = search_.MiddleBegin() + extend_length - 1; + typename Search::MiddleIter mid_iter(search_.MiddleBegin() + extend_length - 1); const WordIndex *i = add_rbegin; for (; ; ++i, ++backoff_out, ++mid_iter) { if (i == add_rend) { @@ -230,7 +235,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, // Ok start by looking up the bigram. const WordIndex *hist_iter = context_rbegin; - const typename Search::Middle *mid_iter = search_.MiddleBegin(); + typename Search::MiddleIter mid_iter(search_.MiddleBegin()); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { // Ran out of history. Typically no backoff, but this could be a blank. |