summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r--klm/lm/search_trie.cc20
1 files changed, 15 insertions, 5 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 63631223..b830dfc3 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -535,13 +535,16 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
}
}
-void ARPAToSortedFiles(const Config &config, util::FilePiece &f, const std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
+void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
{
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
- util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff));
+ // In case <unk> appears.
+ size_t extra_count = counts[0] + 1;
+ util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff));
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));
CheckSpecials(config, vocab);
+ if (!vocab.SawUnk()) ++counts[0];
}
// Only use as much buffer as we need.
@@ -572,7 +575,7 @@ bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const W
return true;
}
-// Counting phrase
+// Phase to count n-grams, including blanks inserted because they were pruned but have extensions
class JustCount {
public:
JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order)
@@ -603,6 +606,7 @@ class JustCount {
uint64_t *const counts_, *const longest_counts_;
};
+// Phase to actually write n-grams to the trie.
class WriteEntries {
public:
WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) :
@@ -764,7 +768,7 @@ template <class Doing> class RecursiveInsert {
void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {
if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]);
- if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant");
+ if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant but it changed from " << initial.back() << " to " << fixed.back());
for (unsigned char i = 0; i < initial.size(); ++i) {
if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected. This shouldn't happen");
}
@@ -789,6 +793,9 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
RecursiveInsert<JustCount> counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
}
+ for (SortedFileReader *i = inputs; i < inputs + counts.size() - 1; ++i) {
+ if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading");
+ }
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
@@ -805,7 +812,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
}
// Fill unigram probabilities.
- {
+ try {
std::string name(file_prefix + "unigrams");
util::scoped_FILE file(OpenOrThrow(name.c_str(), "r"));
for (WordIndex i = 0; i < counts[0]; ++i) {
@@ -816,6 +823,9 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
}
}
RemoveOrThrow(name.c_str());
+ } catch (util::Exception &e) {
+ e << " while re-reading unigram probabilities";
+ throw;
}
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.