summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/binary_format.cc4
-rw-r--r--klm/lm/binary_format.hh2
-rw-r--r--klm/lm/model_test.cc8
-rw-r--r--klm/lm/search_trie.cc123
-rw-r--r--klm/lm/trie.cc10
5 files changed, 147 insertions, 0 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 27cada13..eac8aa85 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -182,6 +182,10 @@ void SeekPastHeader(int fd, const Parameters &params) {
SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
}
+void SeekPastHeader(int fd, const Parameters &params) {
+ SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
+}
+
uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing) {
const off_t file_size = util::SizeFile(backing.file.get());
// The header is smaller than a page, so we have to map the whole header as well.
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh
index e9df0892..a83f6b89 100644
--- a/klm/lm/binary_format.hh
+++ b/klm/lm/binary_format.hh
@@ -76,6 +76,8 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet
void SeekPastHeader(int fd, const Parameters &params);
+void SeekPastHeader(int fd, const Parameters &params);
+
uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing);
void ComplainAboutARPA(const Config &config, ModelType model_type);
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 2654071f..3585d34b 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -264,6 +264,14 @@ template <class M> void NoUnkCheck(const M &model) {
BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001);
}
+template <class M> void NoUnkCheck(const M &model) {
+ WordIndex unk_index = 0;
+ State state;
+
+ FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state);
+ BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001);
+}
+
template <class M> void Everything(const M &m) {
Starters(m);
Continuation(m);
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 5d8c70db..1bcfe27d 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -234,8 +234,19 @@ class FindBlanks {
return unigrams_[index].prob;
}
+<<<<<<< HEAD
+// Phase to count n-grams, including blanks inserted because they were pruned but have extensions
+class JustCount {
+ public:
+ template <class Middle, class Longest> JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order)
+ : counts_(counts), longest_counts_(counts + order - 1) {}
+
+ void Unigrams(WordIndex begin, WordIndex end) {
+ counts_[0] += end - begin;
+=======
void Unigram(WordIndex /*index*/) {
++counts_[0];
+>>>>>>> upstream/master
}
void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) {
@@ -267,7 +278,11 @@ class FindBlanks {
// Phase to actually write n-grams to the trie.
template <class Quant, class Bhiksha> class WriteEntries {
public:
+<<<<<<< HEAD
+ WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :
+=======
WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, unsigned char order, SRISucks &sri) :
+>>>>>>> upstream/master
contexts_(contexts),
unigrams_(unigrams),
middle_(middle),
@@ -315,8 +330,16 @@ template <class Quant, class Bhiksha> class WriteEntries {
SRISucks &sri_;
};
+<<<<<<< HEAD
+template <class Doing> class RecursiveInsert {
+ public:
+ template <class MiddleT, class LongestT> RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) :
+ doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) {
+ }
+=======
struct Gram {
Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {}
+>>>>>>> upstream/master
const WordIndex *begin, *end;
@@ -417,6 +440,29 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u
}
}
+<<<<<<< HEAD
+bool IsDirectory(const char *path) {
+ struct stat info;
+ if (0 != stat(path, &info)) return false;
+ return S_ISDIR(info.st_mode);
+}
+
+template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) {
+ ProbBackoff weights;
+ std::vector<float> probs, backoffs;
+ probs.reserve(count);
+ backoffs.reserve(count);
+ for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) {
+ uint64_t entries = reader.ReadCount();
+ for (uint64_t c = 0; c < entries; ++c) {
+ reader.ReadWord();
+ reader.ReadWeights(weights);
+ // kBlankProb isn't added yet.
+ probs.push_back(weights.prob);
+ if (weights.backoff != 0.0) backoffs.push_back(weights.backoff);
+ ++progress;
+ }
+=======
template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const std::vector<float> &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
std::vector<float> probs(additional), backoffs;
probs.reserve(count + additional.size());
@@ -426,10 +472,26 @@ template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const
probs.push_back(weights.prob);
if (weights.backoff != 0.0) backoffs.push_back(weights.backoff);
++progress;
+>>>>>>> upstream/master
}
quant.Train(order, probs, backoffs);
}
+<<<<<<< HEAD
+template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) {
+ Prob weights;
+ std::vector<float> probs, backoffs;
+ probs.reserve(count);
+ for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) {
+ uint64_t entries = reader.ReadCount();
+ for (uint64_t c = 0; c < entries; ++c) {
+ reader.ReadWord();
+ reader.ReadWeights(weights);
+ // kBlankProb isn't added yet.
+ probs.push_back(weights.prob);
+ ++progress;
+ }
+=======
template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
std::vector<float> probs, backoffs;
probs.reserve(count);
@@ -437,10 +499,18 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re
const Prob &weights = *reinterpret_cast<const Prob*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order);
probs.push_back(weights.prob);
++progress;
+>>>>>>> upstream/master
}
quant.TrainProb(order, probs);
}
+<<<<<<< HEAD
+} // namespace
+
+template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
+ std::vector<SortedFileReader> inputs(counts.size() - 1);
+ std::vector<ContextReader> contexts(counts.size() - 1);
+=======
void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) {
// Fill unigram probabilities.
try {
@@ -463,6 +533,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c
template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
RecordReader inputs[kMaxOrder - 1];
RecordReader contexts[kMaxOrder - 1];
+>>>>>>> upstream/master
for (unsigned char i = 2; i <= counts.size(); ++i) {
std::stringstream assembled;
@@ -477,12 +548,17 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre
SRISucks sri;
std::vector<uint64_t> fixed_counts(counts.size());
{
+<<<<<<< HEAD
+ RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
+ counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
+=======
std::string temp(file_prefix); temp += "unigrams";
util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str()));
util::scoped_memory unigrams;
MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder);
+>>>>>>> upstream/master
}
for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) {
if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading");
@@ -490,6 +566,18 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
+<<<<<<< HEAD
+ out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
+
+ if (Quant::kTrain) {
+ util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0));
+ for (unsigned char i = 2; i < counts.size(); ++i) {
+ TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant);
+ }
+ TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant);
+ quant.FinishedLoading(config);
+ }
+=======
util::scoped_FILE unigram_file;
{
std::string name(file_prefix + "unigrams");
@@ -499,6 +587,7 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre
sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
+>>>>>>> upstream/master
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
@@ -521,8 +610,30 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre
}
// Fill entries except unigram probabilities.
{
+<<<<<<< HEAD
+ RecursiveInsert<WriteEntries<Quant, Bhiksha> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
+ inserter.Apply(config.messages, "Building trie", fixed_counts[0]);
+ }
+
+ // Fill unigram probabilities.
+ try {
+ std::string name(file_prefix + "unigrams");
+ util::scoped_FILE file(OpenOrThrow(name.c_str(), "r"));
+ for (WordIndex i = 0; i < counts[0]; ++i) {
+ ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff));
+ if (contexts[0] && **contexts[0] == i) {
+ SetExtension(unigrams[i].weights.backoff);
+ ++contexts[0];
+ }
+ }
+ RemoveOrThrow(name.c_str());
+ } catch (util::Exception &e) {
+ e << " while re-reading unigram probabilities";
+ throw;
+=======
WriteEntries<Quant, Bhiksha> writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer);
+>>>>>>> upstream/master
}
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
@@ -576,6 +687,17 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup
}
longest.Init(start, quant_.Long(counts.size()), counts[0]);
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
+<<<<<<< HEAD
+}
+
+template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
+ unigram.LoadedBinary();
+ for (Middle *i = middle_begin_; i != middle_end_; ++i) {
+ i->LoadedBinary();
+ }
+ longest.LoadedBinary();
+}
+=======
}
template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
@@ -593,6 +715,7 @@ bool IsDirectory(const char *path) {
return S_ISDIR(info.st_mode);
}
} // namespace
+>>>>>>> upstream/master
template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
std::string temporary_directory;
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 20075bb8..a1136b6f 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -91,6 +91,15 @@ template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
return false;
}
+<<<<<<< HEAD
+ uint64_t index = at_pointer;
+ at_pointer *= total_bits_;
+ at_pointer += word_bits_;
+ quant_.Read(base_, at_pointer, prob, backoff);
+ at_pointer += quant_.TotalBits();
+
+ bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);
+=======
pointer = at_pointer;
at_pointer *= total_bits_;
at_pointer += word_bits_;
@@ -99,6 +108,7 @@ template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find
at_pointer += quant_.TotalBits();
bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range);
+>>>>>>> upstream/master
return true;
}