diff options
author | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-22 16:01:23 -0500 |
---|---|---|
committer | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-22 16:01:23 -0500 |
commit | 597d89c11db53e91bc011eab70fd613bbe6453e8 (patch) | |
tree | 83c87c07d1ff6d3ee4e3b1626f7eddd49c61095b /klm/lm | |
parent | 65e958ff2678a41c22be7171456a63f002ef370b (diff) | |
parent | 201af2acd394415a05072fbd53d42584875aa4b4 (diff) |
Merge branch 'master' of git://github.com/redpony/cdec
Diffstat (limited to 'klm/lm')
-rw-r--r-- | klm/lm/binary_format.cc | 21 | ||||
-rw-r--r-- | klm/lm/config.cc | 1 | ||||
-rw-r--r-- | klm/lm/config.hh | 59 | ||||
-rw-r--r-- | klm/lm/left.hh | 66 | ||||
-rw-r--r-- | klm/lm/max_order.hh | 5 | ||||
-rw-r--r-- | klm/lm/model.cc | 33 | ||||
-rw-r--r-- | klm/lm/search_hashed.cc | 8 | ||||
-rw-r--r-- | klm/lm/search_hashed.hh | 2 | ||||
-rw-r--r-- | klm/lm/search_trie.cc | 47 | ||||
-rw-r--r-- | klm/lm/vocab.cc | 7 | ||||
-rw-r--r-- | klm/lm/vocab.hh | 5 |
11 files changed, 134 insertions, 120 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index efa67056..39c4a9b6 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -16,11 +16,11 @@ namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; -// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). +// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; const long int kMagicVersion = 5; -// Old binary files built on 32-bit machines have this header. +// Old binary files built on 32-bit machines have this header. // TODO: eliminate with next binary release. struct OldSanity { char magic[sizeof(kMagicBytes)]; @@ -39,7 +39,7 @@ struct OldSanity { }; -// Test values aligned to 8 bytes. +// Test values aligned to 8 bytes. struct Sanity { char magic[ALIGN8(sizeof(kMagicBytes))]; float zero_f, one_f, minus_half_f; @@ -101,7 +101,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { - // Grow the file to accomodate the search, using zeros. + // Grow the file to accomodate the search, using zeros. try { util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size); } catch (util::ErrnoException &e) { @@ -114,7 +114,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t return reinterpret_cast<uint8_t*>(backing.search.get()); } // mmap it now. - // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. + // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. std::size_t page_size = util::SizePage(); std::size_t alignment_cruft = adjusted_vocab % page_size; backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); @@ -122,7 +122,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t } else { util::MapAnonymous(memory_size, backing.search); return reinterpret_cast<uint8_t*>(backing.search.get()); - } + } } void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { @@ -140,7 +140,7 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_ util::FSyncOrThrow(backing.file.get()); break; } - // header and vocab share the same mmap. The header is written here because we know the counts. + // header and vocab share the same mmap. The header is written here because we know the counts. Parameters params = Parameters(); params.counts = counts; params.fixed.order = counts.size(); @@ -160,7 +160,7 @@ namespace detail { bool IsBinaryFormat(int fd) { const uint64_t size = util::SizeFile(fd); if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false; - // Try reading the header. + // Try reading the header. util::scoped_memory memory; try { util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory); @@ -214,7 +214,7 @@ void SeekPastHeader(int fd, const Parameters ¶ms) { uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) { const uint64_t file_size = util::SizeFile(backing.file.get()); - // The header is smaller than a page, so we have to map the whole header as well. + // The header is smaller than a page, so we have to map the whole header as well. std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size); if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); @@ -233,7 +233,8 @@ void ComplainAboutARPA(const Config &config, ModelType model_type) { if (config.write_mmap || !config.messages) return; if (config.arpa_complain == Config::ALL) { *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; - } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) { + } else if (config.arpa_complain == Config::EXPENSIVE && + (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; } } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index f9d988ca..9520c41c 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -6,6 +6,7 @@ namespace lm { namespace ngram { Config::Config() : + show_progress(true), messages(&std::cerr), enumerate_vocab(NULL), unknown_missing(COMPLAIN), diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 739cee9c..0de7b7c6 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -11,46 +11,52 @@ /* Configuration for ngram model. Separate header to reduce pollution. */ namespace lm { - + class EnumerateVocab; namespace ngram { struct Config { - // EFFECTIVE FOR BOTH ARPA AND BINARY READS + // EFFECTIVE FOR BOTH ARPA AND BINARY READS + + // (default true) print progress bar to messages + bool show_progress; // Where to log messages including the progress bar. Set to NULL for // silence. std::ostream *messages; + std::ostream *ProgressMessages() const { + return show_progress ? messages : 0; + } + // This will be called with every string in the vocabulary. See // enumerate_vocab.hh for more detail. Config does not take ownership; you - // are still responsible for deleting it (or stack allocating). + // are still responsible for deleting it (or stack allocating). EnumerateVocab *enumerate_vocab; - // ONLY EFFECTIVE WHEN READING ARPA - // What to do when <unk> isn't in the provided model. + // What to do when <unk> isn't in the provided model. WarningAction unknown_missing; - // What to do when <s> or </s> is missing from the model. - // If THROW_UP, the exception will be of type util::SpecialWordMissingException. + // What to do when <s> or </s> is missing from the model. + // If THROW_UP, the exception will be of type util::SpecialWordMissingException. WarningAction sentence_marker_missing; // What to do with a positive log probability. For COMPLAIN and SILENT, map - // to 0. + // to 0. WarningAction positive_log_probability; - // The probability to substitute for <unk> if it's missing from the model. + // The probability to substitute for <unk> if it's missing from the model. // No effect if the model has <unk> or unknown_missing == THROW_UP. float unknown_missing_logprob; // Size multiplier for probing hash table. Must be > 1. Space is linear in // this. Time is probing_multiplier / (probing_multiplier - 1). No effect - // for sorted variant. + // for sorted variant. // If you find yourself setting this to a low number, consider using the - // TrieModel which has lower memory consumption. + // TrieModel which has lower memory consumption. float probing_multiplier; // Amount of memory to use for building. The actual memory usage will be @@ -58,10 +64,10 @@ struct Config { // models. std::size_t building_memory; - // Template for temporary directory appropriate for passing to mkdtemp. + // Template for temporary directory appropriate for passing to mkdtemp. // The characters XXXXXX are appended before passing to mkdtemp. Only // applies to trie. If NULL, defaults to write_mmap. If that's NULL, - // defaults to input file name. + // defaults to input file name. const char *temporary_directory_prefix; // Level of complaining to do when loading from ARPA instead of binary format. @@ -69,49 +75,46 @@ struct Config { ARPALoadComplain arpa_complain; // While loading an ARPA file, also write out this binary format file. Set - // to NULL to disable. + // to NULL to disable. const char *write_mmap; enum WriteMethod { - WRITE_MMAP, // Map the file directly. - WRITE_AFTER // Write after we're done. + WRITE_MMAP, // Map the file directly. + WRITE_AFTER // Write after we're done. }; WriteMethod write_method; - // Include the vocab in the binary file? Only effective if write_mmap != NULL. + // Include the vocab in the binary file? Only effective if write_mmap != NULL. bool include_vocab; - // Left rest options. Only used when the model includes rest costs. + // Left rest options. Only used when the model includes rest costs. enum RestFunction { REST_MAX, // Maximum of any score to the left - REST_LOWER, // Use lower-order files given below. + REST_LOWER, // Use lower-order files given below. }; RestFunction rest_function; - // Only used for REST_LOWER. + // Only used for REST_LOWER. std::vector<std::string> rest_lower_files; - // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used - // to quantize (and one of the remaining backoffs will be 0). + // to quantize (and one of the remaining backoffs will be 0). uint8_t prob_bits, backoff_bits; // Bhiksha compression (simple form). Only works with trie. uint8_t pointer_bhiksha_bits; - - + // ONLY EFFECTIVE WHEN READING BINARY - + // How to get the giant array into memory: lazy mmap, populate, read etc. - // See util/mmap.hh for details of MapMethod. + // See util/mmap.hh for details of MapMethod. util::LoadMethod load_method; - - // Set defaults. + // Set defaults. Config(); }; diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 8c27232e..85c1ea37 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -51,36 +51,36 @@ namespace ngram { template <class M> class RuleScore { public: - explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) { out.left.length = 0; out.right.length = 0; } void BeginSentence() { - out_.right = model_.BeginSentenceState(); - // out_.left is empty. + out_->right = model_.BeginSentenceState(); + // out_->left is empty. left_done_ = true; } void Terminal(WordIndex word) { - State copy(out_.right); - FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); + State copy(out_->right); + FullScoreReturn ret(model_.FullScore(copy, word, out_->right)); if (left_done_) { prob_ += ret.prob; return; } if (ret.independent_left) { prob_ += ret.prob; left_done_ = true; return; } - out_.left.pointers[out_.left.length++] = ret.extend_left; + out_->left.pointers[out_->left.length++] = ret.extend_left; prob_ += ret.rest; - if (out_.right.length != copy.length + 1) + if (out_->right.length != copy.length + 1) left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. void BeginNonTerminal(const ChartState &in, float prob = 0.0) { prob_ = prob; - out_ = in; + *out_ = in; left_done_ = in.left.full; } @@ -89,23 +89,23 @@ template <class M> class RuleScore { if (!in.left.length) { if (in.left.full) { - for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; + for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i; left_done_ = true; - out_.right = in.right; + out_->right = in.right; } return; } - if (!out_.right.length) { - out_.right = in.right; + if (!out_->right.length) { + out_->right = in.right; if (left_done_) { prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1); return; } - if (out_.left.length) { + if (out_->left.length) { left_done_ = true; } else { - out_.left = in.left; + out_->left = in.left; left_done_ = in.left.full; } return; @@ -113,10 +113,10 @@ template <class M> class RuleScore { float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1]; float *back = backoffs, *back2 = backoffs2; - unsigned char next_use = out_.right.length; + unsigned char next_use = out_->right.length; // First word - if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; + if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return; // Words after the first, so extending a bigram to begin with for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { @@ -127,54 +127,58 @@ template <class M> class RuleScore { if (in.left.full) { for (const float *i = back; i != back + next_use; ++i) prob_ += *i; left_done_ = true; - out_.right = in.right; + out_->right = in.right; return; } // Right state was minimized, so it's already independent of the new words to the left. if (in.right.length < in.left.length) { - out_.right = in.right; + out_->right = in.right; return; } // Shift exisiting words down. - for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { + for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) { *(i + in.right.length) = *i; } // Add words from in.right. - std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); + std::copy(in.right.words, in.right.words + in.right.length, out_->right.words); // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. - std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); - std::copy(back, back + next_use, out_.right.backoff + in.right.length); - out_.right.length = in.right.length + next_use; + std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff); + std::copy(back, back + next_use, out_->right.backoff + in.right.length); + out_->right.length = in.right.length + next_use; } float Finish() { // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. - out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1); + out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1); return prob_; } void Reset() { prob_ = 0.0; left_done_ = false; - out_.left.length = 0; - out_.right.length = 0; + out_->left.length = 0; + out_->right.length = 0; + } + void Reset(ChartState &replacement) { + out_ = &replacement; + Reset(); } private: bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { ProcessRet(model_.ExtendLeft( - out_.right.words, out_.right.words + next_use, // Words to extend into + out_->right.words, out_->right.words + next_use, // Words to extend into back_in, // Backoffs to use in.left.pointers[extend_length - 1], extend_length, // Words to be extended back_out, // Backoffs for the next score next_use)); // Length of n-gram to use in next scoring. - if (next_use != out_.right.length) { + if (next_use != out_->right.length) { left_done_ = true; if (!next_use) { // Early exit. - out_.right = in.right; + out_->right = in.right; prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1); return true; } @@ -193,13 +197,13 @@ template <class M> class RuleScore { left_done_ = true; return; } - out_.left.pointers[out_.left.length++] = ret.extend_left; + out_->left.pointers[out_->left.length++] = ret.extend_left; prob_ += ret.rest; } const M &model_; - ChartState &out_; + ChartState *out_; bool left_done_; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index 989f8324..3eb97ccd 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -4,9 +4,6 @@ * (kMaxOrder - 1) * sizeof(float) bytes instead of * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead */ -#ifndef KENLM_MAX_ORDER -#define KENLM_MAX_ORDER 6 -#endif #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 2fd20481..a40fd2fb 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -37,7 +37,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { LoadLM(file, config, *this); - // g++ prints warnings unless these are fully initialized. + // g++ prints warnings unless these are fully initialized. State begin_sentence = State(); begin_sentence.length = 1; begin_sentence.words[0] = vocab_.BeginSentence(); @@ -69,8 +69,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { - // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. - util::FilePiece f(backing_.file.release(), file, config.messages); + // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. + util::FilePiece f(backing_.file.release(), file, config.ProgressMessages()); try { std::vector<uint64_t> counts; // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. @@ -80,14 +80,14 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); - // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); if (config.write_mmap) { WriteWordsWrapper wrap(config.enumerate_vocab); vocab_.ConfigureEnumerate(&wrap, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + backing_.search.size()); + wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config)); } else { vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); @@ -95,7 +95,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT if (!vocab_.SawUnk()) { assert(config.unknown_missing != THROW_UP); - // Default probabilities for unknown. + // Default probabilities for unknown. search_.UnknownUnigram().backoff = 0.0; search_.UnknownUnigram().prob = config.unknown_missing_logprob; } @@ -147,7 +147,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { - // Generate a state from context. + // Generate a state from context. context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); if (context_rend == context_rbegin) { out_state.length = 0; @@ -191,7 +191,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, ret.rest = ptr.Rest(); ret.prob = ptr.Prob(); ret.extend_left = extend_pointer; - // If this function is called, then it does depend on left words. + // If this function is called, then it does depend on left words. ret.independent_left = false; } float subtract_me = ret.rest; @@ -199,7 +199,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, next_use = extend_length; ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret); next_use -= extend_length; - // Charge backoffs. + // Charge backoffs. for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; ret.prob -= subtract_me; ret.rest -= subtract_me; @@ -209,7 +209,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, namespace { // Do a paraonoid copy of history, assuming new_word has already been copied // (hence the -1). out_state.length could be zero so I avoided using -// std::copy. +// std::copy. void CopyRemainingHistory(const WordIndex *from, State &out_state) { WordIndex *out = out_state.words + 1; const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1; @@ -217,18 +217,19 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) { } } // namespace -/* Ugly optimized function. Produce a score excluding backoff. - * The search goes in increasing order of ngram length. +/* Ugly optimized function. Produce a score excluding backoff. + * The search goes in increasing order of ngram length. * Context goes backward, so context_begin is the word immediately preceeding - * new_word. + * new_word. */ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff( const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const { + assert(new_word < vocab_.Bound()); FullScoreReturn ret; - // ret.ngram_length contains the last known non-blank ngram length. + // ret.ngram_length contains the last known non-blank ngram length. ret.ngram_length = 1; typename Search::Node node; @@ -237,9 +238,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, ret.prob = uni.Prob(); ret.rest = uni.Rest(); - // This is the length of the context that should be used for continuation to the right. + // This is the length of the context that should be used for continuation to the right. out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; - // We'll write the word anyway since it will probably be used and does no harm being there. + // We'll write the word anyway since it will probably be used and does no harm being there. out_state.words[0] = new_word; if (context_rbegin == context_rend) return ret; diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index a1623834..2d6f15b2 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -231,7 +231,7 @@ template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { NoRestBuild build; - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { @@ -239,19 +239,19 @@ template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, cons case Config::REST_MAX: { MaxRestBuild build; - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } break; case Config::REST_LOWER: { LowerRestBuild<ProbingModel> build(config, counts.size(), vocab); - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } break; } } -template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { +template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { for (WordIndex i = 0; i < counts[0]; ++i) { build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]); } diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index a52f107b..00595796 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -147,7 +147,7 @@ template <class Value> class HashedSearch { // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); - template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); + template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); class Unigram { public: diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index debcfd07..1b0d9b26 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -55,7 +55,7 @@ struct ProbPointer { uint64_t index; }; -// Array of n-grams and float indices. +// Array of n-grams and float indices. class BackoffMessages { public: void Init(std::size_t entry_size) { @@ -100,7 +100,7 @@ class BackoffMessages { void Apply(float *const *const base, RecordReader &reader) { FinishedAdding(); if (current_ == allocated_) return; - // We'll also use the same buffer to record messages to blanks that they extend. + // We'll also use the same buffer to record messages to blanks that they extend. WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_); const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex); for (reader.Rewind(); reader && (current_ != allocated_); ) { @@ -109,7 +109,7 @@ class BackoffMessages { ++reader; break; case 1: - // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. + // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w; current_ += entry_size_; break; @@ -126,7 +126,7 @@ class BackoffMessages { break; } } - // Now this is a list of blanks that extend right. + // Now this is a list of blanks that extend right. entry_size_ = sizeof(WordIndex) * order; Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get())); current_ = (uint8_t*)backing_.get(); @@ -153,7 +153,7 @@ class BackoffMessages { private: void FinishedAdding() { Resize(current_ - (uint8_t*)backing_.get()); - // Sort requests in same order as files. + // Sort requests in same order as files. std::sort( util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)), util::SizedIterator(util::SizedProxy(current_, entry_size_)), @@ -220,7 +220,7 @@ class SRISucks { } private: - // This used to be one array. Then I needed to separate it by order for quantization to work. + // This used to be one array. Then I needed to separate it by order for quantization to work. std::vector<float> values_[KENLM_MAX_ORDER - 1]; BackoffMessages messages_[KENLM_MAX_ORDER - 1]; @@ -253,7 +253,7 @@ class FindBlanks { ++counts_.back(); } - // Unigrams wrote one past. + // Unigrams wrote one past. void Cleanup() { --counts_[0]; } @@ -270,15 +270,15 @@ class FindBlanks { SRISucks &sri_; }; -// Phase to actually write n-grams to the trie. +// Phase to actually write n-grams to the trie. template <class Quant, class Bhiksha> class WriteEntries { public: - WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : + WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), quant_(quant), unigrams_(unigrams), middle_(middle), - longest_(longest), + longest_(longest), bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)), order_(order), sri_(sri) {} @@ -328,7 +328,7 @@ struct Gram { const WordIndex *begin, *end; - // For queue, this is the direction we want. + // For queue, this is the direction we want. bool operator<(const Gram &other) const { return std::lexicographical_compare(other.begin, other.end, begin, end); } @@ -353,7 +353,7 @@ template <class Doing> class BlankManager { been_length_ = length; return; } - // There are blanks to insert starting with order blank. + // There are blanks to insert starting with order blank. unsigned char blank = cur - to + 1; UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context."); const float *lower_basis; @@ -363,7 +363,7 @@ template <class Doing> class BlankManager { assert(*lower_basis != kBadProb); doing_.MiddleBlank(blank, to, based_on, *lower_basis); *pre = *cur; - // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. + // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. basis_[blank - 1] = kBadProb; } *pre = *cur; @@ -377,7 +377,7 @@ template <class Doing> class BlankManager { unsigned char been_length_; float basis_[KENLM_MAX_ORDER]; - + Doing &doing_; }; @@ -451,7 +451,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re } void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { - // Fill unigram probabilities. + // Fill unigram probabilities. try { rewind(file); for (WordIndex i = 0; i < unigram_count; ++i) { @@ -486,7 +486,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve util::scoped_memory unigrams; MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); - RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); + RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder); fixed_counts = finder.Counts(); } unigram_file.reset(util::FDOpenOrThrow(unigram_fd)); @@ -504,7 +504,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve inputs[i-2].Rewind(); } if (Quant::kTrain) { - util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing"); + util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), + config.ProgressMessages(), "Quantizing"); for (unsigned char i = 2; i < counts.size(); ++i) { TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } @@ -519,13 +520,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); } - // Fill entries except unigram probabilities. + // Fill entries except unigram probabilities. { WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); - RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); + RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer); } - // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. + // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. for (unsigned char order = 2; order <= counts.size(); ++order) { const RecordReader &context = contexts[order - 2]; if (context) { @@ -541,13 +542,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve } /* Set ending offsets so the last entry will be sized properly */ - // Last entry for unigrams was already set. + // Last entry for unigrams was already set. if (out.middle_begin_ != out.middle_end_) { for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { i->FinishedLoading((i+1)->InsertIndex(), config); } (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config); - } + } } template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { @@ -595,7 +596,7 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::Initializ } else { temporary_prefix = file; } - // At least 1MB sorting memory. + // At least 1MB sorting memory. SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab); BuildTrie(sorted, counts, config, *this, quant_, vocab, backing); diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 11c27518..fd7f96dc 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { } *end_ = hashed; if (enumerate_) { - strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); + void *copied = string_backing_.Allocate(str.size()); + memcpy(copied, str.data(), str.size()); + strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size()); } ++end_; // This is 1 + the offset where it was inserted to make room for unk. @@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { if (enumerate_) { if (!strings_to_enumerate_.empty()) { - util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); + util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); util::JointSort(begin_, end_, values); } for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) { @@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { enumerate_->Add(i + 1, strings_to_enumerate_[i]); } strings_to_enumerate_.clear(); + string_backing_.FreeAll(); } else { util::JointSort(begin_, end_, reorder_vocab + 1); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index de54eb06..3902f117 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -4,6 +4,7 @@ #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" +#include "util/pool.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" #include "util/string_piece.hh" @@ -96,7 +97,9 @@ class SortedVocabulary : public base::Vocabulary { EnumerateVocab *enumerate_; // Actual strings. Used only when loading from ARPA and enumerate_ != NULL - std::vector<std::string> strings_to_enumerate_; + util::Pool string_backing_; + + std::vector<StringPiece> strings_to_enumerate_; }; #pragma pack(push) |