diff options
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 5d8c70db..1bcfe27d 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -234,8 +234,19 @@ class FindBlanks { return unigrams_[index].prob; } +<<<<<<< HEAD +// Phase to count n-grams, including blanks inserted because they were pruned but have extensions +class JustCount { + public: + template <class Middle, class Longest> JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) + : counts_(counts), longest_counts_(counts + order - 1) {} + + void Unigrams(WordIndex begin, WordIndex end) { + counts_[0] += end - begin; +======= void Unigram(WordIndex /*index*/) { ++counts_[0]; +>>>>>>> upstream/master } void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) { @@ -267,7 +278,11 @@ class FindBlanks { // Phase to actually write n-grams to the trie. template <class Quant, class Bhiksha> class WriteEntries { public: +<<<<<<< HEAD + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) : +======= WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, unsigned char order, SRISucks &sri) : +>>>>>>> upstream/master contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -315,8 +330,16 @@ template <class Quant, class Bhiksha> class WriteEntries { SRISucks &sri_; }; +<<<<<<< HEAD +template <class Doing> class RecursiveInsert { + public: + template <class MiddleT, class LongestT> RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : + doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { + } +======= struct Gram { Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {} +>>>>>>> upstream/master const WordIndex *begin, *end; @@ -417,6 +440,29 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u } } +<<<<<<< HEAD +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} + +template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + ProbBackoff weights; + std::vector<float> probs, backoffs; + probs.reserve(count); + backoffs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); + ++progress; + } +======= template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const std::vector<float> &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { std::vector<float> probs(additional), backoffs; probs.reserve(count + additional.size()); @@ -426,10 +472,26 @@ template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const probs.push_back(weights.prob); if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); ++progress; +>>>>>>> upstream/master } quant.Train(order, probs, backoffs); } +<<<<<<< HEAD +template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + Prob weights; + std::vector<float> probs, backoffs; + probs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + ++progress; + } +======= template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { std::vector<float> probs, backoffs; probs.reserve(count); @@ -437,10 +499,18 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re const Prob &weights = *reinterpret_cast<const Prob*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order); probs.push_back(weights.prob); ++progress; +>>>>>>> upstream/master } quant.TrainProb(order, probs); } +<<<<<<< HEAD +} // namespace + +template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { + std::vector<SortedFileReader> inputs(counts.size() - 1); + std::vector<ContextReader> contexts(counts.size() - 1); +======= void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { // Fill unigram probabilities. try { @@ -463,6 +533,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { RecordReader inputs[kMaxOrder - 1]; RecordReader contexts[kMaxOrder - 1]; +>>>>>>> upstream/master for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; @@ -477,12 +548,17 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre SRISucks sri; std::vector<uint64_t> fixed_counts(counts.size()); { +<<<<<<< HEAD + RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); + counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); +======= std::string temp(file_prefix); temp += "unigrams"; util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str())); util::scoped_memory unigrams; MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); +>>>>>>> upstream/master } for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) { if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); @@ -490,6 +566,18 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; +<<<<<<< HEAD + out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); + + if (Quant::kTrain) { + util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); + for (unsigned char i = 2; i < counts.size(); ++i) { + TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); + } + TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); + quant.FinishedLoading(config); + } +======= util::scoped_FILE unigram_file; { std::string name(file_prefix + "unigrams"); @@ -499,6 +587,7 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); +>>>>>>> upstream/master for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -521,8 +610,30 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre } // Fill entries except unigram probabilities. { +<<<<<<< HEAD + RecursiveInsert<WriteEntries<Quant, Bhiksha> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); + inserter.Apply(config.messages, "Building trie", fixed_counts[0]); + } + + // Fill unigram probabilities. + try { + std::string name(file_prefix + "unigrams"); + util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); + for (WordIndex i = 0; i < counts[0]; ++i) { + ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); + if (contexts[0] && **contexts[0] == i) { + SetExtension(unigrams[i].weights.backoff); + ++contexts[0]; + } + } + RemoveOrThrow(name.c_str()); + } catch (util::Exception &e) { + e << " while re-reading unigram probabilities"; + throw; +======= WriteEntries<Quant, Bhiksha> writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); +>>>>>>> upstream/master } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. @@ -576,6 +687,17 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup } longest.Init(start, quant_.Long(counts.size()), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); +<<<<<<< HEAD +} + +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() { + unigram.LoadedBinary(); + for (Middle *i = middle_begin_; i != middle_end_; ++i) { + i->LoadedBinary(); + } + longest.LoadedBinary(); +} +======= } template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() { @@ -593,6 +715,7 @@ bool IsDirectory(const char *path) { return S_ISDIR(info.st_mode); } } // namespace +>>>>>>> upstream/master template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; |