summaryrefslogtreecommitdiff
path: root/klm/lm/model.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r--klm/lm/model.cc17
1 files changed, 11 insertions, 6 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index e4c1ec1d..478ebed1 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -46,7 +46,7 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
SetupMemory(start, params.counts, config);
- vocab_.LoadedBinary(fd, config.enumerate_vocab);
+ vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
search_.LoadedBinary();
}
@@ -82,13 +82,18 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.unigram.Unknown().backoff = 0.0;
search_.unigram.Unknown().prob = config.unknown_missing_logprob;
}
- FinishFile(config, kModelType, kVersion, counts, backing_);
+ FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
} catch (util::Exception &e) {
e << " Byte: " << f.Offset();
throw;
}
}
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
+ util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config));
+ Search::UpdateConfigFromBinary(fd, counts, config);
+}
+
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) {
@@ -114,7 +119,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
}
float backoff;
// i is the order of the backoff we're looking for.
- const Middle *mid_iter = search_.MiddleBegin() + start - 2;
+ typename Search::MiddleIter mid_iter = search_.MiddleBegin() + start - 2;
for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) {
if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break;
ret.prob += backoff;
@@ -134,7 +139,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.LookupUnigram(*context_rbegin, out_state.backoff[0], node, ignored);
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
float *backoff_out = out_state.backoff + 1;
- const typename Search::Middle *mid = search_.MiddleBegin();
+ typename Search::MiddleIter mid(search_.MiddleBegin());
for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) {
if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) {
std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
@@ -161,7 +166,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
// If this function is called, then it does depend on left words.
ret.independent_left = false;
ret.extend_left = extend_pointer;
- const typename Search::Middle *mid_iter = search_.MiddleBegin() + extend_length - 1;
+ typename Search::MiddleIter mid_iter(search_.MiddleBegin() + extend_length - 1);
const WordIndex *i = add_rbegin;
for (; ; ++i, ++backoff_out, ++mid_iter) {
if (i == add_rend) {
@@ -230,7 +235,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
// Ok start by looking up the bigram.
const WordIndex *hist_iter = context_rbegin;
- const typename Search::Middle *mid_iter = search_.MiddleBegin();
+ typename Search::MiddleIter mid_iter(search_.MiddleBegin());
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == context_rend) {
// Ran out of history. Typically no backoff, but this could be a blank.