summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r--klm/lm/search_trie.cc50
1 files changed, 27 insertions, 23 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 1060ddef..63631223 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -11,6 +11,7 @@
#include "lm/word_index.hh"
#include "util/ersatz_progress.hh"
#include "util/file_piece.hh"
+#include "util/have.hh"
#include "util/proxy_iterator.hh"
#include "util/scoped.hh"
@@ -20,7 +21,6 @@
#include <cstdio>
#include <deque>
#include <limits>
-//#include <parallel/algorithm>
#include <vector>
#include <sys/mman.h>
@@ -170,7 +170,7 @@ template <class Proxy> class CompareRecords : public std::binary_function<const
return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices());
}
bool operator()(const std::string &first, const std::string &second) const {
- return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(first.data()));
+ return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(second.data()));
}
private:
@@ -384,7 +384,6 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil
PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size));
- // TODO: __gnu_parallel::sort here.
std::sort(context_begin, context_end, CompareRecords<PartialViewProxy>(order - 1));
std::string name(ngram_file_name + kContextSuffix);
@@ -406,16 +405,16 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil
class ContextReader {
public:
- ContextReader() : length_(0) {}
+ ContextReader() : valid_(false) {}
- ContextReader(const char *name, size_t length) : file_(OpenOrThrow(name, "r")), length_(length), words_(length), valid_(true) {
- ++*this;
+ ContextReader(const char *name, unsigned char order) {
+ Reset(name, order);
}
- void Reset(const char *name, size_t length) {
+ void Reset(const char *name, unsigned char order) {
file_.reset(OpenOrThrow(name, "r"));
- length_ = length;
- words_.resize(length);
+ length_ = sizeof(WordIndex) * static_cast<size_t>(order);
+ words_.resize(order);
valid_ = true;
++*this;
}
@@ -449,14 +448,14 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_
const size_t context_size = sizeof(WordIndex) * (order - 1);
std::string first_name(first_base + kContextSuffix);
std::string second_name(second_base + kContextSuffix);
- ContextReader first(first_name.c_str(), context_size), second(second_name.c_str(), context_size);
+ ContextReader first(first_name.c_str(), order - 1), second(second_name.c_str(), order - 1);
RemoveOrThrow(first_name.c_str());
RemoveOrThrow(second_name.c_str());
std::string out_name(out_base + kContextSuffix);
util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w"));
while (first && second) {
for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) {
- if (f == *first + order) {
+ if (f == *first + order - 1) {
// Equal.
WriteOrThrow(out.get(), *first, context_size);
++first;
@@ -475,7 +474,10 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_
}
}
}
- CopyRestOrThrow((first ? first : second).GetFile(), out.get());
+ ContextReader &remaining = first ? first : second;
+ if (!remaining) return;
+ WriteOrThrow(out.get(), *remaining, context_size);
+ CopyRestOrThrow(remaining.GetFile(), out.get());
}
void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) {
@@ -502,7 +504,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
}
// Sort full records by full n-gram.
EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
- // TODO: __gnu_parallel::sort here.
+ // parallel_sort uses too much RAM
std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords<EntryProxy>(order));
files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size));
WriteContextFile(begin, out_end, files.back(), entry_size, order);
@@ -533,21 +535,22 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
}
}
-void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
+void ARPAToSortedFiles(const Config &config, util::FilePiece &f, const std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
{
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff));
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));
+ CheckSpecials(config, vocab);
}
// Only use as much buffer as we need.
size_t buffer_use = 0;
for (unsigned int order = 2; order < counts.size(); ++order) {
- buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1]));
+ buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1]));
}
- buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
- buffer = std::min(buffer, buffer_use);
+ buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
+ buffer = std::min<size_t>(buffer, buffer_use);
util::scoped_memory mem;
mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED);
@@ -767,7 +770,7 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u
}
}
-void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) {
+void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) {
SortedFileReader inputs[counts.size() - 1];
ContextReader contexts[counts.size() - 1];
@@ -777,7 +780,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
inputs[i-2].Init(assembled.str(), i);
RemoveOrThrow(assembled.str().c_str());
assembled << kContextSuffix;
- contexts[i-2].Reset(assembled.str().c_str(), (i-1) * sizeof(WordIndex));
+ contexts[i-2].Reset(assembled.str().c_str(), i-1);
RemoveOrThrow(assembled.str().c_str());
}
@@ -787,8 +790,9 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
}
SanityCheckCounts(counts, fixed_counts);
+ counts = fixed_counts;
- out.SetupMemory(GrowForSearch(config, TrieSearch::kModelType, fixed_counts, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config);
+ out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config);
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
@@ -811,7 +815,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
++contexts[0];
}
}
- unlink(name.c_str());
+ RemoveOrThrow(name.c_str());
}
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
@@ -823,7 +827,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
for (const WordIndex *i = *context; i != *context + order - 1; ++i) {
e << ' ' << *i;
}
- e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not.";
+ e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not";
throw e;
}
}
@@ -868,7 +872,7 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v
// Add directory delimiter. Assumes a real operating system.
temporary_directory += '/';
// At least 1MB sorting memory.
- ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
+ ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
BuildTrie(temporary_directory, counts, config, *this, backing);
if (rmdir(temporary_directory.c_str()) && config.messages) {