summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r--klm/lm/search_trie.cc49
1 files changed, 25 insertions, 24 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 832cc9f7..1b0d9b26 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -55,7 +55,7 @@ struct ProbPointer {
uint64_t index;
};
-// Array of n-grams and float indices.
+// Array of n-grams and float indices.
class BackoffMessages {
public:
void Init(std::size_t entry_size) {
@@ -89,7 +89,7 @@ class BackoffMessages {
if (!HasExtension(weights.backoff)) {
weights.backoff = kExtensionBackoff;
UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed.");
- WriteOrThrow(unigrams, &weights, sizeof(weights));
+ util::WriteOrThrow(unigrams, &weights, sizeof(weights));
}
const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex));
base[write_to.array][write_to.index] += weights.backoff;
@@ -100,7 +100,7 @@ class BackoffMessages {
void Apply(float *const *const base, RecordReader &reader) {
FinishedAdding();
if (current_ == allocated_) return;
- // We'll also use the same buffer to record messages to blanks that they extend.
+ // We'll also use the same buffer to record messages to blanks that they extend.
WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_);
const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex);
for (reader.Rewind(); reader && (current_ != allocated_); ) {
@@ -109,7 +109,7 @@ class BackoffMessages {
++reader;
break;
case 1:
- // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends.
+ // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends.
for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w;
current_ += entry_size_;
break;
@@ -126,7 +126,7 @@ class BackoffMessages {
break;
}
}
- // Now this is a list of blanks that extend right.
+ // Now this is a list of blanks that extend right.
entry_size_ = sizeof(WordIndex) * order;
Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get()));
current_ = (uint8_t*)backing_.get();
@@ -153,7 +153,7 @@ class BackoffMessages {
private:
void FinishedAdding() {
Resize(current_ - (uint8_t*)backing_.get());
- // Sort requests in same order as files.
+ // Sort requests in same order as files.
std::sort(
util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)),
util::SizedIterator(util::SizedProxy(current_, entry_size_)),
@@ -220,7 +220,7 @@ class SRISucks {
}
private:
- // This used to be one array. Then I needed to separate it by order for quantization to work.
+ // This used to be one array. Then I needed to separate it by order for quantization to work.
std::vector<float> values_[KENLM_MAX_ORDER - 1];
BackoffMessages messages_[KENLM_MAX_ORDER - 1];
@@ -253,7 +253,7 @@ class FindBlanks {
++counts_.back();
}
- // Unigrams wrote one past.
+ // Unigrams wrote one past.
void Cleanup() {
--counts_[0];
}
@@ -270,15 +270,15 @@ class FindBlanks {
SRISucks &sri_;
};
-// Phase to actually write n-grams to the trie.
+// Phase to actually write n-grams to the trie.
template <class Quant, class Bhiksha> class WriteEntries {
public:
- WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
+ WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
contexts_(contexts),
quant_(quant),
unigrams_(unigrams),
middle_(middle),
- longest_(longest),
+ longest_(longest),
bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)),
order_(order),
sri_(sri) {}
@@ -328,7 +328,7 @@ struct Gram {
const WordIndex *begin, *end;
- // For queue, this is the direction we want.
+ // For queue, this is the direction we want.
bool operator<(const Gram &other) const {
return std::lexicographical_compare(other.begin, other.end, begin, end);
}
@@ -353,7 +353,7 @@ template <class Doing> class BlankManager {
been_length_ = length;
return;
}
- // There are blanks to insert starting with order blank.
+ // There are blanks to insert starting with order blank.
unsigned char blank = cur - to + 1;
UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context.");
const float *lower_basis;
@@ -363,7 +363,7 @@ template <class Doing> class BlankManager {
assert(*lower_basis != kBadProb);
doing_.MiddleBlank(blank, to, based_on, *lower_basis);
*pre = *cur;
- // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
+ // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
basis_[blank - 1] = kBadProb;
}
*pre = *cur;
@@ -377,7 +377,7 @@ template <class Doing> class BlankManager {
unsigned char been_length_;
float basis_[KENLM_MAX_ORDER];
-
+
Doing &doing_;
};
@@ -451,7 +451,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re
}
void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) {
- // Fill unigram probabilities.
+ // Fill unigram probabilities.
try {
rewind(file);
for (WordIndex i = 0; i < unigram_count; ++i) {
@@ -486,7 +486,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
util::scoped_memory unigrams;
MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
- RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder);
+ RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder);
fixed_counts = finder.Counts();
}
unigram_file.reset(util::FDOpenOrThrow(unigram_fd));
@@ -504,7 +504,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
inputs[i-2].Rewind();
}
if (Quant::kTrain) {
- util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing");
+ util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0),
+ config.ProgressMessages(), "Quantizing");
for (unsigned char i = 2; i < counts.size(); ++i) {
TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
}
@@ -519,13 +520,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
}
- // Fill entries except unigram probabilities.
+ // Fill entries except unigram probabilities.
{
WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
- RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer);
+ RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);
}
- // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
+ // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
for (unsigned char order = 2; order <= counts.size(); ++order) {
const RecordReader &context = contexts[order - 2];
if (context) {
@@ -541,13 +542,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
}
/* Set ending offsets so the last entry will be sized properly */
- // Last entry for unigrams was already set.
+ // Last entry for unigrams was already set.
if (out.middle_begin_ != out.middle_end_) {
for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
i->FinishedLoading((i+1)->InsertIndex(), config);
}
(out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config);
- }
+ }
}
template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
@@ -595,7 +596,7 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::Initializ
} else {
temporary_prefix = file;
}
- // At least 1MB sorting memory.
+ // At least 1MB sorting memory.
SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab);
BuildTrie(sorted, counts, config, *this, quant_, vocab, backing);