diff options
Diffstat (limited to 'klm/lm')
| -rw-r--r-- | klm/lm/binary_format.cc | 21 | ||||
| -rw-r--r-- | klm/lm/config.cc | 1 | ||||
| -rw-r--r-- | klm/lm/config.hh | 59 | ||||
| -rw-r--r-- | klm/lm/left.hh | 66 | ||||
| -rw-r--r-- | klm/lm/max_order.hh | 5 | ||||
| -rw-r--r-- | klm/lm/model.cc | 33 | ||||
| -rw-r--r-- | klm/lm/search_hashed.cc | 8 | ||||
| -rw-r--r-- | klm/lm/search_hashed.hh | 2 | ||||
| -rw-r--r-- | klm/lm/search_trie.cc | 47 | ||||
| -rw-r--r-- | klm/lm/vocab.cc | 7 | ||||
| -rw-r--r-- | klm/lm/vocab.hh | 5 | 
11 files changed, 134 insertions, 120 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index efa67056..39c4a9b6 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -16,11 +16,11 @@ namespace ngram {  namespace {  const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";  const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; -// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).  +// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).  const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";  const long int kMagicVersion = 5; -// Old binary files built on 32-bit machines have this header.   +// Old binary files built on 32-bit machines have this header.  // TODO: eliminate with next binary release.  struct OldSanity {    char magic[sizeof(kMagicBytes)]; @@ -39,7 +39,7 @@ struct OldSanity {  }; -// Test values aligned to 8 bytes.     +// Test values aligned to 8 bytes.  struct Sanity {    char magic[ALIGN8(sizeof(kMagicBytes))];    float zero_f, one_f, minus_half_f; @@ -101,7 +101,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_  uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {    std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;    if (config.write_mmap) { -    // Grow the file to accomodate the search, using zeros.   +    // Grow the file to accomodate the search, using zeros.      try {        util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);      } catch (util::ErrnoException &e) { @@ -114,7 +114,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t        return reinterpret_cast<uint8_t*>(backing.search.get());      }      // mmap it now. -    // We're skipping over the header and vocab for the search space mmap.  mmap likes page aligned offsets, so some arithmetic to round the offset down.   +    // We're skipping over the header and vocab for the search space mmap.  mmap likes page aligned offsets, so some arithmetic to round the offset down.      std::size_t page_size = util::SizePage();      std::size_t alignment_cruft = adjusted_vocab % page_size;      backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); @@ -122,7 +122,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t    } else {      util::MapAnonymous(memory_size, backing.search);      return reinterpret_cast<uint8_t*>(backing.search.get()); -  }  +  }  }  void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { @@ -140,7 +140,7 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_        util::FSyncOrThrow(backing.file.get());        break;    } -  // header and vocab share the same mmap.  The header is written here because we know the counts.   +  // header and vocab share the same mmap.  The header is written here because we know the counts.    Parameters params = Parameters();    params.counts = counts;    params.fixed.order = counts.size(); @@ -160,7 +160,7 @@ namespace detail {  bool IsBinaryFormat(int fd) {    const uint64_t size = util::SizeFile(fd);    if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false; -  // Try reading the header.   +  // Try reading the header.    util::scoped_memory memory;    try {      util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory); @@ -214,7 +214,7 @@ void SeekPastHeader(int fd, const Parameters ¶ms) {  uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) {    const uint64_t file_size = util::SizeFile(backing.file.get()); -  // The header is smaller than a page, so we have to map the whole header as well.   +  // The header is smaller than a page, so we have to map the whole header as well.    std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);    if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)      UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); @@ -233,7 +233,8 @@ void ComplainAboutARPA(const Config &config, ModelType model_type) {    if (config.write_mmap || !config.messages) return;    if (config.arpa_complain == Config::ALL) {      *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; -  } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) { +  } else if (config.arpa_complain == Config::EXPENSIVE && +             (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {      *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive.  Save time by building a binary format." << std::endl;    }  } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index f9d988ca..9520c41c 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -6,6 +6,7 @@ namespace lm {  namespace ngram {  Config::Config() : +  show_progress(true),    messages(&std::cerr),    enumerate_vocab(NULL),    unknown_missing(COMPLAIN), diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 739cee9c..0de7b7c6 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -11,46 +11,52 @@  /* Configuration for ngram model.  Separate header to reduce pollution. */  namespace lm { -   +  class EnumerateVocab;  namespace ngram {  struct Config { -  // EFFECTIVE FOR BOTH ARPA AND BINARY READS  +  // EFFECTIVE FOR BOTH ARPA AND BINARY READS + +  // (default true) print progress bar to messages +  bool show_progress;    // Where to log messages including the progress bar.  Set to NULL for    // silence.    std::ostream *messages; +  std::ostream *ProgressMessages() const { +    return show_progress ? messages : 0; +  } +    // This will be called with every string in the vocabulary.  See    // enumerate_vocab.hh for more detail.  Config does not take ownership; you -  // are still responsible for deleting it (or stack allocating).   +  // are still responsible for deleting it (or stack allocating).    EnumerateVocab *enumerate_vocab; -    // ONLY EFFECTIVE WHEN READING ARPA -  // What to do when <unk> isn't in the provided model.  +  // What to do when <unk> isn't in the provided model.    WarningAction unknown_missing; -  // What to do when <s> or </s> is missing from the model.  -  // If THROW_UP, the exception will be of type util::SpecialWordMissingException.   +  // What to do when <s> or </s> is missing from the model. +  // If THROW_UP, the exception will be of type util::SpecialWordMissingException.    WarningAction sentence_marker_missing;    // What to do with a positive log probability.  For COMPLAIN and SILENT, map -  // to 0.   +  // to 0.    WarningAction positive_log_probability; -  // The probability to substitute for <unk> if it's missing from the model.   +  // The probability to substitute for <unk> if it's missing from the model.    // No effect if the model has <unk> or unknown_missing == THROW_UP.    float unknown_missing_logprob;    // Size multiplier for probing hash table.  Must be > 1.  Space is linear in    // this.  Time is probing_multiplier / (probing_multiplier - 1).  No effect -  // for sorted variant.   +  // for sorted variant.    // If you find yourself setting this to a low number, consider using the -  // TrieModel which has lower memory consumption.   +  // TrieModel which has lower memory consumption.    float probing_multiplier;    // Amount of memory to use for building.  The actual memory usage will be @@ -58,10 +64,10 @@ struct Config {    // models.    std::size_t building_memory; -  // Template for temporary directory appropriate for passing to mkdtemp.   +  // Template for temporary directory appropriate for passing to mkdtemp.    // The characters XXXXXX are appended before passing to mkdtemp.  Only    // applies to trie.  If NULL, defaults to write_mmap.  If that's NULL, -  // defaults to input file name.   +  // defaults to input file name.    const char *temporary_directory_prefix;    // Level of complaining to do when loading from ARPA instead of binary format. @@ -69,49 +75,46 @@ struct Config {    ARPALoadComplain arpa_complain;    // While loading an ARPA file, also write out this binary format file.  Set -  // to NULL to disable.   +  // to NULL to disable.    const char *write_mmap;    enum WriteMethod { -    WRITE_MMAP, // Map the file directly.   -    WRITE_AFTER // Write after we're done.   +    WRITE_MMAP, // Map the file directly. +    WRITE_AFTER // Write after we're done.    };    WriteMethod write_method; -  // Include the vocab in the binary file?  Only effective if write_mmap != NULL.   +  // Include the vocab in the binary file?  Only effective if write_mmap != NULL.    bool include_vocab; -  // Left rest options.  Only used when the model includes rest costs.   +  // Left rest options.  Only used when the model includes rest costs.    enum RestFunction {      REST_MAX,   // Maximum of any score to the left -    REST_LOWER, // Use lower-order files given below.   +    REST_LOWER, // Use lower-order files given below.    };    RestFunction rest_function; -  // Only used for REST_LOWER.   +  // Only used for REST_LOWER.    std::vector<std::string> rest_lower_files; -    // Quantization options.  Only effective for QuantTrieModel.  One value is    // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used -  // to quantize (and one of the remaining backoffs will be 0).   +  // to quantize (and one of the remaining backoffs will be 0).    uint8_t prob_bits, backoff_bits;    // Bhiksha compression (simple form).  Only works with trie.    uint8_t pointer_bhiksha_bits; -   -   +    // ONLY EFFECTIVE WHEN READING BINARY -   +    // How to get the giant array into memory: lazy mmap, populate, read etc. -  // See util/mmap.hh for details of MapMethod.   +  // See util/mmap.hh for details of MapMethod.    util::LoadMethod load_method; - -  // Set defaults.  +  // Set defaults.    Config();  }; diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 8c27232e..85c1ea37 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -51,36 +51,36 @@ namespace ngram {  template <class M> class RuleScore {    public: -    explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { +    explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) {        out.left.length = 0;        out.right.length = 0;      }      void BeginSentence() { -      out_.right = model_.BeginSentenceState(); -      // out_.left is empty. +      out_->right = model_.BeginSentenceState(); +      // out_->left is empty.        left_done_ = true;      }      void Terminal(WordIndex word) { -      State copy(out_.right); -      FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); +      State copy(out_->right); +      FullScoreReturn ret(model_.FullScore(copy, word, out_->right));        if (left_done_) { prob_ += ret.prob; return; }        if (ret.independent_left) {          prob_ += ret.prob;          left_done_ = true;          return;        } -      out_.left.pointers[out_.left.length++] = ret.extend_left; +      out_->left.pointers[out_->left.length++] = ret.extend_left;        prob_ += ret.rest; -      if (out_.right.length != copy.length + 1) +      if (out_->right.length != copy.length + 1)          left_done_ = true;      }      // Faster version of NonTerminal for the case where the rule begins with a non-terminal.        void BeginNonTerminal(const ChartState &in, float prob = 0.0) {        prob_ = prob; -      out_ = in; +      *out_ = in;        left_done_ = in.left.full;      } @@ -89,23 +89,23 @@ template <class M> class RuleScore {        if (!in.left.length) {          if (in.left.full) { -          for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; +          for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i;            left_done_ = true; -          out_.right = in.right; +          out_->right = in.right;          }          return;        } -      if (!out_.right.length) { -        out_.right = in.right; +      if (!out_->right.length) { +        out_->right = in.right;          if (left_done_) {            prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);            return;          } -        if (out_.left.length) { +        if (out_->left.length) {            left_done_ = true;          } else { -          out_.left = in.left; +          out_->left = in.left;            left_done_ = in.left.full;          }          return; @@ -113,10 +113,10 @@ template <class M> class RuleScore {        float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1];        float *back = backoffs, *back2 = backoffs2; -      unsigned char next_use = out_.right.length; +      unsigned char next_use = out_->right.length;        // First word -      if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; +      if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return;        // Words after the first, so extending a bigram to begin with        for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { @@ -127,54 +127,58 @@ template <class M> class RuleScore {        if (in.left.full) {          for (const float *i = back; i != back + next_use; ++i) prob_ += *i;          left_done_ = true; -        out_.right = in.right; +        out_->right = in.right;          return;        }        // Right state was minimized, so it's already independent of the new words to the left.          if (in.right.length < in.left.length) { -        out_.right = in.right; +        out_->right = in.right;          return;        }        // Shift exisiting words down.   -      for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { +      for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) {          *(i + in.right.length) = *i;        }        // Add words from in.right.   -      std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); +      std::copy(in.right.words, in.right.words + in.right.length, out_->right.words);        // Assemble backoff composed on the existing state's backoff followed by the new state's backoff.   -      std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); -      std::copy(back, back + next_use, out_.right.backoff + in.right.length); -      out_.right.length = in.right.length + next_use; +      std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff); +      std::copy(back, back + next_use, out_->right.backoff + in.right.length); +      out_->right.length = in.right.length + next_use;      }      float Finish() {        // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.   -      out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1); +      out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1);        return prob_;      }      void Reset() {        prob_ = 0.0;        left_done_ = false; -      out_.left.length = 0; -      out_.right.length = 0; +      out_->left.length = 0; +      out_->right.length = 0; +    } +    void Reset(ChartState &replacement) { +      out_ = &replacement; +      Reset();      }    private:      bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {        ProcessRet(model_.ExtendLeft( -            out_.right.words, out_.right.words + next_use, // Words to extend into +            out_->right.words, out_->right.words + next_use, // Words to extend into              back_in, // Backoffs to use              in.left.pointers[extend_length - 1], extend_length, // Words to be extended              back_out, // Backoffs for the next score              next_use)); // Length of n-gram to use in next scoring.   -      if (next_use != out_.right.length) { +      if (next_use != out_->right.length) {          left_done_ = true;          if (!next_use) {            // Early exit.   -          out_.right = in.right; +          out_->right = in.right;            prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);            return true;          } @@ -193,13 +197,13 @@ template <class M> class RuleScore {          left_done_ = true;          return;        } -      out_.left.pointers[out_.left.length++] = ret.extend_left; +      out_->left.pointers[out_->left.length++] = ret.extend_left;        prob_ += ret.rest;      }      const M &model_; -    ChartState &out_; +    ChartState *out_;      bool left_done_; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index 989f8324..3eb97ccd 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -4,9 +4,6 @@   * (kMaxOrder - 1) * sizeof(float) bytes instead of   * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead   */ -#ifndef KENLM_MAX_ORDER -#define KENLM_MAX_ORDER 6 -#endif  #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile.  In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'.  Otherwise, edit lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile.  In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'.  Otherwise, edit lm/max_order.hh."  #endif diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 2fd20481..a40fd2fb 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -37,7 +37,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT  template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {    LoadLM(file, config, *this); -  // g++ prints warnings unless these are fully initialized.   +  // g++ prints warnings unless these are fully initialized.    State begin_sentence = State();    begin_sentence.length = 1;    begin_sentence.words[0] = vocab_.BeginSentence(); @@ -69,8 +69,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT  }  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { -  // Backing file is the ARPA.  Steal it so we can make the backing file the mmap output if any.   -  util::FilePiece f(backing_.file.release(), file, config.messages); +  // Backing file is the ARPA.  Steal it so we can make the backing file the mmap output if any. +  util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());    try {      std::vector<uint64_t> counts;      // File counts do not include pruned trigrams that extend to quadgrams etc.   These will be fixed by search_. @@ -80,14 +80,14 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT      if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");      std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); -    // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs.   +    // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs.      vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);      if (config.write_mmap) {        WriteWordsWrapper wrap(config.enumerate_vocab);        vocab_.ConfigureEnumerate(&wrap, counts[0]);        search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); -      wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + backing_.search.size()); +      wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));      } else {        vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);        search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); @@ -95,7 +95,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT      if (!vocab_.SawUnk()) {        assert(config.unknown_missing != THROW_UP); -      // Default probabilities for unknown.   +      // Default probabilities for unknown.        search_.UnknownUnigram().backoff = 0.0;        search_.UnknownUnigram().prob = config.unknown_missing_logprob;      } @@ -147,7 +147,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,  }  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { -  // Generate a state from context.   +  // Generate a state from context.    context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);    if (context_rend == context_rbegin) {      out_state.length = 0; @@ -191,7 +191,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,      ret.rest = ptr.Rest();      ret.prob = ptr.Prob();      ret.extend_left = extend_pointer; -    // If this function is called, then it does depend on left words.    +    // If this function is called, then it does depend on left words.      ret.independent_left = false;    }    float subtract_me = ret.rest; @@ -199,7 +199,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,    next_use = extend_length;    ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret);    next_use -= extend_length; -  // Charge backoffs.   +  // Charge backoffs.    for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;    ret.prob -= subtract_me;    ret.rest -= subtract_me; @@ -209,7 +209,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,  namespace {  // Do a paraonoid copy of history, assuming new_word has already been copied  // (hence the -1).  out_state.length could be zero so I avoided using -// std::copy.    +// std::copy.  void CopyRemainingHistory(const WordIndex *from, State &out_state) {    WordIndex *out = out_state.words + 1;    const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1; @@ -217,18 +217,19 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) {  }  } // namespace -/* Ugly optimized function.  Produce a score excluding backoff.   - * The search goes in increasing order of ngram length.   +/* Ugly optimized function.  Produce a score excluding backoff. + * The search goes in increasing order of ngram length.   * Context goes backward, so context_begin is the word immediately preceeding - * new_word.   + * new_word.   */  template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(      const WordIndex *const context_rbegin,      const WordIndex *const context_rend,      const WordIndex new_word,      State &out_state) const { +  assert(new_word < vocab_.Bound());    FullScoreReturn ret; -  // ret.ngram_length contains the last known non-blank ngram length.   +  // ret.ngram_length contains the last known non-blank ngram length.    ret.ngram_length = 1;    typename Search::Node node; @@ -237,9 +238,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,    ret.prob = uni.Prob();    ret.rest = uni.Rest(); -  // This is the length of the context that should be used for continuation to the right.   +  // This is the length of the context that should be used for continuation to the right.    out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; -  // We'll write the word anyway since it will probably be used and does no harm being there.   +  // We'll write the word anyway since it will probably be used and does no harm being there.    out_state.words[0] = new_word;    if (context_rbegin == context_rend) return ret; diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index a1623834..2d6f15b2 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -231,7 +231,7 @@ template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char *  template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {    NoRestBuild build; -  ApplyBuild(f, counts, config, vocab, warn, build); +  ApplyBuild(f, counts, vocab, warn, build);  }  template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { @@ -239,19 +239,19 @@ template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, cons      case Config::REST_MAX:        {          MaxRestBuild build; -        ApplyBuild(f, counts, config, vocab, warn, build); +        ApplyBuild(f, counts, vocab, warn, build);        }        break;      case Config::REST_LOWER:        {          LowerRestBuild<ProbingModel> build(config, counts.size(), vocab); -        ApplyBuild(f, counts, config, vocab, warn, build); +        ApplyBuild(f, counts, vocab, warn, build);        }        break;    }  } -template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { +template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) {    for (WordIndex i = 0; i < counts[0]; ++i) {      build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]);    } diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index a52f107b..00595796 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -147,7 +147,7 @@ template <class Value> class HashedSearch {      // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.        void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); -    template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); +    template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);      class Unigram {        public: diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index debcfd07..1b0d9b26 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -55,7 +55,7 @@ struct ProbPointer {    uint64_t index;  }; -// Array of n-grams and float indices.   +// Array of n-grams and float indices.  class BackoffMessages {    public:      void Init(std::size_t entry_size) { @@ -100,7 +100,7 @@ class BackoffMessages {      void Apply(float *const *const base, RecordReader &reader) {        FinishedAdding();        if (current_ == allocated_) return; -      // We'll also use the same buffer to record messages to blanks that they extend.   +      // We'll also use the same buffer to record messages to blanks that they extend.        WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_);        const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex);        for (reader.Rewind(); reader && (current_ != allocated_); ) { @@ -109,7 +109,7 @@ class BackoffMessages {              ++reader;              break;            case 1: -            // Message but nobody to receive it.  Write it down at the beginning of the buffer so we can inform this blank that it extends.   +            // Message but nobody to receive it.  Write it down at the beginning of the buffer so we can inform this blank that it extends.              for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w;              current_ += entry_size_;              break; @@ -126,7 +126,7 @@ class BackoffMessages {              break;          }        } -      // Now this is a list of blanks that extend right.   +      // Now this is a list of blanks that extend right.        entry_size_ = sizeof(WordIndex) * order;        Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get()));        current_ = (uint8_t*)backing_.get(); @@ -153,7 +153,7 @@ class BackoffMessages {    private:      void FinishedAdding() {        Resize(current_ - (uint8_t*)backing_.get()); -      // Sort requests in same order as files.   +      // Sort requests in same order as files.        std::sort(            util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)),            util::SizedIterator(util::SizedProxy(current_, entry_size_)), @@ -220,7 +220,7 @@ class SRISucks {      }    private: -    // This used to be one array.  Then I needed to separate it by order for quantization to work.   +    // This used to be one array.  Then I needed to separate it by order for quantization to work.      std::vector<float> values_[KENLM_MAX_ORDER - 1];      BackoffMessages messages_[KENLM_MAX_ORDER - 1]; @@ -253,7 +253,7 @@ class FindBlanks {        ++counts_.back();      } -    // Unigrams wrote one past.   +    // Unigrams wrote one past.      void Cleanup() {        --counts_[0];      } @@ -270,15 +270,15 @@ class FindBlanks {      SRISucks &sri_;  }; -// Phase to actually write n-grams to the trie.   +// Phase to actually write n-grams to the trie.  template <class Quant, class Bhiksha> class WriteEntries {    public: -    WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :  +    WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :        contexts_(contexts),        quant_(quant),        unigrams_(unigrams),        middle_(middle), -      longest_(longest),  +      longest_(longest),        bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)),        order_(order),        sri_(sri) {} @@ -328,7 +328,7 @@ struct Gram {    const WordIndex *begin, *end; -  // For queue, this is the direction we want.   +  // For queue, this is the direction we want.    bool operator<(const Gram &other) const {      return std::lexicographical_compare(other.begin, other.end, begin, end);    } @@ -353,7 +353,7 @@ template <class Doing> class BlankManager {          been_length_ = length;          return;        } -      // There are blanks to insert starting with order blank.   +      // There are blanks to insert starting with order blank.        unsigned char blank = cur - to + 1;        UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context.");        const float *lower_basis; @@ -363,7 +363,7 @@ template <class Doing> class BlankManager {          assert(*lower_basis != kBadProb);          doing_.MiddleBlank(blank, to, based_on, *lower_basis);          *pre = *cur; -        // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.   +        // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.          basis_[blank - 1] = kBadProb;        }        *pre = *cur; @@ -377,7 +377,7 @@ template <class Doing> class BlankManager {      unsigned char been_length_;      float basis_[KENLM_MAX_ORDER]; -     +      Doing &doing_;  }; @@ -451,7 +451,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re  }  void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { -  // Fill unigram probabilities.   +  // Fill unigram probabilities.    try {      rewind(file);      for (WordIndex i = 0; i < unigram_count; ++i) { @@ -486,7 +486,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve      util::scoped_memory unigrams;      MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);      FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); -    RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); +    RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder);      fixed_counts = finder.Counts();    }    unigram_file.reset(util::FDOpenOrThrow(unigram_fd)); @@ -504,7 +504,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve      inputs[i-2].Rewind();    }    if (Quant::kTrain) { -    util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing"); +    util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), +                                  config.ProgressMessages(), "Quantizing");      for (unsigned char i = 2; i < counts.size(); ++i) {        TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);      } @@ -519,13 +520,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve    for (unsigned char i = 2; i <= counts.size(); ++i) {      inputs[i-2].Rewind();    } -  // Fill entries except unigram probabilities.   +  // Fill entries except unigram probabilities.    {      WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); -    RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); +    RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);    } -  // Do not disable this error message or else too little state will be returned.  Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.    +  // Do not disable this error message or else too little state will be returned.  Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.    for (unsigned char order = 2; order <= counts.size(); ++order) {      const RecordReader &context = contexts[order - 2];      if (context) { @@ -541,13 +542,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve    }    /* Set ending offsets so the last entry will be sized properly */ -  // Last entry for unigrams was already set.   +  // Last entry for unigrams was already set.    if (out.middle_begin_ != out.middle_end_) {      for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {        i->FinishedLoading((i+1)->InsertIndex(), config);      }      (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config); -  }   +  }  }  template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { @@ -595,7 +596,7 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::Initializ    } else {      temporary_prefix = file;    } -  // At least 1MB sorting memory.   +  // At least 1MB sorting memory.    SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab);    BuildTrie(sorted, counts, config, *this, quant_, vocab, backing); diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 11c27518..fd7f96dc 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {    }    *end_ = hashed;    if (enumerate_) { -    strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); +    void *copied = string_backing_.Allocate(str.size()); +    memcpy(copied, str.data(), str.size()); +    strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size());    }    ++end_;    // This is 1 + the offset where it was inserted to make room for unk.   @@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {  void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {    if (enumerate_) {      if (!strings_to_enumerate_.empty()) { -      util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); +      util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());        util::JointSort(begin_, end_, values);      }      for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) { @@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {        enumerate_->Add(i + 1, strings_to_enumerate_[i]);      }      strings_to_enumerate_.clear(); +    string_backing_.FreeAll();    } else {      util::JointSort(begin_, end_, reorder_vocab + 1);    } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index de54eb06..3902f117 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -4,6 +4,7 @@  #include "lm/enumerate_vocab.hh"  #include "lm/lm_exception.hh"  #include "lm/virtual_interface.hh" +#include "util/pool.hh"  #include "util/probing_hash_table.hh"  #include "util/sorted_uniform.hh"  #include "util/string_piece.hh" @@ -96,7 +97,9 @@ class SortedVocabulary : public base::Vocabulary {      EnumerateVocab *enumerate_;      // Actual strings.  Used only when loading from ARPA and enumerate_ != NULL  -    std::vector<std::string> strings_to_enumerate_; +    util::Pool string_backing_; + +    std::vector<StringPiece> strings_to_enumerate_;  };  #pragma pack(push)  | 
