diff options
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 63631223..b830dfc3 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -535,13 +535,16 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st } } -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, const std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { { std::string unigram_name = file_prefix + "unigrams"; util::scoped_fd unigram_file; - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); + // In case <unk> appears. + size_t extra_count = counts[0] + 1; + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get())); CheckSpecials(config, vocab); + if (!vocab.SawUnk()) ++counts[0]; } // Only use as much buffer as we need. @@ -572,7 +575,7 @@ bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const W return true; } -// Counting phrase +// Phase to count n-grams, including blanks inserted because they were pruned but have extensions class JustCount { public: JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) @@ -603,6 +606,7 @@ class JustCount { uint64_t *const counts_, *const longest_counts_; }; +// Phase to actually write n-grams to the trie. class WriteEntries { public: WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : @@ -764,7 +768,7 @@ template <class Doing> class RecursiveInsert { void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) { if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]); - if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant"); + if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant but it changed from " << initial.back() << " to " << fixed.back()); for (unsigned char i = 0; i < initial.size(); ++i) { if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected. This shouldn't happen"); } @@ -789,6 +793,9 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co RecursiveInsert<JustCount> counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } + for (SortedFileReader *i = inputs; i < inputs + counts.size() - 1; ++i) { + if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); + } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; @@ -805,7 +812,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co } // Fill unigram probabilities. - { + try { std::string name(file_prefix + "unigrams"); util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); for (WordIndex i = 0; i < counts[0]; ++i) { @@ -816,6 +823,9 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co } } RemoveOrThrow(name.c_str()); + } catch (util::Exception &e) { + e << " while re-reading unigram probabilities"; + throw; } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. |