summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/binary_format.cc21
-rw-r--r--klm/lm/config.cc1
-rw-r--r--klm/lm/config.hh59
-rw-r--r--klm/lm/left.hh66
-rw-r--r--klm/lm/max_order.hh5
-rw-r--r--klm/lm/model.cc33
-rw-r--r--klm/lm/search_hashed.cc8
-rw-r--r--klm/lm/search_hashed.hh2
-rw-r--r--klm/lm/search_trie.cc47
-rw-r--r--klm/lm/vocab.cc7
-rw-r--r--klm/lm/vocab.hh5
11 files changed, 134 insertions, 120 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index efa67056..39c4a9b6 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -16,11 +16,11 @@ namespace ngram {
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
-// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
+// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
const long int kMagicVersion = 5;
-// Old binary files built on 32-bit machines have this header.
+// Old binary files built on 32-bit machines have this header.
// TODO: eliminate with next binary release.
struct OldSanity {
char magic[sizeof(kMagicBytes)];
@@ -39,7 +39,7 @@ struct OldSanity {
};
-// Test values aligned to 8 bytes.
+// Test values aligned to 8 bytes.
struct Sanity {
char magic[ALIGN8(sizeof(kMagicBytes))];
float zero_f, one_f, minus_half_f;
@@ -101,7 +101,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
if (config.write_mmap) {
- // Grow the file to accomodate the search, using zeros.
+ // Grow the file to accomodate the search, using zeros.
try {
util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
} catch (util::ErrnoException &e) {
@@ -114,7 +114,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
return reinterpret_cast<uint8_t*>(backing.search.get());
}
// mmap it now.
- // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
+ // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
std::size_t page_size = util::SizePage();
std::size_t alignment_cruft = adjusted_vocab % page_size;
backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
@@ -122,7 +122,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
} else {
util::MapAnonymous(memory_size, backing.search);
return reinterpret_cast<uint8_t*>(backing.search.get());
- }
+ }
}
void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
@@ -140,7 +140,7 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_
util::FSyncOrThrow(backing.file.get());
break;
}
- // header and vocab share the same mmap. The header is written here because we know the counts.
+ // header and vocab share the same mmap. The header is written here because we know the counts.
Parameters params = Parameters();
params.counts = counts;
params.fixed.order = counts.size();
@@ -160,7 +160,7 @@ namespace detail {
bool IsBinaryFormat(int fd) {
const uint64_t size = util::SizeFile(fd);
if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
- // Try reading the header.
+ // Try reading the header.
util::scoped_memory memory;
try {
util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory);
@@ -214,7 +214,7 @@ void SeekPastHeader(int fd, const Parameters &params) {
uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing) {
const uint64_t file_size = util::SizeFile(backing.file.get());
- // The header is smaller than a page, so we have to map the whole header as well.
+ // The header is smaller than a page, so we have to map the whole header as well.
std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
@@ -233,7 +233,8 @@ void ComplainAboutARPA(const Config &config, ModelType model_type) {
if (config.write_mmap || !config.messages) return;
if (config.arpa_complain == Config::ALL) {
*config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
- } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) {
+ } else if (config.arpa_complain == Config::EXPENSIVE &&
+ (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
*config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
}
}
diff --git a/klm/lm/config.cc b/klm/lm/config.cc
index f9d988ca..9520c41c 100644
--- a/klm/lm/config.cc
+++ b/klm/lm/config.cc
@@ -6,6 +6,7 @@ namespace lm {
namespace ngram {
Config::Config() :
+ show_progress(true),
messages(&std::cerr),
enumerate_vocab(NULL),
unknown_missing(COMPLAIN),
diff --git a/klm/lm/config.hh b/klm/lm/config.hh
index 739cee9c..0de7b7c6 100644
--- a/klm/lm/config.hh
+++ b/klm/lm/config.hh
@@ -11,46 +11,52 @@
/* Configuration for ngram model. Separate header to reduce pollution. */
namespace lm {
-
+
class EnumerateVocab;
namespace ngram {
struct Config {
- // EFFECTIVE FOR BOTH ARPA AND BINARY READS
+ // EFFECTIVE FOR BOTH ARPA AND BINARY READS
+
+ // (default true) print progress bar to messages
+ bool show_progress;
// Where to log messages including the progress bar. Set to NULL for
// silence.
std::ostream *messages;
+ std::ostream *ProgressMessages() const {
+ return show_progress ? messages : 0;
+ }
+
// This will be called with every string in the vocabulary. See
// enumerate_vocab.hh for more detail. Config does not take ownership; you
- // are still responsible for deleting it (or stack allocating).
+ // are still responsible for deleting it (or stack allocating).
EnumerateVocab *enumerate_vocab;
-
// ONLY EFFECTIVE WHEN READING ARPA
- // What to do when <unk> isn't in the provided model.
+ // What to do when <unk> isn't in the provided model.
WarningAction unknown_missing;
- // What to do when <s> or </s> is missing from the model.
- // If THROW_UP, the exception will be of type util::SpecialWordMissingException.
+ // What to do when <s> or </s> is missing from the model.
+ // If THROW_UP, the exception will be of type util::SpecialWordMissingException.
WarningAction sentence_marker_missing;
// What to do with a positive log probability. For COMPLAIN and SILENT, map
- // to 0.
+ // to 0.
WarningAction positive_log_probability;
- // The probability to substitute for <unk> if it's missing from the model.
+ // The probability to substitute for <unk> if it's missing from the model.
// No effect if the model has <unk> or unknown_missing == THROW_UP.
float unknown_missing_logprob;
// Size multiplier for probing hash table. Must be > 1. Space is linear in
// this. Time is probing_multiplier / (probing_multiplier - 1). No effect
- // for sorted variant.
+ // for sorted variant.
// If you find yourself setting this to a low number, consider using the
- // TrieModel which has lower memory consumption.
+ // TrieModel which has lower memory consumption.
float probing_multiplier;
// Amount of memory to use for building. The actual memory usage will be
@@ -58,10 +64,10 @@ struct Config {
// models.
std::size_t building_memory;
- // Template for temporary directory appropriate for passing to mkdtemp.
+ // Template for temporary directory appropriate for passing to mkdtemp.
// The characters XXXXXX are appended before passing to mkdtemp. Only
// applies to trie. If NULL, defaults to write_mmap. If that's NULL,
- // defaults to input file name.
+ // defaults to input file name.
const char *temporary_directory_prefix;
// Level of complaining to do when loading from ARPA instead of binary format.
@@ -69,49 +75,46 @@ struct Config {
ARPALoadComplain arpa_complain;
// While loading an ARPA file, also write out this binary format file. Set
- // to NULL to disable.
+ // to NULL to disable.
const char *write_mmap;
enum WriteMethod {
- WRITE_MMAP, // Map the file directly.
- WRITE_AFTER // Write after we're done.
+ WRITE_MMAP, // Map the file directly.
+ WRITE_AFTER // Write after we're done.
};
WriteMethod write_method;
- // Include the vocab in the binary file? Only effective if write_mmap != NULL.
+ // Include the vocab in the binary file? Only effective if write_mmap != NULL.
bool include_vocab;
- // Left rest options. Only used when the model includes rest costs.
+ // Left rest options. Only used when the model includes rest costs.
enum RestFunction {
REST_MAX, // Maximum of any score to the left
- REST_LOWER, // Use lower-order files given below.
+ REST_LOWER, // Use lower-order files given below.
};
RestFunction rest_function;
- // Only used for REST_LOWER.
+ // Only used for REST_LOWER.
std::vector<std::string> rest_lower_files;
-
// Quantization options. Only effective for QuantTrieModel. One value is
// reserved for each of prob and backoff, so 2^bits - 1 buckets will be used
- // to quantize (and one of the remaining backoffs will be 0).
+ // to quantize (and one of the remaining backoffs will be 0).
uint8_t prob_bits, backoff_bits;
// Bhiksha compression (simple form). Only works with trie.
uint8_t pointer_bhiksha_bits;
-
-
+
// ONLY EFFECTIVE WHEN READING BINARY
-
+
// How to get the giant array into memory: lazy mmap, populate, read etc.
- // See util/mmap.hh for details of MapMethod.
+ // See util/mmap.hh for details of MapMethod.
util::LoadMethod load_method;
-
- // Set defaults.
+ // Set defaults.
Config();
};
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index 8c27232e..85c1ea37 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -51,36 +51,36 @@ namespace ngram {
template <class M> class RuleScore {
public:
- explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) {
+ explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) {
out.left.length = 0;
out.right.length = 0;
}
void BeginSentence() {
- out_.right = model_.BeginSentenceState();
- // out_.left is empty.
+ out_->right = model_.BeginSentenceState();
+ // out_->left is empty.
left_done_ = true;
}
void Terminal(WordIndex word) {
- State copy(out_.right);
- FullScoreReturn ret(model_.FullScore(copy, word, out_.right));
+ State copy(out_->right);
+ FullScoreReturn ret(model_.FullScore(copy, word, out_->right));
if (left_done_) { prob_ += ret.prob; return; }
if (ret.independent_left) {
prob_ += ret.prob;
left_done_ = true;
return;
}
- out_.left.pointers[out_.left.length++] = ret.extend_left;
+ out_->left.pointers[out_->left.length++] = ret.extend_left;
prob_ += ret.rest;
- if (out_.right.length != copy.length + 1)
+ if (out_->right.length != copy.length + 1)
left_done_ = true;
}
// Faster version of NonTerminal for the case where the rule begins with a non-terminal.
void BeginNonTerminal(const ChartState &in, float prob = 0.0) {
prob_ = prob;
- out_ = in;
+ *out_ = in;
left_done_ = in.left.full;
}
@@ -89,23 +89,23 @@ template <class M> class RuleScore {
if (!in.left.length) {
if (in.left.full) {
- for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i;
+ for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i;
left_done_ = true;
- out_.right = in.right;
+ out_->right = in.right;
}
return;
}
- if (!out_.right.length) {
- out_.right = in.right;
+ if (!out_->right.length) {
+ out_->right = in.right;
if (left_done_) {
prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);
return;
}
- if (out_.left.length) {
+ if (out_->left.length) {
left_done_ = true;
} else {
- out_.left = in.left;
+ out_->left = in.left;
left_done_ = in.left.full;
}
return;
@@ -113,10 +113,10 @@ template <class M> class RuleScore {
float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1];
float *back = backoffs, *back2 = backoffs2;
- unsigned char next_use = out_.right.length;
+ unsigned char next_use = out_->right.length;
// First word
- if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return;
+ if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return;
// Words after the first, so extending a bigram to begin with
for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) {
@@ -127,54 +127,58 @@ template <class M> class RuleScore {
if (in.left.full) {
for (const float *i = back; i != back + next_use; ++i) prob_ += *i;
left_done_ = true;
- out_.right = in.right;
+ out_->right = in.right;
return;
}
// Right state was minimized, so it's already independent of the new words to the left.
if (in.right.length < in.left.length) {
- out_.right = in.right;
+ out_->right = in.right;
return;
}
// Shift exisiting words down.
- for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) {
+ for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) {
*(i + in.right.length) = *i;
}
// Add words from in.right.
- std::copy(in.right.words, in.right.words + in.right.length, out_.right.words);
+ std::copy(in.right.words, in.right.words + in.right.length, out_->right.words);
// Assemble backoff composed on the existing state's backoff followed by the new state's backoff.
- std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff);
- std::copy(back, back + next_use, out_.right.backoff + in.right.length);
- out_.right.length = in.right.length + next_use;
+ std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff);
+ std::copy(back, back + next_use, out_->right.backoff + in.right.length);
+ out_->right.length = in.right.length + next_use;
}
float Finish() {
// A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
- out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1);
+ out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1);
return prob_;
}
void Reset() {
prob_ = 0.0;
left_done_ = false;
- out_.left.length = 0;
- out_.right.length = 0;
+ out_->left.length = 0;
+ out_->right.length = 0;
+ }
+ void Reset(ChartState &replacement) {
+ out_ = &replacement;
+ Reset();
}
private:
bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
ProcessRet(model_.ExtendLeft(
- out_.right.words, out_.right.words + next_use, // Words to extend into
+ out_->right.words, out_->right.words + next_use, // Words to extend into
back_in, // Backoffs to use
in.left.pointers[extend_length - 1], extend_length, // Words to be extended
back_out, // Backoffs for the next score
next_use)); // Length of n-gram to use in next scoring.
- if (next_use != out_.right.length) {
+ if (next_use != out_->right.length) {
left_done_ = true;
if (!next_use) {
// Early exit.
- out_.right = in.right;
+ out_->right = in.right;
prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);
return true;
}
@@ -193,13 +197,13 @@ template <class M> class RuleScore {
left_done_ = true;
return;
}
- out_.left.pointers[out_.left.length++] = ret.extend_left;
+ out_->left.pointers[out_->left.length++] = ret.extend_left;
prob_ += ret.rest;
}
const M &model_;
- ChartState &out_;
+ ChartState *out_;
bool left_done_;
diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh
index 989f8324..3eb97ccd 100644
--- a/klm/lm/max_order.hh
+++ b/klm/lm/max_order.hh
@@ -4,9 +4,6 @@
* (kMaxOrder - 1) * sizeof(float) bytes instead of
* sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
*/
-#ifndef KENLM_MAX_ORDER
-#define KENLM_MAX_ORDER 6
-#endif
#ifndef KENLM_ORDER_MESSAGE
-#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh."
+#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh."
#endif
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index 2fd20481..a40fd2fb 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -37,7 +37,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
LoadLM(file, config, *this);
- // g++ prints warnings unless these are fully initialized.
+ // g++ prints warnings unless these are fully initialized.
State begin_sentence = State();
begin_sentence.length = 1;
begin_sentence.words[0] = vocab_.BeginSentence();
@@ -69,8 +69,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
- // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
- util::FilePiece f(backing_.file.release(), file, config.messages);
+ // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
+ util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());
try {
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
@@ -80,14 +80,14 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
- // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
+ // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
if (config.write_mmap) {
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + backing_.search.size());
+ wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
@@ -95,7 +95,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (!vocab_.SawUnk()) {
assert(config.unknown_missing != THROW_UP);
- // Default probabilities for unknown.
+ // Default probabilities for unknown.
search_.UnknownUnigram().backoff = 0.0;
search_.UnknownUnigram().prob = config.unknown_missing_logprob;
}
@@ -147,7 +147,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const {
- // Generate a state from context.
+ // Generate a state from context.
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
if (context_rend == context_rbegin) {
out_state.length = 0;
@@ -191,7 +191,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
ret.rest = ptr.Rest();
ret.prob = ptr.Prob();
ret.extend_left = extend_pointer;
- // If this function is called, then it does depend on left words.
+ // If this function is called, then it does depend on left words.
ret.independent_left = false;
}
float subtract_me = ret.rest;
@@ -199,7 +199,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
next_use = extend_length;
ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret);
next_use -= extend_length;
- // Charge backoffs.
+ // Charge backoffs.
for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
ret.prob -= subtract_me;
ret.rest -= subtract_me;
@@ -209,7 +209,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
namespace {
// Do a paraonoid copy of history, assuming new_word has already been copied
// (hence the -1). out_state.length could be zero so I avoided using
-// std::copy.
+// std::copy.
void CopyRemainingHistory(const WordIndex *from, State &out_state) {
WordIndex *out = out_state.words + 1;
const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1;
@@ -217,18 +217,19 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) {
}
} // namespace
-/* Ugly optimized function. Produce a score excluding backoff.
- * The search goes in increasing order of ngram length.
+/* Ugly optimized function. Produce a score excluding backoff.
+ * The search goes in increasing order of ngram length.
* Context goes backward, so context_begin is the word immediately preceeding
- * new_word.
+ * new_word.
*/
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(
const WordIndex *const context_rbegin,
const WordIndex *const context_rend,
const WordIndex new_word,
State &out_state) const {
+ assert(new_word < vocab_.Bound());
FullScoreReturn ret;
- // ret.ngram_length contains the last known non-blank ngram length.
+ // ret.ngram_length contains the last known non-blank ngram length.
ret.ngram_length = 1;
typename Search::Node node;
@@ -237,9 +238,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
ret.prob = uni.Prob();
ret.rest = uni.Rest();
- // This is the length of the context that should be used for continuation to the right.
+ // This is the length of the context that should be used for continuation to the right.
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
- // We'll write the word anyway since it will probably be used and does no harm being there.
+ // We'll write the word anyway since it will probably be used and does no harm being there.
out_state.words[0] = new_word;
if (context_rbegin == context_rend) return ret;
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index a1623834..2d6f15b2 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -231,7 +231,7 @@ template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char *
template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
NoRestBuild build;
- ApplyBuild(f, counts, config, vocab, warn, build);
+ ApplyBuild(f, counts, vocab, warn, build);
}
template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
@@ -239,19 +239,19 @@ template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, cons
case Config::REST_MAX:
{
MaxRestBuild build;
- ApplyBuild(f, counts, config, vocab, warn, build);
+ ApplyBuild(f, counts, vocab, warn, build);
}
break;
case Config::REST_LOWER:
{
LowerRestBuild<ProbingModel> build(config, counts.size(), vocab);
- ApplyBuild(f, counts, config, vocab, warn, build);
+ ApplyBuild(f, counts, vocab, warn, build);
}
break;
}
}
-template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) {
+template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) {
for (WordIndex i = 0; i < counts[0]; ++i) {
build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]);
}
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index a52f107b..00595796 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -147,7 +147,7 @@ template <class Value> class HashedSearch {
// Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
- template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
+ template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
class Unigram {
public:
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index debcfd07..1b0d9b26 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -55,7 +55,7 @@ struct ProbPointer {
uint64_t index;
};
-// Array of n-grams and float indices.
+// Array of n-grams and float indices.
class BackoffMessages {
public:
void Init(std::size_t entry_size) {
@@ -100,7 +100,7 @@ class BackoffMessages {
void Apply(float *const *const base, RecordReader &reader) {
FinishedAdding();
if (current_ == allocated_) return;
- // We'll also use the same buffer to record messages to blanks that they extend.
+ // We'll also use the same buffer to record messages to blanks that they extend.
WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_);
const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex);
for (reader.Rewind(); reader && (current_ != allocated_); ) {
@@ -109,7 +109,7 @@ class BackoffMessages {
++reader;
break;
case 1:
- // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends.
+ // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends.
for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w;
current_ += entry_size_;
break;
@@ -126,7 +126,7 @@ class BackoffMessages {
break;
}
}
- // Now this is a list of blanks that extend right.
+ // Now this is a list of blanks that extend right.
entry_size_ = sizeof(WordIndex) * order;
Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get()));
current_ = (uint8_t*)backing_.get();
@@ -153,7 +153,7 @@ class BackoffMessages {
private:
void FinishedAdding() {
Resize(current_ - (uint8_t*)backing_.get());
- // Sort requests in same order as files.
+ // Sort requests in same order as files.
std::sort(
util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)),
util::SizedIterator(util::SizedProxy(current_, entry_size_)),
@@ -220,7 +220,7 @@ class SRISucks {
}
private:
- // This used to be one array. Then I needed to separate it by order for quantization to work.
+ // This used to be one array. Then I needed to separate it by order for quantization to work.
std::vector<float> values_[KENLM_MAX_ORDER - 1];
BackoffMessages messages_[KENLM_MAX_ORDER - 1];
@@ -253,7 +253,7 @@ class FindBlanks {
++counts_.back();
}
- // Unigrams wrote one past.
+ // Unigrams wrote one past.
void Cleanup() {
--counts_[0];
}
@@ -270,15 +270,15 @@ class FindBlanks {
SRISucks &sri_;
};
-// Phase to actually write n-grams to the trie.
+// Phase to actually write n-grams to the trie.
template <class Quant, class Bhiksha> class WriteEntries {
public:
- WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
+ WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
contexts_(contexts),
quant_(quant),
unigrams_(unigrams),
middle_(middle),
- longest_(longest),
+ longest_(longest),
bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)),
order_(order),
sri_(sri) {}
@@ -328,7 +328,7 @@ struct Gram {
const WordIndex *begin, *end;
- // For queue, this is the direction we want.
+ // For queue, this is the direction we want.
bool operator<(const Gram &other) const {
return std::lexicographical_compare(other.begin, other.end, begin, end);
}
@@ -353,7 +353,7 @@ template <class Doing> class BlankManager {
been_length_ = length;
return;
}
- // There are blanks to insert starting with order blank.
+ // There are blanks to insert starting with order blank.
unsigned char blank = cur - to + 1;
UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context.");
const float *lower_basis;
@@ -363,7 +363,7 @@ template <class Doing> class BlankManager {
assert(*lower_basis != kBadProb);
doing_.MiddleBlank(blank, to, based_on, *lower_basis);
*pre = *cur;
- // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
+ // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
basis_[blank - 1] = kBadProb;
}
*pre = *cur;
@@ -377,7 +377,7 @@ template <class Doing> class BlankManager {
unsigned char been_length_;
float basis_[KENLM_MAX_ORDER];
-
+
Doing &doing_;
};
@@ -451,7 +451,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re
}
void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) {
- // Fill unigram probabilities.
+ // Fill unigram probabilities.
try {
rewind(file);
for (WordIndex i = 0; i < unigram_count; ++i) {
@@ -486,7 +486,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
util::scoped_memory unigrams;
MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
- RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder);
+ RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder);
fixed_counts = finder.Counts();
}
unigram_file.reset(util::FDOpenOrThrow(unigram_fd));
@@ -504,7 +504,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
inputs[i-2].Rewind();
}
if (Quant::kTrain) {
- util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing");
+ util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0),
+ config.ProgressMessages(), "Quantizing");
for (unsigned char i = 2; i < counts.size(); ++i) {
TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
}
@@ -519,13 +520,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
}
- // Fill entries except unigram probabilities.
+ // Fill entries except unigram probabilities.
{
WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
- RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer);
+ RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);
}
- // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
+ // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
for (unsigned char order = 2; order <= counts.size(); ++order) {
const RecordReader &context = contexts[order - 2];
if (context) {
@@ -541,13 +542,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
}
/* Set ending offsets so the last entry will be sized properly */
- // Last entry for unigrams was already set.
+ // Last entry for unigrams was already set.
if (out.middle_begin_ != out.middle_end_) {
for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
i->FinishedLoading((i+1)->InsertIndex(), config);
}
(out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config);
- }
+ }
}
template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
@@ -595,7 +596,7 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::Initializ
} else {
temporary_prefix = file;
}
- // At least 1MB sorting memory.
+ // At least 1MB sorting memory.
SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab);
BuildTrie(sorted, counts, config, *this, quant_, vocab, backing);
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 11c27518..fd7f96dc 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
}
*end_ = hashed;
if (enumerate_) {
- strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size());
+ void *copied = string_backing_.Allocate(str.size());
+ memcpy(copied, str.data(), str.size());
+ strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size());
}
++end_;
// This is 1 + the offset where it was inserted to make room for unk.
@@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
if (enumerate_) {
if (!strings_to_enumerate_.empty()) {
- util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
+ util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
util::JointSort(begin_, end_, values);
}
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
@@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
enumerate_->Add(i + 1, strings_to_enumerate_[i]);
}
strings_to_enumerate_.clear();
+ string_backing_.FreeAll();
} else {
util::JointSort(begin_, end_, reorder_vocab + 1);
}
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index de54eb06..3902f117 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -4,6 +4,7 @@
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/virtual_interface.hh"
+#include "util/pool.hh"
#include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
@@ -96,7 +97,9 @@ class SortedVocabulary : public base::Vocabulary {
EnumerateVocab *enumerate_;
// Actual strings. Used only when loading from ARPA and enumerate_ != NULL
- std::vector<std::string> strings_to_enumerate_;
+ util::Pool string_backing_;
+
+ std::vector<StringPiece> strings_to_enumerate_;
};
#pragma pack(push)