diff options
| author | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 | 
|---|---|---|
| committer | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 | 
| commit | 3d8d656fa7911524e0e6885647173474524e0784 (patch) | |
| tree | 81b1ee2fcb67980376d03f0aa48e42e53abff222 /klm/lm/model.cc | |
| parent | be7f57fdd484e063775d7abf083b9fa4c403b610 (diff) | |
| parent | 96fedabebafe7a38a6d5928be8fff767e411d705 (diff) | |
fixed conflicts
Diffstat (limited to 'klm/lm/model.cc')
| -rw-r--r-- | klm/lm/model.cc | 54 | 
1 files changed, 31 insertions, 23 deletions
| diff --git a/klm/lm/model.cc b/klm/lm/model.cc index b46333a4..a40fd2fb 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -12,6 +12,7 @@  #include <functional>  #include <numeric>  #include <cmath> +#include <limits>  namespace lm {  namespace ngram { @@ -19,23 +20,24 @@ namespace detail {  template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType; -template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) { +template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {    return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);  }  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) { +  size_t goal_size = util::CheckOverflow(Size(counts, config));    uint8_t *start = static_cast<uint8_t*>(base);    size_t allocated = VocabularyT::Size(counts[0], config);    vocab_.SetupMemory(start, allocated, counts[0], config);    start += allocated;    start = search_.SetupMemory(start, counts, config); -  if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config)); +  if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);  }  template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {    LoadLM(file, config, *this); -  // g++ prints warnings unless these are fully initialized.   +  // g++ prints warnings unless these are fully initialized.    State begin_sentence = State();    begin_sentence.length = 1;    begin_sentence.words[0] = vocab_.BeginSentence(); @@ -49,38 +51,43 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge  }  namespace { -void CheckMaxOrder(size_t order) { -  UTIL_THROW_IF(order > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << order << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ".  " << KENLM_ORDER_MESSAGE); +void CheckCounts(const std::vector<uint64_t> &counts) { +  UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ".  " << KENLM_ORDER_MESSAGE); +  if (sizeof(uint64_t) > sizeof(std::size_t)) { +    for (std::vector<uint64_t>::const_iterator i = counts.begin(); i != counts.end(); ++i) { +      UTIL_THROW_IF(*i > static_cast<uint64_t>(std::numeric_limits<size_t>::max()), util::OverflowException, "This model has " << *i << " " << (i - counts.begin() + 1) << "-grams which is too many for 32-bit machines."); +    } +  }  }  } // namespace  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { -  CheckMaxOrder(params.counts.size()); +  CheckCounts(params.counts);    SetupMemory(start, params.counts, config);    vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);    search_.LoadedBinary();  }  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { -  // Backing file is the ARPA.  Steal it so we can make the backing file the mmap output if any.   -  util::FilePiece f(backing_.file.release(), file, config.messages); +  // Backing file is the ARPA.  Steal it so we can make the backing file the mmap output if any. +  util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());    try {      std::vector<uint64_t> counts;      // File counts do not include pruned trigrams that extend to quadgrams etc.   These will be fixed by search_.      ReadARPACounts(f, counts); -    CheckMaxOrder(counts.size()); +    CheckCounts(counts);      if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");      if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); -    std::size_t vocab_size = VocabularyT::Size(counts[0], config); -    // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs.   +    std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); +    // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs.      vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);      if (config.write_mmap) {        WriteWordsWrapper wrap(config.enumerate_vocab);        vocab_.ConfigureEnumerate(&wrap, counts[0]);        search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); -      wrap.Write(backing_.file.get()); +      wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));      } else {        vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);        search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); @@ -88,7 +95,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT      if (!vocab_.SawUnk()) {        assert(config.unknown_missing != THROW_UP); -      // Default probabilities for unknown.   +      // Default probabilities for unknown.        search_.UnknownUnigram().backoff = 0.0;        search_.UnknownUnigram().prob = config.unknown_missing_logprob;      } @@ -140,7 +147,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,  }  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { -  // Generate a state from context.   +  // Generate a state from context.    context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);    if (context_rend == context_rbegin) {      out_state.length = 0; @@ -184,7 +191,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,      ret.rest = ptr.Rest();      ret.prob = ptr.Prob();      ret.extend_left = extend_pointer; -    // If this function is called, then it does depend on left words.    +    // If this function is called, then it does depend on left words.      ret.independent_left = false;    }    float subtract_me = ret.rest; @@ -192,7 +199,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,    next_use = extend_length;    ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret);    next_use -= extend_length; -  // Charge backoffs.   +  // Charge backoffs.    for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;    ret.prob -= subtract_me;    ret.rest -= subtract_me; @@ -202,7 +209,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,  namespace {  // Do a paraonoid copy of history, assuming new_word has already been copied  // (hence the -1).  out_state.length could be zero so I avoided using -// std::copy.    +// std::copy.  void CopyRemainingHistory(const WordIndex *from, State &out_state) {    WordIndex *out = out_state.words + 1;    const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1; @@ -210,18 +217,19 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) {  }  } // namespace -/* Ugly optimized function.  Produce a score excluding backoff.   - * The search goes in increasing order of ngram length.   +/* Ugly optimized function.  Produce a score excluding backoff. + * The search goes in increasing order of ngram length.   * Context goes backward, so context_begin is the word immediately preceeding - * new_word.   + * new_word.   */  template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(      const WordIndex *const context_rbegin,      const WordIndex *const context_rend,      const WordIndex new_word,      State &out_state) const { +  assert(new_word < vocab_.Bound());    FullScoreReturn ret; -  // ret.ngram_length contains the last known non-blank ngram length.   +  // ret.ngram_length contains the last known non-blank ngram length.    ret.ngram_length = 1;    typename Search::Node node; @@ -230,9 +238,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,    ret.prob = uni.Prob();    ret.rest = uni.Rest(); -  // This is the length of the context that should be used for continuation to the right.   +  // This is the length of the context that should be used for continuation to the right.    out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; -  // We'll write the word anyway since it will probably be used and does no harm being there.   +  // We'll write the word anyway since it will probably be used and does no harm being there.    out_state.words[0] = new_word;    if (context_rbegin == context_rend) return ret; | 
