diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-11-10 02:02:04 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-11-10 02:02:04 +0000 |
commit | 15b03336564d5e57e50693f19dd81b45076af5d4 (patch) | |
tree | c2072893a43f4c75f0ad5ebe3080bfa901faf18f /klm/lm/search_trie.cc | |
parent | 1336aecfe930546f8836ffe65dd5ff78434084eb (diff) |
new version of klm
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@706 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 402 |
1 files changed, 402 insertions, 0 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc new file mode 100644 index 00000000..182e27f5 --- /dev/null +++ b/klm/lm/search_trie.cc @@ -0,0 +1,402 @@ +#include "lm/search_trie.hh" + +#include "lm/lm_exception.hh" +#include "lm/read_arpa.hh" +#include "lm/trie.hh" +#include "lm/vocab.hh" +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/ersatz_progress.hh" +#include "util/file_piece.hh" +#include "util/scoped.hh" + +#include <algorithm> +#include <cstring> +#include <cstdio> +#include <deque> +#include <iostream> +#include <limits> +//#include <parallel/algorithm> +#include <vector> + +#include <sys/mman.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <fcntl.h> +#include <stdlib.h> + +namespace lm { +namespace ngram { +namespace trie { +namespace { + +template <unsigned char Order> class FullEntry { + public: + typedef ProbBackoff Weights; + static const unsigned char kOrder = Order; + + // reverse order + WordIndex words[Order]; + Weights weights; + + bool operator<(const FullEntry<Order> &other) const { + for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { + if (*i < *j) return true; + if (*i > *j) return false; + } + return false; + } +}; + +template <unsigned char Order> class ProbEntry { + public: + typedef Prob Weights; + static const unsigned char kOrder = Order; + + // reverse order + WordIndex words[Order]; + Weights weights; + + bool operator<(const ProbEntry<Order> &other) const { + for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { + if (*i < *j) return true; + if (*i > *j) return false; + } + return false; + } +}; + +void WriteOrThrow(FILE *to, const void *data, size_t size) { + if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); +} + +void ReadOrThrow(FILE *from, void *data, size_t size) { + if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size); +} + +void CopyOrThrow(FILE *from, FILE *to, size_t size) { + const size_t kBufSize = 512; + char buf[kBufSize]; + for (size_t i = 0; i < size; i += kBufSize) { + std::size_t amount = std::min(size - i, kBufSize); + ReadOrThrow(from, buf, amount); + WriteOrThrow(to, buf, amount); + } +} + +template <class Entry> std::string DiskFlush(const Entry *begin, const Entry *end, const std::string &file_prefix, std::size_t batch) { + std::stringstream assembled; + assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << '_' << batch; + std::string ret(assembled.str()); + util::scoped_FILE out(fopen(ret.c_str(), "w")); + if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing"); + for (const Entry *group_begin = begin; group_begin != end;) { + const Entry *group_end = group_begin; + for (++group_end; (group_end != end) && !memcmp(group_begin->words, group_end->words, sizeof(WordIndex) * (Entry::kOrder - 1)); ++group_end) {} + WriteOrThrow(out.get(), group_begin->words, sizeof(WordIndex) * (Entry::kOrder - 1)); + WordIndex group_size = group_end - group_begin; + WriteOrThrow(out.get(), &group_size, sizeof(group_size)); + for (const Entry *i = group_begin; i != group_end; ++i) { + WriteOrThrow(out.get(), &i->words[Entry::kOrder - 1], sizeof(WordIndex)); + WriteOrThrow(out.get(), &i->weights, sizeof(typename Entry::Weights)); + } + group_begin = group_end; + } + return ret; +} + +class SortedFileReader { + public: + SortedFileReader() {} + + void Init(const std::string &name, unsigned char order) { + file_.reset(fopen(name.c_str(), "r")); + if (!file_.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " for read"); + header_.resize(order - 1); + NextHeader(); + } + + // Preceding words. + const WordIndex *Header() const { + return &*header_.begin(); + } + const std::vector<WordIndex> &HeaderVector() const { return header_;} + + std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); } + + void NextHeader() { + if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get()) && !Ended()) { + UTIL_THROW(util::ErrnoException, "Short read of counts"); + } + } + + void ReadCount(WordIndex &to) { + ReadOrThrow(file_.get(), &to, sizeof(WordIndex)); + } + + void ReadWord(WordIndex &to) { + ReadOrThrow(file_.get(), &to, sizeof(WordIndex)); + } + + template <class Weights> void ReadWeights(Weights &to) { + ReadOrThrow(file_.get(), &to, sizeof(Weights)); + } + + bool Ended() { + return feof(file_.get()); + } + + FILE *File() { return file_.get(); } + + private: + util::scoped_FILE file_; + + std::vector<WordIndex> header_; +}; + +void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) { + WriteOrThrow(to, from.Header(), from.HeaderBytes()); + WordIndex count; + from.ReadCount(count); + WriteOrThrow(to, &count, sizeof(WordIndex)); + + CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); +} + +void MergeSortedFiles(const char *first_name, const char *second_name, const char *out, std::size_t weights_size, unsigned char order) { + SortedFileReader first, second; + first.Init(first_name, order); + second.Init(second_name, order); + util::scoped_FILE out_file(fopen(out, "w")); + if (!out_file.get()) UTIL_THROW(util::ErrnoException, "Could not open " << out << " for write"); + while (!first.Ended() && !second.Ended()) { + if (first.HeaderVector() < second.HeaderVector()) { + CopyFullRecord(first, out_file.get(), weights_size); + first.NextHeader(); + continue; + } + if (first.HeaderVector() > second.HeaderVector()) { + CopyFullRecord(second, out_file.get(), weights_size); + second.NextHeader(); + continue; + } + // Merge at the entry level. + WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes()); + WordIndex first_count, second_count; + first.ReadCount(first_count); second.ReadCount(second_count); + WordIndex total_count = first_count + second_count; + WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex)); + + WordIndex first_word, second_word; + first.ReadWord(first_word); second.ReadWord(second_word); + WordIndex first_index = 0, second_index = 0; + while (true) { + if (first_word < second_word) { + WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); + CopyOrThrow(first.File(), out_file.get(), weights_size); + if (++first_index == first_count) break; + first.ReadWord(first_word); + } else { + WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); + CopyOrThrow(second.File(), out_file.get(), weights_size); + if (++second_index == second_count) break; + second.ReadWord(second_word); + } + } + if (first_index == first_count) { + WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); + CopyOrThrow(second.File(), out_file.get(), (second_count - second_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); + } else { + WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); + CopyOrThrow(first.File(), out_file.get(), (first_count - first_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); + } + first.NextHeader(); + second.NextHeader(); + } + + for (SortedFileReader &remaining = first.Ended() ? second : first; !remaining.Ended(); remaining.NextHeader()) { + CopyFullRecord(remaining, out_file.get(), weights_size); + } +} + +template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix) { + ConvertToSorted<FullEntry<Entry::kOrder - 1> >(f, vocab, counts, mem, file_prefix); + + ReadNGramHeader(f, Entry::kOrder); + const size_t count = counts[Entry::kOrder - 1]; + const size_t batch_size = std::min(count, mem.size() / sizeof(Entry)); + Entry *const begin = reinterpret_cast<Entry*>(mem.get()); + std::deque<std::string> files; + for (std::size_t batch = 0, done = 0; done < count; ++batch) { + Entry *out = begin; + Entry *out_end = out + std::min(count - done, batch_size); + for (; out != out_end; ++out) { + ReadNGram(f, Entry::kOrder, vocab, out->words, out->weights); + } + //__gnu_parallel::sort(begin, out_end); + std::sort(begin, out_end); + + files.push_back(DiskFlush(begin, out_end, file_prefix, batch)); + done += out_end - begin; + } + + // All individual files created. Merge them. + + std::size_t merge_count = 0; + while (files.size() > 1) { + std::stringstream assembled; + assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merge_" << (merge_count++); + files.push_back(assembled.str()); + MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), sizeof(typename Entry::Weights), Entry::kOrder); + if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); + files.pop_front(); + if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); + files.pop_front(); + } + if (!files.empty()) { + std::stringstream assembled; + assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merged"; + std::string merged_name(assembled.str()); + if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); + } +} + +template <> void ConvertToSorted<FullEntry<1> >(util::FilePiece &/*f*/, const SortedVocabulary &/*vocab*/, const std::vector<uint64_t> &/*counts*/, util::scoped_memory &/*mem*/, const std::string &/*file_prefix*/) {} + +void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + { + std::string unigram_name = file_prefix + "unigrams"; + util::scoped_fd unigram_file; + util::scoped_mmap unigram_mmap; + unigram_mmap.reset(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); + Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get())); + } + + util::scoped_memory mem; + mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED); + if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); + ConvertToSorted<ProbEntry<5> >(f, vocab, counts, mem, file_prefix); + ReadEnd(f); +} + +struct RecursiveInsertParams { + WordIndex *words; + SortedFileReader *files; + unsigned char max_order; + // This is an array of size order - 2. + BitPackedMiddle *middle; + // This has exactly one entry. + BitPackedLongest *longest; +}; + +uint64_t RecursiveInsert(RecursiveInsertParams ¶ms, unsigned char order) { + SortedFileReader &file = params.files[order - 2]; + const uint64_t ret = (order == params.max_order) ? params.longest->InsertIndex() : params.middle[order - 2].InsertIndex(); + if (std::memcmp(params.words, file.Header(), sizeof(WordIndex) * (order - 1))) + return ret; + WordIndex count; + file.ReadCount(count); + WordIndex key; + if (order == params.max_order) { + Prob value; + for (WordIndex i = 0; i < count; ++i) { + file.ReadWord(key); + file.ReadWeights(value); + params.longest->Insert(key, value.prob); + } + file.NextHeader(); + return ret; + } + ProbBackoff value; + for (WordIndex i = 0; i < count; ++i) { + file.ReadWord(params.words[order - 1]); + file.ReadWeights(value); + params.middle[order - 2].Insert( + params.words[order - 1], + value.prob, + value.backoff, + RecursiveInsert(params, order + 1)); + } + file.NextHeader(); + return ret; +} + +void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, std::ostream *messages, TrieSearch &out) { + UnigramValue *unigrams = out.unigram.Raw(); + // Load unigrams. Leave the next pointers uninitialized. + { + std::string name(file_prefix + "unigrams"); + util::scoped_FILE file(fopen(name.c_str(), "r")); + if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed"); + for (WordIndex i = 0; i < counts[0]; ++i) { + ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); + } + unlink(name.c_str()); + } + + // inputs[0] is bigrams. + SortedFileReader inputs[counts.size() - 1]; + for (unsigned char i = 2; i <= counts.size(); ++i) { + std::stringstream assembled; + assembled << file_prefix << static_cast<unsigned int>(i) << "_merged"; + inputs[i-2].Init(assembled.str(), i); + unlink(assembled.str().c_str()); + } + + // words[0] is unigrams. + WordIndex words[counts.size()]; + RecursiveInsertParams params; + params.words = words; + params.files = inputs; + params.max_order = static_cast<unsigned char>(counts.size()); + params.middle = &*out.middle.begin(); + params.longest = &out.longest; + { + util::ErsatzProgress progress(messages, "Building trie", counts[0]); + for (words[0] = 0; words[0] < counts[0]; ++words[0], ++progress) { + unigrams[words[0]].next = RecursiveInsert(params, 2); + } + } + + /* Set ending offsets so the last entry will be sized properly */ + if (!out.middle.empty()) { + unigrams[counts[0]].next = out.middle.front().InsertIndex(); + for (size_t i = 0; i < out.middle.size() - 1; ++i) { + out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); + } + out.middle.back().FinishedLoading(out.longest.InsertIndex()); + } else { + unigrams[counts[0]].next = out.longest.InsertIndex(); + } +} + +} // namespace + +void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab) { + std::string temporary_directory; + if (config.temporary_directory_prefix) { + temporary_directory = config.temporary_directory_prefix; + } else if (config.write_mmap) { + temporary_directory = config.write_mmap; + } else { + temporary_directory = file; + } + // Null on end is kludge to ensure null termination. + temporary_directory += "-tmp-XXXXXX\0"; + if (!mkdtemp(&temporary_directory[0])) { + UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str()); + } + // Chop off null kludge. + temporary_directory.resize(strlen(temporary_directory.c_str())); + // Add directory delimiter. Assumes a real operating system. + temporary_directory += '/'; + ARPAToSortedFiles(f, counts, config.building_memory, temporary_directory.c_str(), vocab); + BuildTrie(temporary_directory.c_str(), counts, config.messages, *this); + if (rmdir(temporary_directory.c_str())) { + std::cerr << "Failed to delete " << temporary_directory << std::endl; + } +} + +} // namespace trie +} // namespace ngram +} // namespace lm |