diff options
Diffstat (limited to 'klm/lm')
-rw-r--r-- | klm/lm/binary_format.cc | 4 | ||||
-rw-r--r-- | klm/lm/binary_format.hh | 2 | ||||
-rw-r--r-- | klm/lm/model_test.cc | 8 | ||||
-rw-r--r-- | klm/lm/search_trie.cc | 123 | ||||
-rw-r--r-- | klm/lm/trie.cc | 10 |
5 files changed, 147 insertions, 0 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 27cada13..eac8aa85 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -182,6 +182,10 @@ void SeekPastHeader(int fd, const Parameters ¶ms) { SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); } +void SeekPastHeader(int fd, const Parameters ¶ms) { + SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); +} + uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { const off_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index e9df0892..a83f6b89 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -76,6 +76,8 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet void SeekPastHeader(int fd, const Parameters ¶ms); +void SeekPastHeader(int fd, const Parameters ¶ms); + uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing); void ComplainAboutARPA(const Config &config, ModelType model_type); diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 2654071f..3585d34b 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -264,6 +264,14 @@ template <class M> void NoUnkCheck(const M &model) { BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001); } +template <class M> void NoUnkCheck(const M &model) { + WordIndex unk_index = 0; + State state; + + FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state); + BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001); +} + template <class M> void Everything(const M &m) { Starters(m); Continuation(m); diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 5d8c70db..1bcfe27d 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -234,8 +234,19 @@ class FindBlanks { return unigrams_[index].prob; } +<<<<<<< HEAD +// Phase to count n-grams, including blanks inserted because they were pruned but have extensions +class JustCount { + public: + template <class Middle, class Longest> JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) + : counts_(counts), longest_counts_(counts + order - 1) {} + + void Unigrams(WordIndex begin, WordIndex end) { + counts_[0] += end - begin; +======= void Unigram(WordIndex /*index*/) { ++counts_[0]; +>>>>>>> upstream/master } void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) { @@ -267,7 +278,11 @@ class FindBlanks { // Phase to actually write n-grams to the trie. template <class Quant, class Bhiksha> class WriteEntries { public: +<<<<<<< HEAD + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) : +======= WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, unsigned char order, SRISucks &sri) : +>>>>>>> upstream/master contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -315,8 +330,16 @@ template <class Quant, class Bhiksha> class WriteEntries { SRISucks &sri_; }; +<<<<<<< HEAD +template <class Doing> class RecursiveInsert { + public: + template <class MiddleT, class LongestT> RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : + doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { + } +======= struct Gram { Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {} +>>>>>>> upstream/master const WordIndex *begin, *end; @@ -417,6 +440,29 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u } } +<<<<<<< HEAD +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} + +template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + ProbBackoff weights; + std::vector<float> probs, backoffs; + probs.reserve(count); + backoffs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); + ++progress; + } +======= template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const std::vector<float> &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { std::vector<float> probs(additional), backoffs; probs.reserve(count + additional.size()); @@ -426,10 +472,26 @@ template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const probs.push_back(weights.prob); if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); ++progress; +>>>>>>> upstream/master } quant.Train(order, probs, backoffs); } +<<<<<<< HEAD +template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + Prob weights; + std::vector<float> probs, backoffs; + probs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + ++progress; + } +======= template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { std::vector<float> probs, backoffs; probs.reserve(count); @@ -437,10 +499,18 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re const Prob &weights = *reinterpret_cast<const Prob*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order); probs.push_back(weights.prob); ++progress; +>>>>>>> upstream/master } quant.TrainProb(order, probs); } +<<<<<<< HEAD +} // namespace + +template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { + std::vector<SortedFileReader> inputs(counts.size() - 1); + std::vector<ContextReader> contexts(counts.size() - 1); +======= void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { // Fill unigram probabilities. try { @@ -463,6 +533,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { RecordReader inputs[kMaxOrder - 1]; RecordReader contexts[kMaxOrder - 1]; +>>>>>>> upstream/master for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; @@ -477,12 +548,17 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre SRISucks sri; std::vector<uint64_t> fixed_counts(counts.size()); { +<<<<<<< HEAD + RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); + counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); +======= std::string temp(file_prefix); temp += "unigrams"; util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str())); util::scoped_memory unigrams; MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); +>>>>>>> upstream/master } for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) { if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); @@ -490,6 +566,18 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; +<<<<<<< HEAD + out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); + + if (Quant::kTrain) { + util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); + for (unsigned char i = 2; i < counts.size(); ++i) { + TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); + } + TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); + quant.FinishedLoading(config); + } +======= util::scoped_FILE unigram_file; { std::string name(file_prefix + "unigrams"); @@ -499,6 +587,7 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); +>>>>>>> upstream/master for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -521,8 +610,30 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre } // Fill entries except unigram probabilities. { +<<<<<<< HEAD + RecursiveInsert<WriteEntries<Quant, Bhiksha> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); + inserter.Apply(config.messages, "Building trie", fixed_counts[0]); + } + + // Fill unigram probabilities. + try { + std::string name(file_prefix + "unigrams"); + util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); + for (WordIndex i = 0; i < counts[0]; ++i) { + ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); + if (contexts[0] && **contexts[0] == i) { + SetExtension(unigrams[i].weights.backoff); + ++contexts[0]; + } + } + RemoveOrThrow(name.c_str()); + } catch (util::Exception &e) { + e << " while re-reading unigram probabilities"; + throw; +======= WriteEntries<Quant, Bhiksha> writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); +>>>>>>> upstream/master } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. @@ -576,6 +687,17 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup } longest.Init(start, quant_.Long(counts.size()), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); +<<<<<<< HEAD +} + +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() { + unigram.LoadedBinary(); + for (Middle *i = middle_begin_; i != middle_end_; ++i) { + i->LoadedBinary(); + } + longest.LoadedBinary(); +} +======= } template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() { @@ -593,6 +715,7 @@ bool IsDirectory(const char *path) { return S_ISDIR(info.st_mode); } } // namespace +>>>>>>> upstream/master template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 20075bb8..a1136b6f 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -91,6 +91,15 @@ template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } +<<<<<<< HEAD + uint64_t index = at_pointer; + at_pointer *= total_bits_; + at_pointer += word_bits_; + quant_.Read(base_, at_pointer, prob, backoff); + at_pointer += quant_.TotalBits(); + + bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); +======= pointer = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; @@ -99,6 +108,7 @@ template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find at_pointer += quant_.TotalBits(); bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range); +>>>>>>> upstream/master return true; } |