From 0b9031042500d45a098762f0a930bd6a66a58fac Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 18 Jan 2013 17:12:51 +0000 Subject: KenLM dffafbf with lmplz source (but not built) --- klm/lm/Makefile.am | 4 +- klm/lm/build_binary.cc | 77 +++----- klm/lm/builder/README.md | 47 +++++ klm/lm/builder/TODO | 5 + klm/lm/builder/adjust_counts.cc | 216 +++++++++++++++++++++ klm/lm/builder/adjust_counts.hh | 44 +++++ klm/lm/builder/adjust_counts_test.cc | 106 +++++++++++ klm/lm/builder/corpus_count.cc | 223 ++++++++++++++++++++++ klm/lm/builder/corpus_count.hh | 42 +++++ klm/lm/builder/corpus_count_test.cc | 76 ++++++++ klm/lm/builder/discount.hh | 26 +++ klm/lm/builder/header_info.hh | 20 ++ klm/lm/builder/initial_probabilities.cc | 136 ++++++++++++++ klm/lm/builder/initial_probabilities.hh | 34 ++++ klm/lm/builder/interpolate.cc | 65 +++++++ klm/lm/builder/interpolate.hh | 27 +++ klm/lm/builder/joint_order.hh | 43 +++++ klm/lm/builder/main.cc | 94 ++++++++++ klm/lm/builder/multi_stream.hh | 180 ++++++++++++++++++ klm/lm/builder/ngram.hh | 84 +++++++++ klm/lm/builder/ngram_stream.hh | 55 ++++++ klm/lm/builder/pipeline.cc | 320 ++++++++++++++++++++++++++++++++ klm/lm/builder/pipeline.hh | 40 ++++ klm/lm/builder/print.cc | 135 ++++++++++++++ klm/lm/builder/print.hh | 102 ++++++++++ klm/lm/builder/sort.hh | 103 ++++++++++ klm/lm/filter/arpa_io.cc | 122 ++++++++++++ klm/lm/filter/arpa_io.hh | 122 ++++++++++++ klm/lm/filter/count_io.hh | 91 +++++++++ klm/lm/filter/format.hh | 250 +++++++++++++++++++++++++ klm/lm/filter/main.cc | 249 +++++++++++++++++++++++++ klm/lm/filter/phrase.cc | 281 ++++++++++++++++++++++++++++ klm/lm/filter/phrase.hh | 153 +++++++++++++++ klm/lm/filter/thread.hh | 167 +++++++++++++++++ klm/lm/filter/vocab.cc | 54 ++++++ klm/lm/filter/vocab.hh | 132 +++++++++++++ klm/lm/filter/wrapper.hh | 58 ++++++ klm/lm/model_test.cc | 10 +- klm/lm/read_arpa.cc | 11 +- klm/lm/sizes.cc | 63 +++++++ klm/lm/sizes.hh | 17 ++ klm/lm/state.hh | 6 +- klm/lm/trie_sort.cc | 27 ++- klm/lm/trie_sort.hh | 3 +- 44 files changed, 4043 insertions(+), 77 deletions(-) create mode 100644 klm/lm/builder/README.md create mode 100644 klm/lm/builder/TODO create mode 100644 klm/lm/builder/adjust_counts.cc create mode 100644 klm/lm/builder/adjust_counts.hh create mode 100644 klm/lm/builder/adjust_counts_test.cc create mode 100644 klm/lm/builder/corpus_count.cc create mode 100644 klm/lm/builder/corpus_count.hh create mode 100644 klm/lm/builder/corpus_count_test.cc create mode 100644 klm/lm/builder/discount.hh create mode 100644 klm/lm/builder/header_info.hh create mode 100644 klm/lm/builder/initial_probabilities.cc create mode 100644 klm/lm/builder/initial_probabilities.hh create mode 100644 klm/lm/builder/interpolate.cc create mode 100644 klm/lm/builder/interpolate.hh create mode 100644 klm/lm/builder/joint_order.hh create mode 100644 klm/lm/builder/main.cc create mode 100644 klm/lm/builder/multi_stream.hh create mode 100644 klm/lm/builder/ngram.hh create mode 100644 klm/lm/builder/ngram_stream.hh create mode 100644 klm/lm/builder/pipeline.cc create mode 100644 klm/lm/builder/pipeline.hh create mode 100644 klm/lm/builder/print.cc create mode 100644 klm/lm/builder/print.hh create mode 100644 klm/lm/builder/sort.hh create mode 100644 klm/lm/filter/arpa_io.cc create mode 100644 klm/lm/filter/arpa_io.hh create mode 100644 klm/lm/filter/count_io.hh create mode 100644 klm/lm/filter/format.hh create mode 100644 klm/lm/filter/main.cc create mode 100644 klm/lm/filter/phrase.cc create mode 100644 klm/lm/filter/phrase.hh create mode 100644 klm/lm/filter/thread.hh create mode 100644 klm/lm/filter/vocab.cc create mode 100644 klm/lm/filter/vocab.hh create mode 100644 klm/lm/filter/wrapper.hh create mode 100644 klm/lm/sizes.cc create mode 100644 klm/lm/sizes.hh (limited to 'klm/lm') diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 870f7128..f15cbd77 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = build_binary build_binary_SOURCES = build_binary.cc -build_binary_LDADD = libklm.a ../util/libklm_util.a -lz +build_binary_LDADD = libklm.a ../util/libklm_util.a ../util/double-conversion/libklm_util_double.a -lz #noinst_PROGRAMS = \ # ngram_test @@ -30,6 +30,7 @@ libklm_a_SOURCES = \ return.hh \ search_hashed.hh \ search_trie.hh \ + sizes.hh \ state.hh \ trie.hh \ trie_sort.hh \ @@ -49,6 +50,7 @@ libklm_a_SOURCES = \ read_arpa.cc \ search_hashed.cc \ search_trie.cc \ + sizes.cc \ trie.cc \ trie_sort.cc \ value_build.cc \ diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 2b8c9d5b..ab2c0c32 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -1,10 +1,14 @@ #include "lm/model.hh" +#include "lm/sizes.hh" #include "util/file_piece.hh" +#include "util/usage.hh" +#include #include #include #include #include +#include #include #include @@ -19,8 +23,8 @@ namespace lm { namespace ngram { namespace { -void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" +void Usage(const char *name, const char *default_mem) { + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" "-u sets the log10 probability for if the ARPA file does not have one.\n" " Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have and .\n" @@ -38,8 +42,11 @@ void Usage(const char *name) { "trie is a straightforward trie with bit-level packing. It uses the least\n" "memory and is still faster than SRI or IRST. Building the trie format uses an\n" "on-disk sort to save memory.\n" -"-t is the temporary directory prefix. Default is the output file name.\n" -"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n" +"-T is the temporary directory prefix. Default is the output file name.\n" +"-S determines memory use for sorting. Default is " << default_mem << ". This is compatible\n" +" with GNU sort. The number is followed by a unit: \% for percent of physical\n" +" memory, b for bytes, K for Kilobytes, M for megabytes, then G,T,P,E,Z,Y. \n" +" Default unit is K for Kilobytes.\n" "-q turns quantization on and sets the number of bits (e.g. -q 8).\n" "-b sets backoff quantization bits. Requires -q and defaults to that value.\n" "-a compresses pointers using an array of offsets. The parameter is the\n" @@ -83,47 +90,6 @@ void ParseFileList(const char *from, std::vector &to) { } } -void ShowSizes(const char *file, const lm::ngram::Config &config) { - std::vector counts; - util::FilePiece f(file); - lm::ReadARPACounts(f, counts); - uint64_t sizes[6]; - sizes[0] = ProbingModel::Size(counts, config); - sizes[1] = RestProbingModel::Size(counts, config); - sizes[2] = TrieModel::Size(counts, config); - sizes[3] = QuantTrieModel::Size(counts, config); - sizes[4] = ArrayTrieModel::Size(counts, config); - sizes[5] = QuantArrayTrieModel::Size(counts, config); - uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); - uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); - uint64_t divide; - char prefix; - if (min_length < (1 << 10) * 10) { - prefix = ' '; - divide = 1; - } else if (min_length < (1 << 20) * 10) { - prefix = 'k'; - divide = 1 << 10; - } else if (min_length < (1ULL << 30) * 10) { - prefix = 'M'; - divide = 1 << 20; - } else { - prefix = 'G'; - divide = 1 << 30; - } - long int length = std::max(2, static_cast(ceil(log10((double) max_length / divide)))); - std::cout << "Memory estimate:\ntype "; - // right align bytes. - for (long int i = 0; i < length - 2; ++i) std::cout << ' '; - std::cout << prefix << "B\n" - "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" - "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n" - "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n" - "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" - "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" - "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; -} - void ProbingQuantizationUnsupported() { std::cerr << "Quantization is only implemented in the trie data structure." << std::endl; exit(1); @@ -136,11 +102,14 @@ void ProbingQuantizationUnsupported() { int main(int argc, char *argv[]) { using namespace lm::ngram; + const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G"; + try { bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false; lm::ngram::Config config; + config.building_memory = util::ParseSize(default_mem); int opt; - while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:sir:")) != -1) { + while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -161,12 +130,16 @@ int main(int argc, char *argv[]) { case 'p': config.probing_multiplier = ParseFloat(optarg); break; - case 't': + case 't': // legacy + case 'T': config.temporary_directory_prefix = optarg; break; - case 'm': + case 'm': // legacy config.building_memory = ParseUInt(optarg) * 1048576; break; + case 'S': + config.building_memory = std::min(static_cast(std::numeric_limits::max()), util::ParseSize(optarg)); + break; case 'w': set_write_method = true; if (!strcmp(optarg, "mmap")) { @@ -174,7 +147,7 @@ int main(int argc, char *argv[]) { } else if (!strcmp(optarg, "after")) { config.write_method = Config::WRITE_AFTER; } else { - Usage(argv[0]); + Usage(argv[0], default_mem); } break; case 's': @@ -189,7 +162,7 @@ int main(int argc, char *argv[]) { config.rest_function = Config::REST_LOWER; break; default: - Usage(argv[0]); + Usage(argv[0], default_mem); } } if (!quantize && set_backoff_bits) { @@ -212,7 +185,7 @@ int main(int argc, char *argv[]) { from_file = argv[optind + 1]; config.write_mmap = argv[optind + 2]; } else { - Usage(argv[0]); + Usage(argv[0], default_mem); } if (!strcmp(model_type, "probing")) { if (!set_write_method) config.write_method = Config::WRITE_AFTER; @@ -242,7 +215,7 @@ int main(int argc, char *argv[]) { } } } else { - Usage(argv[0]); + Usage(argv[0], default_mem); } } catch (const std::exception &e) { diff --git a/klm/lm/builder/README.md b/klm/lm/builder/README.md new file mode 100644 index 00000000..be0d35e2 --- /dev/null +++ b/klm/lm/builder/README.md @@ -0,0 +1,47 @@ +Dependencies +============ + +Boost >= 1.42.0 is required. + +For Ubuntu, +```bash +sudo apt-get install libboost1.48-all-dev +``` + +Alternatively, you can download, compile, and install it yourself: + +```bash +wget http://sourceforge.net/projects/boost/files/boost/1.52.0/boost_1_52_0.tar.gz/download -O boost_1_52_0.tar.gz +tar -xvzf boost_1_52_0.tar.gz +cd boost_1_52_0 +./bootstrap.sh +./b2 +sudo ./b2 install +``` + +Local install options (in a user-space prefix directory) are also possible. See http://www.boost.org/doc/libs/1_52_0/doc/html/bbv2/installation.html. + + +Building +======== + +```bash +bjam +``` +Your distribution might package bjam and boost-build separately from Boost. Both are required. + +Usage +===== + +Run +```bash +$ bin/lmplz +``` +to see command line arguments + +Running +======= + +```bash +bin/lmplz -o 5 text.arpa +``` diff --git a/klm/lm/builder/TODO b/klm/lm/builder/TODO new file mode 100644 index 00000000..cb5aef3a --- /dev/null +++ b/klm/lm/builder/TODO @@ -0,0 +1,5 @@ +More tests! +Sharding. +Some way to manage all the crazy config options. +Option to build the binary file directly. +Interpolation of different orders. diff --git a/klm/lm/builder/adjust_counts.cc b/klm/lm/builder/adjust_counts.cc new file mode 100644 index 00000000..a6f48011 --- /dev/null +++ b/klm/lm/builder/adjust_counts.cc @@ -0,0 +1,216 @@ +#include "lm/builder/adjust_counts.hh" +#include "lm/builder/multi_stream.hh" +#include "util/stream/timer.hh" + +#include + +namespace lm { namespace builder { + +BadDiscountException::BadDiscountException() throw() {} +BadDiscountException::~BadDiscountException() throw() {} + +namespace { +// Return last word in full that is different. +const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) { + const WordIndex *cur_word = full.end() - 1; + const WordIndex *pre_word = lower_last.end() - 1; + // Find last difference. + for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {} + return cur_word; +} + +class StatCollector { + public: + StatCollector(std::size_t order, std::vector &counts, std::vector &discounts) + : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) { + memset(&orders_[0], 0, sizeof(OrderStat) * order); + } + + ~StatCollector() {} + + void CalculateDiscounts() { + counts_.resize(orders_.size()); + discounts_.resize(orders_.size()); + for (std::size_t i = 0; i < orders_.size(); ++i) { + const OrderStat &s = orders_[i]; + counts_[i] = s.count; + + for (unsigned j = 1; j < 4; ++j) { + // TODO: Specialize error message for j == 3, meaning 3+ + UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for " + << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any " + << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?"); + } + + // See equation (26) in Chen and Goodman. + discounts_[i].amount[0] = 0.0; + float y = static_cast(s.n[1]) / static_cast(s.n[1] + 2.0 * s.n[2]); + for (unsigned j = 1; j < 4; ++j) { + discounts_[i].amount[j] = static_cast(j) - static_cast(j + 1) * y * static_cast(s.n[j+1]) / static_cast(s.n[j]); + UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]); + } + } + } + + void Add(std::size_t order_minus_1, uint64_t count) { + OrderStat &stat = orders_[order_minus_1]; + ++stat.count; + if (count < 5) ++stat.n[count]; + } + + void AddFull(uint64_t count) { + ++full_.count; + if (count < 5) ++full_.n[count]; + } + + private: + struct OrderStat { + // n_1 in equation 26 of Chen and Goodman etc + uint64_t n[5]; + uint64_t count; + }; + + std::vector orders_; + OrderStat &full_; + + std::vector &counts_; + std::vector &discounts_; +}; + +// Reads all entries in order like NGramStream does. +// But deletes any entries that have in the 1st (not 0th) position on the +// way out by putting other entries in their place. This disrupts the sort +// order but we don't care because the data is going to be sorted again. +class CollapseStream { + public: + CollapseStream(const util::stream::ChainPosition &position) : + current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), + block_(position) { + StartBlock(); + } + + const NGram &operator*() const { return current_; } + const NGram *operator->() const { return ¤t_; } + + operator bool() const { return block_; } + + CollapseStream &operator++() { + assert(block_); + if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) { + memcpy(current_.Base(), copy_from_, current_.TotalSize()); + UpdateCopyFrom(); + } + current_.NextInMemory(); + uint8_t *block_base = static_cast(block_->Get()); + if (current_.Base() == block_base + block_->ValidSize()) { + block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base); + ++block_; + StartBlock(); + } + return *this; + } + + private: + void StartBlock() { + for (; ; ++block_) { + if (!block_) return; + if (block_->ValidSize()) break; + } + current_.ReBase(block_->Get()); + copy_from_ = static_cast(block_->Get()) + block_->ValidSize(); + UpdateCopyFrom(); + } + + // Find last without bos. + void UpdateCopyFrom() { + for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) { + if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break; + } + } + + NGram current_; + + // Goes backwards in the block + uint8_t *copy_from_; + + util::stream::Link block_; +}; + +} // namespace + +void AdjustCounts::Run(const ChainPositions &positions) { + UTIL_TIMER("(%w s) Adjusted counts\n"); + + const std::size_t order = positions.size(); + StatCollector stats(order, counts_, discounts_); + if (order == 1) { + // Only unigrams. Just collect stats. + for (NGramStream full(positions[0]); full; ++full) + stats.AddFull(full->Count()); + stats.CalculateDiscounts(); + return; + } + + NGramStreams streams; + streams.Init(positions, positions.size() - 1); + CollapseStream full(positions[positions.size() - 1]); + + // Initialization: has count 0 and so does . + NGramStream *lower_valid = streams.begin(); + streams[0]->Count() = 0; + *streams[0]->begin() = kUNK; + stats.Add(0, 0); + (++streams[0])->Count() = 0; + *streams[0]->begin() = kBOS; + // not in stats because it will get put in later. + + // iterate over full (the stream of the highest order ngrams) + for (; full; ++full) { + const WordIndex *different = FindDifference(*full, **lower_valid); + std::size_t same = full->end() - 1 - different; + // Increment the adjusted count. + if (same) ++streams[same - 1]->Count(); + + // Output all the valid ones that changed. + for (; lower_valid >= &streams[same]; --lower_valid) { + stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count()); + ++*lower_valid; + } + + // This is here because bos is also const WordIndex *, so copy gets + // consistent argument types. + const WordIndex *full_end = full->end(); + // Initialize and mark as valid up to bos. + const WordIndex *bos; + for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) { + ++lower_valid; + std::copy(bos, full_end, (*lower_valid)->begin()); + (*lower_valid)->Count() = 1; + } + // Now bos indicates where is or is the 0th word of full. + if (bos != full->begin()) { + // There is an beyond the 0th word. + NGramStream &to = *++lower_valid; + std::copy(bos, full_end, to->begin()); + to->Count() = full->Count(); + } else { + stats.AddFull(full->Count()); + } + assert(lower_valid >= &streams[0]); + } + + // Output everything valid. + for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) { + stats.Add(s - streams.begin(), (*s)->Count()); + ++*s; + } + // Poison everyone! Except the N-grams which were already poisoned by the input. + for (NGramStream *s = streams.begin(); s != streams.end(); ++s) + s->Poison(); + + stats.CalculateDiscounts(); + + // NOTE: See special early-return case for unigrams near the top of this function +} + +}} // namespaces diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh new file mode 100644 index 00000000..f38ff79d --- /dev/null +++ b/klm/lm/builder/adjust_counts.hh @@ -0,0 +1,44 @@ +#ifndef LM_BUILDER_ADJUST_COUNTS__ +#define LM_BUILDER_ADJUST_COUNTS__ + +#include "lm/builder/discount.hh" +#include "util/exception.hh" + +#include + +#include + +namespace lm { +namespace builder { + +class ChainPositions; + +class BadDiscountException : public util::Exception { + public: + BadDiscountException() throw(); + ~BadDiscountException() throw(); +}; + +/* Compute adjusted counts. + * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. + * Output: [1,N]-grams with adjusted counts. + * [1,N)-grams are in suffix order + * N-grams are in undefined order (they're going to be sorted anyway). + */ +class AdjustCounts { + public: + AdjustCounts(std::vector &counts, std::vector &discounts) + : counts_(counts), discounts_(discounts) {} + + void Run(const ChainPositions &positions); + + private: + std::vector &counts_; + std::vector &discounts_; +}; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_ADJUST_COUNTS__ + diff --git a/klm/lm/builder/adjust_counts_test.cc b/klm/lm/builder/adjust_counts_test.cc new file mode 100644 index 00000000..68b5f33e --- /dev/null +++ b/klm/lm/builder/adjust_counts_test.cc @@ -0,0 +1,106 @@ +#include "lm/builder/adjust_counts.hh" + +#include "lm/builder/multi_stream.hh" +#include "util/scoped.hh" + +#include +#define BOOST_TEST_MODULE AdjustCounts +#include + +namespace lm { namespace builder { namespace { + +class KeepCopy { + public: + KeepCopy() : size_(0) {} + + void Run(const util::stream::ChainPosition &position) { + for (util::stream::Link link(position); link; ++link) { + mem_.call_realloc(size_ + link->ValidSize()); + memcpy(static_cast(mem_.get()) + size_, link->Get(), link->ValidSize()); + size_ += link->ValidSize(); + } + } + + uint8_t *Get() { return static_cast(mem_.get()); } + std::size_t Size() const { return size_; } + + private: + util::scoped_malloc mem_; + std::size_t size_; +}; + +struct Gram4 { + WordIndex ids[4]; + uint64_t count; +}; + +class WriteInput { + public: + void Run(const util::stream::ChainPosition &position) { + NGramStream input(position); + Gram4 grams[] = { + {{0,0,0,0},10}, + {{0,0,3,0},3}, + // bos + {{1,1,1,2},5}, + {{0,0,3,2},5}, + }; + for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) { + memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4); + input->Count() = grams[i].count; + } + input.Poison(); + } +}; + +BOOST_AUTO_TEST_CASE(Simple) { + KeepCopy outputs[4]; + std::vector counts; + std::vector discount; + { + util::stream::ChainConfig config; + config.total_memory = 100; + config.block_count = 1; + Chains chains(4); + for (unsigned i = 0; i < 4; ++i) { + config.entry_size = NGram::TotalSize(i + 1); + chains.push_back(config); + } + + chains[3] >> WriteInput(); + ChainPositions for_adjust(chains); + for (unsigned i = 0; i < 4; ++i) { + chains[i] >> boost::ref(outputs[i]); + } + chains >> util::stream::kRecycle; + BOOST_CHECK_THROW(AdjustCounts(counts, discount).Run(for_adjust), BadDiscountException); + } + BOOST_REQUIRE_EQUAL(4UL, counts.size()); + BOOST_CHECK_EQUAL(4UL, counts[0]); + // These are no longer set because the discounts are bad. +/* BOOST_CHECK_EQUAL(4UL, counts[1]); + BOOST_CHECK_EQUAL(3UL, counts[2]); + BOOST_CHECK_EQUAL(3UL, counts[3]);*/ + BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size()); + NGram uni(outputs[0].Get(), 1); + BOOST_CHECK_EQUAL(kUNK, *uni.begin()); + BOOST_CHECK_EQUAL(0ULL, uni.Count()); + uni.NextInMemory(); + BOOST_CHECK_EQUAL(kBOS, *uni.begin()); + BOOST_CHECK_EQUAL(0ULL, uni.Count()); + uni.NextInMemory(); + BOOST_CHECK_EQUAL(0UL, *uni.begin()); + BOOST_CHECK_EQUAL(2ULL, uni.Count()); + uni.NextInMemory(); + BOOST_CHECK_EQUAL(2ULL, uni.Count()); + BOOST_CHECK_EQUAL(2UL, *uni.begin()); + + BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size()); + NGram bi(outputs[1].Get(), 2); + BOOST_CHECK_EQUAL(0UL, *bi.begin()); + BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1)); + BOOST_CHECK_EQUAL(1ULL, bi.Count()); + bi.NextInMemory(); +} + +}}} // namespaces diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc new file mode 100644 index 00000000..8c3de57d --- /dev/null +++ b/klm/lm/builder/corpus_count.cc @@ -0,0 +1,223 @@ +#include "lm/builder/corpus_count.hh" + +#include "lm/builder/ngram.hh" +#include "lm/lm_exception.hh" +#include "lm/word_index.hh" +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/murmur_hash.hh" +#include "util/probing_hash_table.hh" +#include "util/scoped.hh" +#include "util/stream/chain.hh" +#include "util/stream/timer.hh" +#include "util/tokenize_piece.hh" + +#include +#include + +#include + +#include + +namespace lm { +namespace builder { +namespace { + +class VocabHandout { + public: + explicit VocabHandout(int fd) { + util::scoped_fd duped(util::DupOrThrow(fd)); + word_list_.reset(util::FDOpenOrThrow(duped)); + + Lookup(""); // Force 0 + Lookup(""); // Force 1 + Lookup(""); // Force 2 + } + + WordIndex Lookup(const StringPiece &word) { + uint64_t hashed = util::MurmurHashNative(word.data(), word.size()); + std::pair ret(seen_.insert(std::pair(hashed, seen_.size()))); + if (ret.second) { + char null_delimit = 0; + util::WriteOrThrow(word_list_.get(), word.data(), word.size()); + util::WriteOrThrow(word_list_.get(), &null_delimit, 1); + UTIL_THROW_IF(seen_.size() >= std::numeric_limits::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); + } + return ret.first->second; + } + + WordIndex Size() const { + return seen_.size(); + } + + private: + typedef boost::unordered_map Seen; + + Seen seen_; + + util::scoped_FILE word_list_; +}; + +class DedupeHash : public std::unary_function { + public: + explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} + + std::size_t operator()(const WordIndex *start) const { + return util::MurmurHashNative(start, size_); + } + + private: + const std::size_t size_; +}; + +class DedupeEquals : public std::binary_function { + public: + explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {} + + bool operator()(const WordIndex *first, const WordIndex *second) const { + return !memcmp(first, second, size_); + } + + private: + const std::size_t size_; +}; + +struct DedupeEntry { + typedef WordIndex *Key; + Key GetKey() const { return key; } + Key key; + static DedupeEntry Construct(WordIndex *at) { + DedupeEntry ret; + ret.key = at; + return ret; + } +}; + +typedef util::ProbingHashTable Dedupe; + +const float kProbingMultiplier = 1.5; + +class Writer { + public: + Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size) + : block_(position), gram_(block_->Get(), order), + dedupe_invalid_(order, std::numeric_limits::max()), + dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)), + buffer_(new WordIndex[order - 1]), + block_size_(position.GetChain().BlockSize()) { + dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); + assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size); + if (order == 1) { + // Add special words. AdjustCounts is responsible if order != 1. + AddUnigramWord(kUNK); + AddUnigramWord(kBOS); + } + } + + ~Writer() { + block_->SetValidSize(reinterpret_cast(gram_.begin()) - static_cast(block_->Get())); + (++block_).Poison(); + } + + // Write context with a bunch of + void StartSentence() { + for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) { + *i = kBOS; + } + } + + void Append(WordIndex word) { + *(gram_.end() - 1) = word; + Dedupe::MutableIterator at; + bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at); + if (found) { + // Already present. + NGram already(at->key, gram_.Order()); + ++(already.Count()); + // Shift left by one. + memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1)); + return; + } + // Complete the write. + gram_.Count() = 1; + // Prepare the next n-gram. + if (reinterpret_cast(gram_.begin()) + gram_.TotalSize() != static_cast(block_->Get()) + block_size_) { + NGram last(gram_); + gram_.NextInMemory(); + std::copy(last.begin() + 1, last.end(), gram_.begin()); + return; + } + // Block end. Need to store the context in a temporary buffer. + std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); + dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); + block_->SetValidSize(block_size_); + gram_.ReBase((++block_)->Get()); + std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin()); + } + + private: + void AddUnigramWord(WordIndex index) { + *gram_.begin() = index; + gram_.Count() = 0; + gram_.NextInMemory(); + if (gram_.Base() == static_cast(block_->Get()) + block_size_) { + block_->SetValidSize(block_size_); + gram_.ReBase((++block_)->Get()); + } + } + + util::stream::Link block_; + + NGram gram_; + + // This is the memory behind the invalid value in dedupe_. + std::vector dedupe_invalid_; + // Hash table combiner implementation. + Dedupe dedupe_; + + // Small buffer to hold existing ngrams when shifting across a block boundary. + boost::scoped_array buffer_; + + const std::size_t block_size_; +}; + +} // namespace + +float CorpusCount::DedupeMultiplier(std::size_t order) { + return kProbingMultiplier * static_cast(sizeof(DedupeEntry)) / static_cast(NGram::TotalSize(order)); +} + +CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) + : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), + dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), + dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { + token_count_ = 0; + type_count_ = 0; +} + +void CorpusCount::Run(const util::stream::ChainPosition &position) { + UTIL_TIMER("(%w s) Counted n-grams\n"); + + VocabHandout vocab(vocab_write_); + const WordIndex end_sentence = vocab.Lookup(""); + Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); + uint64_t count = 0; + try { + while(true) { + StringPiece line(from_.ReadLine()); + writer.StartSentence(); + for (util::TokenIter w(line, " \t"); w; ++w) { + WordIndex word = vocab.Lookup(*w); + UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing in the future."); + writer.Append(word); + ++count; + } + writer.Append(end_sentence); + } + } catch (const util::EndOfFileException &e) {} + token_count_ = count; + type_count_ = vocab.Size(); +} + +} // namespace builder +} // namespace lm diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh new file mode 100644 index 00000000..e255bad1 --- /dev/null +++ b/klm/lm/builder/corpus_count.hh @@ -0,0 +1,42 @@ +#ifndef LM_BUILDER_CORPUS_COUNT__ +#define LM_BUILDER_CORPUS_COUNT__ + +#include "lm/word_index.hh" +#include "util/scoped.hh" + +#include +#include +#include + +namespace util { +class FilePiece; +namespace stream { +class ChainPosition; +} // namespace stream +} // namespace util + +namespace lm { +namespace builder { + +class CorpusCount { + public: + // Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size + static float DedupeMultiplier(std::size_t order); + + CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block); + + void Run(const util::stream::ChainPosition &position); + + private: + util::FilePiece &from_; + int vocab_write_; + uint64_t &token_count_; + WordIndex &type_count_; + + std::size_t dedupe_mem_size_; + util::scoped_malloc dedupe_mem_; +}; + +} // namespace builder +} // namespace lm +#endif // LM_BUILDER_CORPUS_COUNT__ diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc new file mode 100644 index 00000000..8d53ca9d --- /dev/null +++ b/klm/lm/builder/corpus_count_test.cc @@ -0,0 +1,76 @@ +#include "lm/builder/corpus_count.hh" + +#include "lm/builder/ngram.hh" +#include "lm/builder/ngram_stream.hh" + +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/tokenize_piece.hh" +#include "util/stream/chain.hh" +#include "util/stream/stream.hh" + +#define BOOST_TEST_MODULE CorpusCountTest +#include + +namespace lm { namespace builder { namespace { + +#define Check(str, count) { \ + BOOST_REQUIRE(stream); \ + w = stream->begin(); \ + for (util::TokenIter t(str, " "); t; ++t, ++w) { \ + BOOST_CHECK_EQUAL(*t, v[*w]); \ + } \ + BOOST_CHECK_EQUAL((uint64_t)count, stream->Count()); \ + ++stream; \ +} + +BOOST_AUTO_TEST_CASE(Short) { + util::scoped_fd input_file(util::MakeTemp("corpus_count_test_temp")); + const char input[] = "looking on a little more loin\non a little more loin\non foo little more loin\nbar\n\n"; + // Blocks of 10 are + // looking on a little more loin on a little[duplicate] more[duplicate] loin[duplicate] [duplicate] on[duplicate] foo + // little more loin bar + + util::WriteOrThrow(input_file.get(), input, sizeof(input) - 1); + util::FilePiece input_piece(input_file.release(), "temp file"); + + util::stream::ChainConfig config; + config.entry_size = NGram::TotalSize(3); + config.total_memory = config.entry_size * 20; + config.block_count = 2; + + util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab")); + + util::stream::Chain chain(config); + NGramStream stream; + uint64_t token_count; + WordIndex type_count; + CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize()); + chain >> boost::ref(counter) >> stream >> util::stream::kRecycle; + + const char *v[] = {"", "", "", "looking", "on", "a", "little", "more", "loin", "foo", "bar"}; + + WordIndex *w; + + Check(" looking", 1); + Check(" looking on", 1); + Check("looking on a", 1); + Check("on a little", 2); + Check("a little more", 2); + Check("little more loin", 2); + Check("more loin ", 2); + Check(" on", 2); + Check(" on a", 1); + Check(" on foo", 1); + Check("on foo little", 1); + Check("foo little more", 1); + Check("little more loin", 1); + Check("more loin ", 1); + Check(" bar", 1); + Check(" bar ", 1); + Check(" ", 1); + BOOST_CHECK(!stream); + BOOST_CHECK_EQUAL(sizeof(v) / sizeof(const char*), type_count); +} + +}}} // namespaces diff --git a/klm/lm/builder/discount.hh b/klm/lm/builder/discount.hh new file mode 100644 index 00000000..754fb20d --- /dev/null +++ b/klm/lm/builder/discount.hh @@ -0,0 +1,26 @@ +#ifndef BUILDER_DISCOUNT__ +#define BUILDER_DISCOUNT__ + +#include + +#include + +namespace lm { +namespace builder { + +struct Discount { + float amount[4]; + + float Get(uint64_t count) const { + return amount[std::min(count, 3)]; + } + + float Apply(uint64_t count) const { + return static_cast(count) - Get(count); + } +}; + +} // namespace builder +} // namespace lm + +#endif // BUILDER_DISCOUNT__ diff --git a/klm/lm/builder/header_info.hh b/klm/lm/builder/header_info.hh new file mode 100644 index 00000000..ccca1456 --- /dev/null +++ b/klm/lm/builder/header_info.hh @@ -0,0 +1,20 @@ +#ifndef LM_BUILDER_HEADER_INFO__ +#define LM_BUILDER_HEADER_INFO__ + +#include +#include + +// Some configuration info that is used to add +// comments to the beginning of an ARPA file +struct HeaderInfo { + const std::string input_file; + const uint64_t token_count; + + HeaderInfo(const std::string& input_file_in, uint64_t token_count_in) + : input_file(input_file_in), token_count(token_count_in) {} + + // TODO: Add smoothing type + // TODO: More info if multiple models were interpolated +}; + +#endif diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc new file mode 100644 index 00000000..58b42a20 --- /dev/null +++ b/klm/lm/builder/initial_probabilities.cc @@ -0,0 +1,136 @@ +#include "lm/builder/initial_probabilities.hh" + +#include "lm/builder/discount.hh" +#include "lm/builder/ngram_stream.hh" +#include "lm/builder/sort.hh" +#include "util/file.hh" +#include "util/stream/chain.hh" +#include "util/stream/io.hh" +#include "util/stream/stream.hh" + +#include + +namespace lm { namespace builder { + +namespace { +struct BufferEntry { + // Gamma from page 20 of Chen and Goodman. + float gamma; + // \sum_w a(c w) for all w. + float denominator; +}; + +// Extract an array of gamma from an array of BufferEntry. +class OnlyGamma { + public: + void Run(const util::stream::ChainPosition &position) { + for (util::stream::Link block_it(position); block_it; ++block_it) { + float *out = static_cast(block_it->Get()); + const float *in = out; + const float *end = static_cast(block_it->ValidEnd()); + for (out += 1, in += 2; in < end; out += 1, in += 2) { + *out = *in; + } + block_it->SetValidSize(block_it->ValidSize() / 2); + } + } +}; + +class AddRight { + public: + AddRight(const Discount &discount, const util::stream::ChainPosition &input) + : discount_(discount), input_(input) {} + + void Run(const util::stream::ChainPosition &output) { + NGramStream in(input_); + util::stream::Stream out(output); + + std::vector previous(in->Order() - 1); + const std::size_t size = sizeof(WordIndex) * previous.size(); + for(; in; ++out) { + memcpy(&previous[0], in->begin(), size); + uint64_t denominator = 0; + uint64_t counts[4]; + memset(counts, 0, sizeof(counts)); + do { + denominator += in->Count(); + ++counts[std::min(in->Count(), static_cast(3))]; + } while (++in && !memcmp(&previous[0], in->begin(), size)); + BufferEntry &entry = *reinterpret_cast(out.Get()); + entry.denominator = static_cast(denominator); + entry.gamma = 0.0; + for (unsigned i = 1; i <= 3; ++i) { + entry.gamma += discount_.Get(i) * static_cast(counts[i]); + } + entry.gamma /= entry.denominator; + } + out.Poison(); + } + + private: + const Discount &discount_; + const util::stream::ChainPosition input_; +}; + +class MergeRight { + public: + MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount) + : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {} + + // calculate the initial probability of each n-gram (before order-interpolation) + // Run() gets invoked once for each order + void Run(const util::stream::ChainPosition &primary) { + util::stream::Stream summed(from_adder_); + + NGramStream grams(primary); + + // Without interpolation, the interpolation weight goes to . + if (grams->Order() == 1 && !interpolate_unigrams_) { + BufferEntry sums(*static_cast(summed.Get())); + assert(*grams->begin() == kUNK); + grams->Value().uninterp.prob = sums.gamma; + grams->Value().uninterp.gamma = 0.0; + while (++grams) { + grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator; + grams->Value().uninterp.gamma = 0.0; + } + ++summed; + return; + } + + std::vector previous(grams->Order() - 1); + const std::size_t size = sizeof(WordIndex) * previous.size(); + for (; grams; ++summed) { + memcpy(&previous[0], grams->begin(), size); + const BufferEntry &sums = *static_cast(summed.Get()); + do { + Payload &pay = grams->Value(); + pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator; + pay.uninterp.gamma = sums.gamma; + } while (++grams && !memcmp(&previous[0], grams->begin(), size)); + } + } + + private: + bool interpolate_unigrams_; + util::stream::ChainPosition from_adder_; + Discount discount_; +}; + +} // namespace + +void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) { + util::stream::ChainConfig gamma_config = config.adder_out; + gamma_config.entry_size = sizeof(BufferEntry); + for (size_t i = 0; i < primary.size(); ++i) { + util::stream::ChainPosition second(second_in[i].Add()); + second_in[i] >> util::stream::kRecycle; + gamma_out.push_back(gamma_config); + gamma_out[i] >> AddRight(discounts[i], second); + primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]); + // Don't bother with the OnlyGamma thread for something to discard. + if (i) gamma_out[i] >> OnlyGamma(); + } +} + +}} // namespaces diff --git a/klm/lm/builder/initial_probabilities.hh b/klm/lm/builder/initial_probabilities.hh new file mode 100644 index 00000000..626388eb --- /dev/null +++ b/klm/lm/builder/initial_probabilities.hh @@ -0,0 +1,34 @@ +#ifndef LM_BUILDER_INITIAL_PROBABILITIES__ +#define LM_BUILDER_INITIAL_PROBABILITIES__ + +#include "lm/builder/discount.hh" +#include "util/stream/config.hh" + +#include + +namespace lm { +namespace builder { +class Chains; + +struct InitialProbabilitiesConfig { + // These should be small buffers to keep the adder from getting too far ahead + util::stream::ChainConfig adder_in; + util::stream::ChainConfig adder_out; + // SRILM doesn't normally interpolate unigrams. + bool interpolate_unigrams; +}; + +/* Compute initial (uninterpolated) probabilities + * primary: the normal chain of n-grams. Incoming is context sorted adjusted + * counts. Outgoing has uninterpolated probabilities for use by Interpolate. + * second_in: a second copy of the primary input. Discard the output. + * gamma_out: Computed gamma values are output on these chains in suffix order. + * The values are bare floats and should be buffered for interpolation to + * use. + */ +void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector &discounts, Chains &primary, Chains &second_in, Chains &gamma_out); + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_INITIAL_PROBABILITIES__ diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc new file mode 100644 index 00000000..50026806 --- /dev/null +++ b/klm/lm/builder/interpolate.cc @@ -0,0 +1,65 @@ +#include "lm/builder/interpolate.hh" + +#include "lm/builder/joint_order.hh" +#include "lm/builder/multi_stream.hh" +#include "lm/builder/sort.hh" +#include "lm/lm_exception.hh" + +#include + +namespace lm { namespace builder { +namespace { + +class Callback { + public: + Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) { + probs_[0] = uniform_prob; + for (std::size_t i = 0; i < backoffs.size(); ++i) { + backoffs_.push_back(backoffs[i]); + } + } + + ~Callback() { + for (std::size_t i = 0; i < backoffs_.size(); ++i) { + if (backoffs_[i]) { + std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl; + abort(); + } + } + } + + void Enter(unsigned order_minus_1, NGram &gram) { + Payload &pay = gram.Value(); + pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; + probs_[order_minus_1 + 1] = pay.complete.prob; + pay.complete.prob = log10(pay.complete.prob); + // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling. + if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { + pay.complete.backoff = log10(*static_cast(backoffs_[order_minus_1].Get())); + ++backoffs_[order_minus_1]; + } else { + // Not a context. + pay.complete.backoff = 0.0; + } + } + + void Exit(unsigned, const NGram &) const {} + + private: + FixedArray backoffs_; + + std::vector probs_; +}; +} // namespace + +Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs) + : uniform_prob_(1.0 / static_cast(unigram_count - 1)), backoffs_(backoffs) {} + +// perform order-wise interpolation +void Interpolate::Run(const ChainPositions &positions) { + assert(positions.size() == backoffs_.size() + 1); + Callback callback(uniform_prob_, backoffs_); + JointOrder(positions, callback); +} + +}} // namespaces diff --git a/klm/lm/builder/interpolate.hh b/klm/lm/builder/interpolate.hh new file mode 100644 index 00000000..9268d404 --- /dev/null +++ b/klm/lm/builder/interpolate.hh @@ -0,0 +1,27 @@ +#ifndef LM_BUILDER_INTERPOLATE__ +#define LM_BUILDER_INTERPOLATE__ + +#include + +#include "lm/builder/multi_stream.hh" + +namespace lm { namespace builder { + +/* Interpolate step. + * Input: suffix sorted n-grams with (p_uninterpolated, gamma) from + * InitialProbabilities. + * Output: suffix sorted n-grams with complete probability + */ +class Interpolate { + public: + explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs); + + void Run(const ChainPositions &positions); + + private: + float uniform_prob_; + ChainPositions backoffs_; +}; + +}} // namespaces +#endif // LM_BUILDER_INTERPOLATE__ diff --git a/klm/lm/builder/joint_order.hh b/klm/lm/builder/joint_order.hh new file mode 100644 index 00000000..b5620144 --- /dev/null +++ b/klm/lm/builder/joint_order.hh @@ -0,0 +1,43 @@ +#ifndef LM_BUILDER_JOINT_ORDER__ +#define LM_BUILDER_JOINT_ORDER__ + +#include "lm/builder/multi_stream.hh" +#include "lm/lm_exception.hh" + +#include + +namespace lm { namespace builder { + +template void JointOrder(const ChainPositions &positions, Callback &callback) { + // Allow matching to reference streams[-1]. + NGramStreams streams_with_dummy; + streams_with_dummy.InitWithDummy(positions); + NGramStream *streams = streams_with_dummy.begin() + 1; + + unsigned int order; + for (order = 0; order < positions.size() && streams[order]; ++order) {} + assert(order); // should always have . + unsigned int current = 0; + while (true) { + // Does the context match the lower one? + if (!memcmp(streams[static_cast(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset, sizeof(WordIndex) * current)) { + callback.Enter(current, *streams[current]); + // Transition to looking for extensions. + if (++current < order) continue; + } + // No extension left. + while(true) { + assert(current > 0); + --current; + callback.Exit(current, *streams[current]); + if (++streams[current]) break; + UTIL_THROW_IF(order != current + 1, FormatLoadException, "Detected n-gram without matching suffix"); + order = current; + if (!order) return; + } + } +} + +}} // namespaces + +#endif // LM_BUILDER_JOINT_ORDER__ diff --git a/klm/lm/builder/main.cc b/klm/lm/builder/main.cc new file mode 100644 index 00000000..90b9dca2 --- /dev/null +++ b/klm/lm/builder/main.cc @@ -0,0 +1,94 @@ +#include "lm/builder/pipeline.hh" +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/usage.hh" + +#include + +#include + +namespace { +class SizeNotify { + public: + SizeNotify(std::size_t &out) : behind_(out) {} + + void operator()(const std::string &from) { + behind_ = util::ParseSize(from); + } + + private: + std::size_t &behind_; +}; + +boost::program_options::typed_value *SizeOption(std::size_t &to, const char *default_value) { + return boost::program_options::value()->notifier(SizeNotify(to))->default_value(default_value); +} + +} // namespace + +int main(int argc, char *argv[]) { + try { + namespace po = boost::program_options; + po::options_description options("Language model building options"); + lm::builder::PipelineConfig pipeline; + + options.add_options() + ("order,o", po::value(&pipeline.order)->required(), "Order of the model") + ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") + ("temp_prefix,T", po::value(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") + ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") + ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step") + ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") + ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") + ("block_count", po::value(&pipeline.block_count)->default_value(2), "Block count (per order)") + ("vocab_file", po::value(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") + ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); + if (argc == 1) { + std::cerr << + "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n" + "Please cite:\n" + "@inproceedings{kenlm,\n" + "author = {Kenneth Heafield},\n" + "title = {{KenLM}: Faster and Smaller Language Model Queries},\n" + "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n" + "month = {July}, year={2011},\n" + "address = {Edinburgh, UK},\n" + "publisher = {Association for Computational Linguistics},\n" + "}\n\n" + "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n" + "the model (-o) is the only mandatory option. As this is an on-disk program,\n" + "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n" + "Memory sizes are specified like GNU sort: a number followed by a unit character.\n" + "Valid units are \% for percentage of memory (supported platforms only) and (in\n" + "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n"; + std::cerr << options << std::endl; + return 1; + } + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, options), vm); + po::notify(vm); + + util::NormalizeTempPrefix(pipeline.sort.temp_prefix); + + lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; + // TODO: evaluate options for these. + initial.adder_in.total_memory = 32768; + initial.adder_in.block_count = 2; + initial.adder_out.total_memory = 32768; + initial.adder_out.block_count = 2; + pipeline.read_backoffs = initial.adder_out; + + // Read from stdin + try { + lm::builder::Pipeline(pipeline, 0, 1); + } catch (const util::MallocException &e) { + std::cerr << e.what() << std::endl; + std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as() << std::endl; + return 1; + } + util::PrintUsage(std::cerr); + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; + } +} diff --git a/klm/lm/builder/multi_stream.hh b/klm/lm/builder/multi_stream.hh new file mode 100644 index 00000000..707a98c7 --- /dev/null +++ b/klm/lm/builder/multi_stream.hh @@ -0,0 +1,180 @@ +#ifndef LM_BUILDER_MULTI_STREAM__ +#define LM_BUILDER_MULTI_STREAM__ + +#include "lm/builder/ngram_stream.hh" +#include "util/scoped.hh" +#include "util/stream/chain.hh" + +#include +#include + +#include +#include + +namespace lm { namespace builder { + +template class FixedArray { + public: + explicit FixedArray(std::size_t count) { + Init(count); + } + + FixedArray() : newed_end_(NULL) {} + + void Init(std::size_t count) { + assert(!block_.get()); + block_.reset(malloc(sizeof(T) * count)); + if (!block_.get()) throw std::bad_alloc(); + newed_end_ = begin(); + } + + FixedArray(const FixedArray &from) { + std::size_t size = from.newed_end_ - static_cast(from.block_.get()); + Init(size); + for (std::size_t i = 0; i < size; ++i) { + new(end()) T(from[i]); + Constructed(); + } + } + + ~FixedArray() { clear(); } + + T *begin() { return static_cast(block_.get()); } + const T *begin() const { return static_cast(block_.get()); } + // Always call Constructed after successful completion of new. + T *end() { return newed_end_; } + const T *end() const { return newed_end_; } + + T &back() { return *(end() - 1); } + const T &back() const { return *(end() - 1); } + + std::size_t size() const { return end() - begin(); } + bool empty() const { return begin() == end(); } + + T &operator[](std::size_t i) { return begin()[i]; } + const T &operator[](std::size_t i) const { return begin()[i]; } + + template void push_back(const C &c) { + new (end()) T(c); + Constructed(); + } + + void clear() { + for (T *i = begin(); i != end(); ++i) + i->~T(); + newed_end_ = begin(); + } + + protected: + void Constructed() { + ++newed_end_; + } + + private: + util::scoped_malloc block_; + + T *newed_end_; +}; + +class Chains; + +class ChainPositions : public FixedArray { + public: + ChainPositions() {} + + void Init(Chains &chains); + + explicit ChainPositions(Chains &chains) { + Init(chains); + } +}; + +class Chains : public FixedArray { + private: + template struct CheckForRun { + typedef Chains type; + }; + + public: + explicit Chains(std::size_t limit) : FixedArray(limit) {} + + template typename CheckForRun::type &operator>>(const Worker &worker) { + threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker)); + return *this; + } + + template typename CheckForRun::type &operator>>(const boost::reference_wrapper &worker) { + threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker)); + return *this; + } + + Chains &operator>>(const util::stream::Recycler &recycler) { + for (util::stream::Chain *i = begin(); i != end(); ++i) + *i >> recycler; + return *this; + } + + void Wait(bool release_memory = true) { + threads_.clear(); + for (util::stream::Chain *i = begin(); i != end(); ++i) { + i->Wait(release_memory); + } + } + + private: + boost::ptr_vector threads_; + + Chains(const Chains &); + void operator=(const Chains &); +}; + +inline void ChainPositions::Init(Chains &chains) { + FixedArray::Init(chains.size()); + for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) { + new (end()) util::stream::ChainPosition(i->Add()); Constructed(); + } +} + +inline Chains &operator>>(Chains &chains, ChainPositions &positions) { + positions.Init(chains); + return chains; +} + +class NGramStreams : public FixedArray { + public: + NGramStreams() {} + + // This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning). + void InitWithDummy(const ChainPositions &positions) { + FixedArray::Init(positions.size() + 1); + new (end()) NGramStream(); Constructed(); + for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) { + push_back(*i); + } + } + + // Limit restricts to positions[0,limit) + void Init(const ChainPositions &positions, std::size_t limit) { + FixedArray::Init(limit); + for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) { + push_back(*i); + } + } + void Init(const ChainPositions &positions) { + Init(positions, positions.size()); + } + + NGramStreams(const ChainPositions &positions) { + Init(positions); + } +}; + +inline Chains &operator>>(Chains &chains, NGramStreams &streams) { + ChainPositions positions; + chains >> positions; + streams.Init(positions); + return chains; +} + +}} // namespaces +#endif // LM_BUILDER_MULTI_STREAM__ diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh new file mode 100644 index 00000000..2984ed0b --- /dev/null +++ b/klm/lm/builder/ngram.hh @@ -0,0 +1,84 @@ +#ifndef LM_BUILDER_NGRAM__ +#define LM_BUILDER_NGRAM__ + +#include "lm/weights.hh" +#include "lm/word_index.hh" + +#include + +#include +#include +#include + +namespace lm { +namespace builder { + +struct Uninterpolated { + float prob; // Uninterpolated probability. + float gamma; // Interpolation weight for lower order. +}; + +union Payload { + uint64_t count; + Uninterpolated uninterp; + ProbBackoff complete; +}; + +class NGram { + public: + NGram(void *begin, std::size_t order) + : begin_(static_cast(begin)), end_(begin_ + order) {} + + const uint8_t *Base() const { return reinterpret_cast(begin_); } + uint8_t *Base() { return reinterpret_cast(begin_); } + + void ReBase(void *to) { + std::size_t difference = end_ - begin_; + begin_ = reinterpret_cast(to); + end_ = begin_ + difference; + } + + // Would do operator++ but that can get confusing for a stream. + void NextInMemory() { + ReBase(&Value() + 1); + } + + // Lower-case in deference to STL. + const WordIndex *begin() const { return begin_; } + WordIndex *begin() { return begin_; } + const WordIndex *end() const { return end_; } + WordIndex *end() { return end_; } + + const Payload &Value() const { return *reinterpret_cast(end_); } + Payload &Value() { return *reinterpret_cast(end_); } + + uint64_t &Count() { return Value().count; } + const uint64_t Count() const { return Value().count; } + + std::size_t Order() const { return end_ - begin_; } + + static std::size_t TotalSize(std::size_t order) { + return order * sizeof(WordIndex) + sizeof(Payload); + } + std::size_t TotalSize() const { + // Compiler should optimize this. + return TotalSize(Order()); + } + static std::size_t OrderFromSize(std::size_t size) { + std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex); + assert(size == TotalSize(ret)); + return ret; + } + + private: + WordIndex *begin_, *end_; +}; + +const WordIndex kUNK = 0; +const WordIndex kBOS = 1; +const WordIndex kEOS = 2; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_NGRAM__ diff --git a/klm/lm/builder/ngram_stream.hh b/klm/lm/builder/ngram_stream.hh new file mode 100644 index 00000000..3c994664 --- /dev/null +++ b/klm/lm/builder/ngram_stream.hh @@ -0,0 +1,55 @@ +#ifndef LM_BUILDER_NGRAM_STREAM__ +#define LM_BUILDER_NGRAM_STREAM__ + +#include "lm/builder/ngram.hh" +#include "util/stream/chain.hh" +#include "util/stream/stream.hh" + +#include + +namespace lm { namespace builder { + +class NGramStream { + public: + NGramStream() : gram_(NULL, 0) {} + + NGramStream(const util::stream::ChainPosition &position) : gram_(NULL, 0) { + Init(position); + } + + void Init(const util::stream::ChainPosition &position) { + stream_.Init(position); + gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize())); + } + + NGram &operator*() { return gram_; } + const NGram &operator*() const { return gram_; } + + NGram *operator->() { return &gram_; } + const NGram *operator->() const { return &gram_; } + + void *Get() { return stream_.Get(); } + const void *Get() const { return stream_.Get(); } + + operator bool() const { return stream_; } + bool operator!() const { return !stream_; } + void Poison() { stream_.Poison(); } + + NGramStream &operator++() { + ++stream_; + gram_.ReBase(stream_.Get()); + return *this; + } + + private: + NGram gram_; + util::stream::Stream stream_; +}; + +inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) { + str.Init(chain.Add()); + return chain; +} + +}} // namespaces +#endif // LM_BUILDER_NGRAM_STREAM__ diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc new file mode 100644 index 00000000..14a1f721 --- /dev/null +++ b/klm/lm/builder/pipeline.cc @@ -0,0 +1,320 @@ +#include "lm/builder/pipeline.hh" + +#include "lm/builder/adjust_counts.hh" +#include "lm/builder/corpus_count.hh" +#include "lm/builder/initial_probabilities.hh" +#include "lm/builder/interpolate.hh" +#include "lm/builder/print.hh" +#include "lm/builder/sort.hh" + +#include "lm/sizes.hh" + +#include "util/exception.hh" +#include "util/file.hh" +#include "util/stream/io.hh" + +#include +#include +#include + +namespace lm { namespace builder { + +namespace { +void PrintStatistics(const std::vector &counts, const std::vector &discounts) { + std::cerr << "Statistics:\n"; + for (size_t i = 0; i < counts.size(); ++i) { + std::cerr << (i + 1) << ' ' << counts[i]; + for (size_t d = 1; d <= 3; ++d) + std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d]; + std::cerr << '\n'; + } +} + +class Master { + public: + explicit Master(const PipelineConfig &config) + : config_(config), chains_(config.order), files_(config.order) { + config_.minimum_block = std::max(NGram::TotalSize(config_.order), config_.minimum_block); + } + + const PipelineConfig &Config() const { return config_; } + + Chains &MutableChains() { return chains_; } + + template Master &operator>>(const T &worker) { + chains_ >> worker; + return *this; + } + + // This takes the (partially) sorted ngrams and sets up for adjusted counts. + void InitForAdjust(util::stream::Sort &ngrams, WordIndex types) { + const std::size_t each_order_min = config_.minimum_block * config_.block_count; + // We know how many unigrams there are. Don't allocate more than needed to them. + const std::size_t min_chains = (config_.order - 1) * each_order_min + + std::min(types * NGram::TotalSize(1), each_order_min); + // Do merge sort with calculated laziness. + const std::size_t merge_using = ngrams.Merge(std::min(config_.TotalMemory() - min_chains, ngrams.DefaultLazy())); + + std::vector count_bounds(1, types); + CreateChains(config_.TotalMemory() - merge_using, count_bounds); + ngrams.Output(chains_.back(), merge_using); + + // Setup unigram file. + files_.push_back(util::MakeTemp(config_.TempPrefix())); + } + + // For initial probabilities, but this is generic. + void SortAndReadTwice(const std::vector &counts, Sorts &sorts, Chains &second, util::stream::ChainConfig second_config) { + // Do merge first before allocating chain memory. + for (std::size_t i = 1; i < config_.order; ++i) { + sorts[i - 1].Merge(0); + } + // There's no lazy merge, so just divide memory amongst the chains. + CreateChains(config_.TotalMemory(), counts); + chains_.back().ActivateProgress(); + chains_[0] >> files_[0].Source(); + second_config.entry_size = NGram::TotalSize(1); + second.push_back(second_config); + second.back() >> files_[0].Source(); + for (std::size_t i = 1; i < config_.order; ++i) { + util::scoped_fd fd(sorts[i - 1].StealCompleted()); + chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get())); + chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true); + second_config.entry_size = NGram::TotalSize(i + 1); + second.push_back(second_config); + second.back() >> util::stream::PRead(fd.release(), true); + } + } + + // There is no sort after this, so go for broke on lazy merging. + template void MaximumLazyInput(const std::vector &counts, Sorts &sorts) { + // Determine the minimum we can use for all the chains. + std::size_t min_chains = 0; + for (std::size_t i = 0; i < config_.order; ++i) { + min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast(config_.minimum_block)); + } + std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains); + std::vector laziness; + // Prioritize longer n-grams. + for (util::stream::Sort *i = sorts.end() - 1; i >= sorts.begin(); --i) { + laziness.push_back(i->Merge(for_merge)); + assert(for_merge >= laziness.back()); + for_merge -= laziness.back(); + } + std::reverse(laziness.begin(), laziness.end()); + + CreateChains(for_merge + min_chains, counts); + chains_.back().ActivateProgress(); + chains_[0] >> files_[0].Source(); + for (std::size_t i = 1; i < config_.order; ++i) { + sorts[i - 1].Output(chains_[i], laziness[i - 1]); + } + } + + void BufferFinal(const std::vector &counts) { + chains_[0] >> files_[0].Sink(); + for (std::size_t i = 1; i < config_.order; ++i) { + files_.push_back(util::MakeTemp(config_.TempPrefix())); + chains_[i] >> files_[i].Sink(); + } + chains_.Wait(true); + // Use less memory. Because we can. + CreateChains(std::min(config_.sort.buffer_size * config_.order, config_.TotalMemory()), counts); + for (std::size_t i = 0; i < config_.order; ++i) { + chains_[i] >> files_[i].Source(); + } + } + + template void SetupSorts(Sorts &sorts) { + sorts.Init(config_.order - 1); + // Unigrams don't get sorted because their order is always the same. + chains_[0] >> files_[0].Sink(); + for (std::size_t i = 1; i < config_.order; ++i) { + sorts.push_back(chains_[i], config_.sort, Compare(i + 1)); + } + chains_.Wait(true); + } + + private: + // Create chains, allocating memory to them. Totally heuristic. Count + // bounds are upper bounds on the counts or not present. + void CreateChains(std::size_t remaining_mem, const std::vector &count_bounds) { + std::vector assignments; + assignments.reserve(config_.order); + // Start by assigning maximum memory usage (to be refined later). + for (std::size_t i = 0; i < count_bounds.size(); ++i) { + assignments.push_back(static_cast(std::min( + static_cast(remaining_mem), + count_bounds[i] * static_cast(NGram::TotalSize(i + 1))))); + } + assignments.resize(config_.order, remaining_mem); + + // Now we know how much memory everybody wants. How much will they get? + // Proportional to this. + std::vector portions; + // Indices of orders that have yet to be assigned. + std::vector unassigned; + for (std::size_t i = 0; i < config_.order; ++i) { + portions.push_back(static_cast((i+1) * NGram::TotalSize(i+1))); + unassigned.push_back(i); + } + /*If somebody doesn't eat their full dinner, give it to the rest of the + * family. Then somebody else might not eat their full dinner etc. Ends + * when everybody unassigned is hungry. + */ + float sum; + bool found_more; + std::vector block_count(config_.order); + do { + sum = 0.0; + for (std::size_t i = 0; i < unassigned.size(); ++i) { + sum += portions[unassigned[i]]; + } + found_more = false; + // If the proportional assignment is more than needed, give it just what it needs. + for (std::vector::iterator i = unassigned.begin(); i != unassigned.end();) { + if (assignments[*i] <= remaining_mem * (portions[*i] / sum)) { + remaining_mem -= assignments[*i]; + block_count[*i] = 1; + i = unassigned.erase(i); + found_more = true; + } else { + ++i; + } + } + } while (found_more); + for (std::vector::iterator i = unassigned.begin(); i != unassigned.end(); ++i) { + assignments[*i] = remaining_mem * (portions[*i] / sum); + block_count[*i] = config_.block_count; + } + chains_.clear(); + std::cerr << "Chain sizes:"; + for (std::size_t i = 0; i < config_.order; ++i) { + std::cerr << ' ' << (i+1) << ":" << assignments[i]; + chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i])); + } + std::cerr << std::endl; + } + + PipelineConfig config_; + + Chains chains_; + // Often only unigrams, but sometimes all orders. + FixedArray files_; +}; + +void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) { + const PipelineConfig &config = master.Config(); + std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl; + + UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory()); + std::size_t memory_for_chain = + // This much memory to work with after vocab hash table. + static_cast(config.TotalMemory() - config.assume_vocab_hash_size) / + // Solve for block size including the dedupe multiplier for one block. + (static_cast(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) * + // Chain likes memory expressed in terms of total memory. + static_cast(config.block_count); + util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain)); + + WordIndex type_count; + util::FilePiece text(text_file, NULL, &std::cerr); + text_file_name = text.FileName(); + CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); + chain >> boost::ref(counter); + + util::stream::Sort sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); + chain.Wait(true); + std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl; + master.InitForAdjust(sorter, type_count); +} + +void InitialProbabilities(const std::vector &counts, const std::vector &discounts, Master &master, Sorts &primary, FixedArray &gammas) { + const PipelineConfig &config = master.Config(); + Chains second(config.order); + + { + Sorts sorts; + master.SetupSorts(sorts); + PrintStatistics(counts, discounts); + lm::ngram::ShowSizes(counts); + std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl; + master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in); + } + + Chains gamma_chains(config.order); + InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains); + // Don't care about gamma for 0. + gamma_chains[0] >> util::stream::kRecycle; + gammas.Init(config.order - 1); + for (std::size_t i = 1; i < config.order; ++i) { + gammas.push_back(util::MakeTemp(config.TempPrefix())); + gamma_chains[i] >> gammas[i - 1].Sink(); + } + // Has to be done here due to gamma_chains scope. + master.SetupSorts(primary); +} + +void InterpolateProbabilities(const std::vector &counts, Master &master, Sorts &primary, FixedArray &gammas) { + std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl; + const PipelineConfig &config = master.Config(); + master.MaximumLazyInput(counts, primary); + + Chains gamma_chains(config.order - 1); + util::stream::ChainConfig read_backoffs(config.read_backoffs); + read_backoffs.entry_size = sizeof(float); + for (std::size_t i = 0; i < config.order - 1; ++i) { + gamma_chains.push_back(read_backoffs); + gamma_chains.back() >> gammas[i].Source(); + } + master >> Interpolate(counts[0], ChainPositions(gamma_chains)); + gamma_chains >> util::stream::kRecycle; + master.BufferFinal(counts); +} + +} // namespace + +void Pipeline(PipelineConfig config, int text_file, int out_arpa) { + // Some fail-fast sanity checks. + if (config.sort.buffer_size * 4 > config.TotalMemory()) { + config.sort.buffer_size = config.TotalMemory() / 4; + std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl; + } + if (config.minimum_block < NGram::TotalSize(config.order)) { + config.minimum_block = NGram::TotalSize(config.order); + std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl; + } + UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << "."); + UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception, + "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size."); + + UTIL_TIMER("(%w s) Total wall time elapsed\n"); + Master master(config); + + util::scoped_fd vocab_file(config.vocab_file.empty() ? + util::MakeTemp(config.TempPrefix()) : + util::CreateOrThrow(config.vocab_file.c_str())); + uint64_t token_count; + std::string text_file_name; + CountText(text_file, vocab_file.get(), master, token_count, text_file_name); + + std::vector counts; + std::vector discounts; + master >> AdjustCounts(counts, discounts); + + { + FixedArray gammas; + Sorts primary; + InitialProbabilities(counts, discounts, master, primary, gammas); + InterpolateProbabilities(counts, master, primary, gammas); + } + + std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; + VocabReconstitute vocab(vocab_file.get()); + UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?"); + HeaderInfo header_info(text_file_name, token_count); + master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; + master.MutableChains().Wait(true); +} + +}} // namespaces diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh new file mode 100644 index 00000000..f1d6c5f6 --- /dev/null +++ b/klm/lm/builder/pipeline.hh @@ -0,0 +1,40 @@ +#ifndef LM_BUILDER_PIPELINE__ +#define LM_BUILDER_PIPELINE__ + +#include "lm/builder/initial_probabilities.hh" +#include "lm/builder/header_info.hh" +#include "util/stream/config.hh" +#include "util/file_piece.hh" + +#include +#include + +namespace lm { namespace builder { + +struct PipelineConfig { + std::size_t order; + std::string vocab_file; + util::stream::SortConfig sort; + InitialProbabilitiesConfig initial_probs; + util::stream::ChainConfig read_backoffs; + bool verbose_header; + + // Amount of memory to assume that the vocabulary hash table will use. This + // is subtracted from total memory for CorpusCount. + std::size_t assume_vocab_hash_size; + + // Minimum block size to tolerate. + std::size_t minimum_block; + + // Number of blocks to use. This will be overridden to 1 if everything fits. + std::size_t block_count; + + const std::string &TempPrefix() const { return sort.temp_prefix; } + std::size_t TotalMemory() const { return sort.total_memory; } +}; + +// Takes ownership of text_file. +void Pipeline(PipelineConfig config, int text_file, int out_arpa); + +}} // namespaces +#endif // LM_BUILDER_PIPELINE__ diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc new file mode 100644 index 00000000..b0323221 --- /dev/null +++ b/klm/lm/builder/print.cc @@ -0,0 +1,135 @@ +#include "lm/builder/print.hh" + +#include "util/double-conversion/double-conversion.h" +#include "util/double-conversion/utils.h" +#include "util/file.hh" +#include "util/mmap.hh" +#include "util/scoped.hh" +#include "util/stream/timer.hh" + +#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE +#include + +#include + +#include + +namespace lm { namespace builder { + +VocabReconstitute::VocabReconstitute(int fd) { + uint64_t size = util::SizeOrThrow(fd); + util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_); + const char *const start = static_cast(memory_.get()); + const char *i; + for (i = start; i != start + size; i += strlen(i) + 1) { + map_.push_back(i); + } + // Last one for LookupPiece. + map_.push_back(i); +} + +namespace { +class OutputManager { + public: + static const std::size_t kOutBuf = 1048576; + + // Does not take ownership of out. + explicit OutputManager(int out) + : buf_(util::MallocOrThrow(kOutBuf)), + builder_(static_cast(buf_.get()), kOutBuf), + // Mostly the default but with inf instead. And no flags. + convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0), + fd_(out) {} + + ~OutputManager() { + Flush(); + } + + OutputManager &operator<<(float value) { + // Odd, but this is the largest number found in the comments. + EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8); + convert_.ToShortestSingle(value, &builder_); + return *this; + } + + OutputManager &operator<<(StringPiece str) { + if (str.size() > kOutBuf) { + Flush(); + util::WriteOrThrow(fd_, str.data(), str.size()); + } else { + EnsureRemaining(str.size()); + builder_.AddSubstring(str.data(), str.size()); + } + return *this; + } + + // Inefficient! + OutputManager &operator<<(unsigned val) { + return *this << boost::lexical_cast(val); + } + + OutputManager &operator<<(char c) { + EnsureRemaining(1); + builder_.AddCharacter(c); + return *this; + } + + void Flush() { + util::WriteOrThrow(fd_, buf_.get(), builder_.position()); + builder_.Reset(); + } + + private: + void EnsureRemaining(std::size_t amount) { + if (static_cast(builder_.size() - builder_.position()) < amount) { + Flush(); + } + } + + util::scoped_malloc buf_; + double_conversion::StringBuilder builder_; + double_conversion::DoubleToStringConverter convert_; + int fd_; +}; +} // namespace + +PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector &counts, const HeaderInfo* header_info, int out_fd) + : vocab_(vocab), out_fd_(out_fd) { + std::stringstream stream; + + if (header_info) { + stream << "# Input file: " << header_info->input_file << '\n'; + stream << "# Token count: " << header_info->token_count << '\n'; + stream << "# Smoothing: Modified Kneser-Ney" << '\n'; + } + stream << "\\data\\\n"; + for (size_t i = 0; i < counts.size(); ++i) { + stream << "ngram " << (i+1) << '=' << counts[i] << '\n'; + } + stream << '\n'; + std::string as_string(stream.str()); + util::WriteOrThrow(out_fd, as_string.data(), as_string.size()); +} + +void PrintARPA::Run(const ChainPositions &positions) { + UTIL_TIMER("(%w s) Wrote ARPA file\n"); + OutputManager out(out_fd_); + for (unsigned order = 1; order <= positions.size(); ++order) { + out << "\\" << order << "-grams:" << '\n'; + for (NGramStream stream(positions[order - 1]); stream; ++stream) { + // Correcting for numerical precision issues. Take that IRST. + out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin()); + for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) { + out << ' ' << vocab_.Lookup(*i); + } + float backoff = stream->Value().complete.backoff; + if (backoff != 0.0) + out << '\t' << backoff; + out << '\n'; + } + out << '\n'; + } + out << "\\end\\\n"; +} + +}} // namespaces diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh new file mode 100644 index 00000000..aa932e75 --- /dev/null +++ b/klm/lm/builder/print.hh @@ -0,0 +1,102 @@ +#ifndef LM_BUILDER_PRINT__ +#define LM_BUILDER_PRINT__ + +#include "lm/builder/ngram.hh" +#include "lm/builder/multi_stream.hh" +#include "lm/builder/header_info.hh" +#include "util/file.hh" +#include "util/mmap.hh" +#include "util/string_piece.hh" + +#include + +#include + +// Warning: print routines read all unigrams before all bigrams before all +// trigrams etc. So if other parts of the chain move jointly, you'll have to +// buffer. + +namespace lm { namespace builder { + +class VocabReconstitute { + public: + // fd must be alive for life of this object; does not take ownership. + explicit VocabReconstitute(int fd); + + const char *Lookup(WordIndex index) const { + assert(index < map_.size() - 1); + return map_[index]; + } + + StringPiece LookupPiece(WordIndex index) const { + return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]); + } + + std::size_t Size() const { + // There's an extra entry to support StringPiece lengths. + return map_.size() - 1; + } + + private: + util::scoped_memory memory_; + std::vector map_; +}; + +// Not defined, only specialized. +template void PrintPayload(std::ostream &to, const Payload &payload); +template <> inline void PrintPayload(std::ostream &to, const Payload &payload) { + to << payload.count; +} +template <> inline void PrintPayload(std::ostream &to, const Payload &payload) { + to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma); +} +template <> inline void PrintPayload(std::ostream &to, const Payload &payload) { + to << payload.complete.prob << ' ' << payload.complete.backoff; +} + +// template parameter is the type stored. +template class Print { + public: + explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {} + + void Run(const ChainPositions &chains) { + NGramStreams streams(chains); + for (NGramStream *s = streams.begin(); s != streams.end(); ++s) { + DumpStream(*s); + } + } + + void Run(const util::stream::ChainPosition &position) { + NGramStream stream(position); + DumpStream(stream); + } + + private: + void DumpStream(NGramStream &stream) { + for (; stream; ++stream) { + PrintPayload(to_, stream->Value()); + for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) { + to_ << ' ' << vocab_.Lookup(*w) << '=' << *w; + } + to_ << '\n'; + } + } + + const VocabReconstitute &vocab_; + std::ostream &to_; +}; + +class PrintARPA { + public: + // header_info may be NULL to disable the header + explicit PrintARPA(const VocabReconstitute &vocab, const std::vector &counts, const HeaderInfo* header_info, int out_fd); + + void Run(const ChainPositions &positions); + + private: + const VocabReconstitute &vocab_; + int out_fd_; +}; + +}} // namespaces +#endif // LM_BUILDER_PRINT__ diff --git a/klm/lm/builder/sort.hh b/klm/lm/builder/sort.hh new file mode 100644 index 00000000..9989389b --- /dev/null +++ b/klm/lm/builder/sort.hh @@ -0,0 +1,103 @@ +#ifndef LM_BUILDER_SORT__ +#define LM_BUILDER_SORT__ + +#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram.hh" +#include "lm/word_index.hh" +#include "util/stream/sort.hh" + +#include "util/stream/timer.hh" + +#include +#include + +namespace lm { +namespace builder { + +template class Comparator : public std::binary_function { + public: + explicit Comparator(std::size_t order) : order_(order) {} + + inline bool operator()(const void *lhs, const void *rhs) const { + return static_cast(this)->Compare(static_cast(lhs), static_cast(rhs)); + } + + std::size_t Order() const { return order_; } + + protected: + std::size_t order_; +}; + +class SuffixOrder : public Comparator { + public: + explicit SuffixOrder(std::size_t order) : Comparator(order) {} + + inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { + for (std::size_t i = order_ - 1; i != 0; --i) { + if (lhs[i] != rhs[i]) + return lhs[i] < rhs[i]; + } + return lhs[0] < rhs[0]; + } + + static const unsigned kMatchOffset = 1; +}; + +class ContextOrder : public Comparator { + public: + explicit ContextOrder(std::size_t order) : Comparator(order) {} + + inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { + for (int i = order_ - 2; i >= 0; --i) { + if (lhs[i] != rhs[i]) + return lhs[i] < rhs[i]; + } + return lhs[order_ - 1] < rhs[order_ - 1]; + } +}; + +class PrefixOrder : public Comparator { + public: + explicit PrefixOrder(std::size_t order) : Comparator(order) {} + + inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { + for (std::size_t i = 0; i < order_; ++i) { + if (lhs[i] != rhs[i]) + return lhs[i] < rhs[i]; + } + return false; + } + + static const unsigned kMatchOffset = 0; +}; + +// Sum counts for the same n-gram. +struct AddCombiner { + bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const { + NGram first(first_void, compare.Order()); + // There isn't a const version of NGram. + NGram second(const_cast(second_void), compare.Order()); + if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false; + first.Count() += second.Count(); + return true; + } +}; + +// The combiner is only used on a single chain, so I didn't bother to allow +// that template. +template class Sorts : public FixedArray > { + private: + typedef util::stream::Sort S; + typedef FixedArray P; + + public: + void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) { + new (P::end()) S(chain, config, compare); + P::Constructed(); + } +}; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_SORT__ diff --git a/klm/lm/filter/arpa_io.cc b/klm/lm/filter/arpa_io.cc new file mode 100644 index 00000000..caf8df95 --- /dev/null +++ b/klm/lm/filter/arpa_io.cc @@ -0,0 +1,122 @@ +#include "lm/filter/arpa_io.hh" +#include "util/file_piece.hh" + +#include +#include +#include +#include + +#include +#include +#include + +namespace lm { + +ARPAInputException::ARPAInputException(const StringPiece &message) throw() : what_("Error: ") { + what_.append(message.data(), message.size()); +} + +ARPAInputException::ARPAInputException(const StringPiece &message, const StringPiece &line) throw() { + what_ = "Error: "; + what_.append(message.data(), message.size()); + what_ += " in line '"; + what_.append(line.data(), line.size()); + what_ += "'."; +} + +ARPAOutputException::ARPAOutputException(const char *message, const std::string &file_name) throw() + : what_(std::string(message) + " file " + file_name), file_name_(file_name) { + if (errno) { + char buf[1024]; + buf[0] = 0; +#if (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600) && ! _GNU_SOURCE + const char *add = buf; + if (!strerror_r(errno, buf, 1024)) { +#else + const char *add = strerror_r(errno, buf, 1024); + if (add) { +#endif + what_ += " :"; + what_ += add; + } + } +} + +// Seeking is the responsibility of the caller. +void WriteCounts(std::ostream &out, const std::vector &number) { + out << "\n\\data\\\n"; + for (unsigned int i = 0; i < number.size(); ++i) { + out << "ngram " << i+1 << "=" << number[i] << '\n'; + } + out << '\n'; +} + +size_t SizeNeededForCounts(const std::vector &number) { + std::ostringstream buf; + WriteCounts(buf, number); + return buf.tellp(); +} + +bool IsEntirelyWhiteSpace(const StringPiece &line) { + for (size_t i = 0; i < static_cast(line.size()); ++i) { + if (!isspace(line.data()[i])) return false; + } + return true; +} + +ARPAOutput::ARPAOutput(const char *name, size_t buffer_size) : file_name_(name), buffer_(new char[buffer_size]) { + try { + file_.exceptions(std::ostream::eofbit | std::ostream::failbit | std::ostream::badbit); + if (!file_.rdbuf()->pubsetbuf(buffer_.get(), buffer_size)) { + std::cerr << "Warning: could not enlarge buffer for " << name << std::endl; + buffer_.reset(); + } + file_.open(name, std::ios::out | std::ios::binary); + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Opening", file_name_); + } +} + +void ARPAOutput::ReserveForCounts(std::streampos reserve) { + try { + for (std::streampos i = 0; i < reserve; i += std::streampos(1)) { + file_ << '\n'; + } + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing blanks to reserve space for counts to ", file_name_); + } +} + +void ARPAOutput::BeginLength(unsigned int length) { + fast_counter_ = 0; + try { + file_ << '\\' << length << "-grams:" << '\n'; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing n-gram header to ", file_name_); + } +} + +void ARPAOutput::EndLength(unsigned int length) { + try { + file_ << '\n'; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing blank at end of count list to ", file_name_); + } + if (length > counts_.size()) { + counts_.resize(length); + } + counts_[length - 1] = fast_counter_; +} + +void ARPAOutput::Finish() { + try { + file_ << "\\end\\\n"; + file_.seekp(0); + WriteCounts(file_, counts_); + file_ << std::flush; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Finishing including writing counts at beginning to ", file_name_); + } +} + +} // namespace lm diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh new file mode 100644 index 00000000..90f48447 --- /dev/null +++ b/klm/lm/filter/arpa_io.hh @@ -0,0 +1,122 @@ +#ifndef LM_FILTER_ARPA_IO__ +#define LM_FILTER_ARPA_IO__ +/* Input and output for ARPA format language model files. + */ +#include "lm/read_arpa.hh" +#include "util/exception.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include +#include + +#include +#include +#include + +#include +#include + +namespace util { class FilePiece; } + +namespace lm { + +class ARPAInputException : public util::Exception { + public: + explicit ARPAInputException(const StringPiece &message) throw(); + explicit ARPAInputException(const StringPiece &message, const StringPiece &line) throw(); + virtual ~ARPAInputException() throw() {} + + const char *what() const throw() { return what_.c_str(); } + + private: + std::string what_; +}; + +class ARPAOutputException : public std::exception { + public: + ARPAOutputException(const char *prefix, const std::string &file_name) throw(); + virtual ~ARPAOutputException() throw() {} + + const char *what() const throw() { return what_.c_str(); } + + const std::string &File() const throw() { return file_name_; } + + private: + std::string what_; + const std::string file_name_; +}; + +// Handling for the counts of n-grams at the beginning of ARPA files. +size_t SizeNeededForCounts(const std::vector &number); + +/* Writes an ARPA file. This has to be seekable so the counts can be written + * at the end. Hence, I just have it own a std::fstream instead of accepting + * a separately held std::ostream. + */ +class ARPAOutput : boost::noncopyable { + public: + explicit ARPAOutput(const char *name, size_t buffer_size = 65536); + + void ReserveForCounts(std::streampos reserve); + + void BeginLength(unsigned int length); + + void AddNGram(const StringPiece &line) { + try { + file_ << line << '\n'; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing an n-gram", file_name_); + } + ++fast_counter_; + } + + void AddNGram(const StringPiece &ngram, const StringPiece &line) { + AddNGram(line); + } + + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + AddNGram(line); + } + + void EndLength(unsigned int length); + + void Finish(); + + private: + const std::string file_name_; + boost::scoped_array buffer_; + std::fstream file_; + size_t fast_counter_; + std::vector counts_; +}; + + +template void ReadNGrams(util::FilePiece &in, unsigned int length, size_t number, Output &out) { + ReadNGramHeader(in, length); + out.BeginLength(length); + for (size_t i = 0; i < number; ++i) { + StringPiece line = in.ReadLine(); + util::TokenIter tabber(line, '\t'); + if (!tabber) throw ARPAInputException("blank line", line); + if (!++tabber) throw ARPAInputException("no tab", line); + + out.AddNGram(*tabber, line); + } + out.EndLength(length); +} + +template void ReadARPA(util::FilePiece &in_lm, Output &out) { + std::vector number; + ReadARPACounts(in_lm, number); + out.ReserveForCounts(SizeNeededForCounts(number)); + for (unsigned int i = 0; i < number.size(); ++i) { + ReadNGrams(in_lm, i + 1, number[i], out); + } + ReadEnd(in_lm); + out.Finish(); +} + +} // namespace lm + +#endif // LM_FILTER_ARPA_IO__ diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh new file mode 100644 index 00000000..97c0fa25 --- /dev/null +++ b/klm/lm/filter/count_io.hh @@ -0,0 +1,91 @@ +#ifndef LM_FILTER_COUNT_IO__ +#define LM_FILTER_COUNT_IO__ + +#include +#include +#include + +#include + +#include "util/file_piece.hh" + +namespace lm { + +class CountOutput : boost::noncopyable { + public: + explicit CountOutput(const char *name) : file_(name, std::ios::out) {} + + void AddNGram(const StringPiece &line) { + if (!(file_ << line << '\n')) { + err(3, "Writing counts file failed"); + } + } + + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + AddNGram(line); + } + + void AddNGram(const StringPiece &ngram, const StringPiece &line) { + AddNGram(line); + } + + private: + std::fstream file_; +}; + +class CountBatch { + public: + explicit CountBatch(std::streamsize initial_read) + : initial_read_(initial_read) { + buffer_.reserve(initial_read); + } + + void Read(std::istream &in) { + buffer_.resize(initial_read_); + in.read(&*buffer_.begin(), initial_read_); + buffer_.resize(in.gcount()); + char got; + while (in.get(got) && got != '\n') + buffer_.push_back(got); + } + + template void Send(Output &out) { + for (util::TokenIter line(StringPiece(&*buffer_.begin(), buffer_.size()), '\n'); line; ++line) { + util::TokenIter tabber(*line, '\t'); + if (!tabber) { + std::cerr << "Warning: empty n-gram count line being removed\n"; + continue; + } + util::TokenIter words(*tabber, ' '); + if (!words) { + std::cerr << "Line has a tab but no words.\n"; + continue; + } + out.AddNGram(words, util::TokenIter::end(), *line); + } + } + + private: + std::streamsize initial_read_; + + // This could have been a std::string but that's less happy with raw writes. + std::vector buffer_; +}; + +template void ReadCount(util::FilePiece &in_file, Output &out) { + try { + while (true) { + StringPiece line = in_file.ReadLine(); + util::TokenIter tabber(line, '\t'); + if (!tabber) { + std::cerr << "Warning: empty n-gram count line being removed\n"; + continue; + } + out.AddNGram(*tabber, line); + } + } catch (const util::EndOfFileException &e) {} +} + +} // namespace lm + +#endif // LM_FILTER_COUNT_IO__ diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh new file mode 100644 index 00000000..7f945b0d --- /dev/null +++ b/klm/lm/filter/format.hh @@ -0,0 +1,250 @@ +#ifndef LM_FILTER_FORMAT_H__ +#define LM_FITLER_FORMAT_H__ + +#include "lm/filter/arpa_io.hh" +#include "lm/filter/count_io.hh" + +#include +#include + +#include + +namespace lm { + +template class MultipleOutput { + private: + typedef boost::ptr_vector Singles; + typedef typename Singles::iterator SinglesIterator; + + public: + MultipleOutput(const char *prefix, size_t number) { + files_.reserve(number); + std::string tmp; + for (unsigned int i = 0; i < number; ++i) { + tmp = prefix; + tmp += boost::lexical_cast(i); + files_.push_back(new Single(tmp.c_str())); + } + } + + void AddNGram(const StringPiece &line) { + for (SinglesIterator i = files_.begin(); i != files_.end(); ++i) + i->AddNGram(line); + } + + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + for (SinglesIterator i = files_.begin(); i != files_.end(); ++i) + i->AddNGram(begin, end, line); + } + + void SingleAddNGram(size_t offset, const StringPiece &line) { + files_[offset].AddNGram(line); + } + + template void SingleAddNGram(size_t offset, const Iterator &begin, const Iterator &end, const StringPiece &line) { + files_[offset].AddNGram(begin, end, line); + } + + protected: + Singles files_; +}; + +class MultipleARPAOutput : public MultipleOutput { + public: + MultipleARPAOutput(const char *prefix, size_t number) : MultipleOutput(prefix, number) {} + + void ReserveForCounts(std::streampos reserve) { + for (boost::ptr_vector::iterator i = files_.begin(); i != files_.end(); ++i) + i->ReserveForCounts(reserve); + } + + void BeginLength(unsigned int length) { + for (boost::ptr_vector::iterator i = files_.begin(); i != files_.end(); ++i) + i->BeginLength(length); + } + + void EndLength(unsigned int length) { + for (boost::ptr_vector::iterator i = files_.begin(); i != files_.end(); ++i) + i->EndLength(length); + } + + void Finish() { + for (boost::ptr_vector::iterator i = files_.begin(); i != files_.end(); ++i) + i->Finish(); + } +}; + +template class DispatchInput { + public: + DispatchInput(Filter &filter, Output &output) : filter_(filter), output_(output) {} + +/* template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + filter_.AddNGram(begin, end, line, output_); + }*/ + + void AddNGram(const StringPiece &ngram, const StringPiece &line) { + filter_.AddNGram(ngram, line, output_); + } + + protected: + Filter &filter_; + Output &output_; +}; + +template class DispatchARPAInput : public DispatchInput { + private: + typedef DispatchInput B; + + public: + DispatchARPAInput(Filter &filter, Output &output) : B(filter, output) {} + + void ReserveForCounts(std::streampos reserve) { B::output_.ReserveForCounts(reserve); } + void BeginLength(unsigned int length) { B::output_.BeginLength(length); } + + void EndLength(unsigned int length) { + B::filter_.Flush(); + B::output_.EndLength(length); + } + void Finish() { B::output_.Finish(); } +}; + +struct ARPAFormat { + typedef ARPAOutput Output; + typedef MultipleARPAOutput Multiple; + static void Copy(util::FilePiece &in, Output &out) { + ReadARPA(in, out); + } + template static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) { + DispatchARPAInput dispatcher(filter, output); + ReadARPA(in, dispatcher); + } +}; + +struct CountFormat { + typedef CountOutput Output; + typedef MultipleOutput Multiple; + static void Copy(util::FilePiece &in, Output &out) { + ReadCount(in, out); + } + template static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) { + DispatchInput dispatcher(filter, output); + ReadCount(in, dispatcher); + } +}; + +/* For multithreading, the buffer classes hold batches of filter inputs and + * outputs in memory. The strings get reused a lot, so keep them around + * instead of clearing each time. + */ +class InputBuffer { + public: + InputBuffer() : actual_(0) {} + + void Reserve(size_t size) { lines_.reserve(size); } + + template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + if (lines_.size() == actual_) lines_.resize(lines_.size() + 1); + // TODO avoid this copy. + std::string &copied = lines_[actual_].line; + copied.assign(line.data(), line.size()); + lines_[actual_].ngram.set(copied.data() + (ngram.data() - line.data()), ngram.size()); + ++actual_; + } + + template void CallFilter(Filter &filter, Output &output) const { + for (std::vector::const_iterator i = lines_.begin(); i != lines_.begin() + actual_; ++i) { + filter.AddNGram(i->ngram, i->line, output); + } + } + + void Clear() { actual_ = 0; } + bool Empty() { return actual_ == 0; } + size_t Size() { return actual_; } + + private: + struct Line { + std::string line; + StringPiece ngram; + }; + + size_t actual_; + + std::vector lines_; +}; + +class BinaryOutputBuffer { + public: + BinaryOutputBuffer() {} + + void Reserve(size_t size) { + lines_.reserve(size); + } + + void AddNGram(const StringPiece &line) { + lines_.push_back(line); + } + + template void Flush(Output &output) { + for (std::vector::const_iterator i = lines_.begin(); i != lines_.end(); ++i) { + output.AddNGram(*i); + } + lines_.clear(); + } + + private: + std::vector lines_; +}; + +class MultipleOutputBuffer { + public: + MultipleOutputBuffer() : last_(NULL) {} + + void Reserve(size_t size) { + annotated_.reserve(size); + } + + void AddNGram(const StringPiece &line) { + annotated_.resize(annotated_.size() + 1); + annotated_.back().line = line; + } + + void SingleAddNGram(size_t offset, const StringPiece &line) { + if ((line.data() == last_.data()) && (line.length() == last_.length())) { + annotated_.back().systems.push_back(offset); + } else { + annotated_.resize(annotated_.size() + 1); + annotated_.back().systems.push_back(offset); + annotated_.back().line = line; + last_ = line; + } + } + + template void Flush(Output &output) { + for (std::vector::const_iterator i = annotated_.begin(); i != annotated_.end(); ++i) { + if (i->systems.empty()) { + output.AddNGram(i->line); + } else { + for (std::vector::const_iterator j = i->systems.begin(); j != i->systems.end(); ++j) { + output.SingleAddNGram(*j, i->line); + } + } + } + annotated_.clear(); + } + + private: + struct Annotated { + // If this is empty, send to all systems. + // A filter should never send to all systems and send to a single one. + std::vector systems; + StringPiece line; + }; + + StringPiece last_; + + std::vector annotated_; +}; + +} // namespace lm + +#endif // LM_FILTER_FORMAT_H__ diff --git a/klm/lm/filter/main.cc b/klm/lm/filter/main.cc new file mode 100644 index 00000000..c42243e2 --- /dev/null +++ b/klm/lm/filter/main.cc @@ -0,0 +1,249 @@ +#include "lm/filter/arpa_io.hh" +#include "lm/filter/format.hh" +#include "lm/filter/phrase.hh" +#ifndef NTHREAD +#include "lm/filter/thread.hh" +#endif +#include "lm/filter/vocab.hh" +#include "lm/filter/wrapper.hh" +#include "util/file_piece.hh" + +#include + +#include +#include +#include +#include + +namespace lm { +namespace { + +void DisplayHelp(const char *name) { + std::cerr + << "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n" + "copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n" + " parser.\n" + "single mode treats the entire input as a single sentence.\n" + "multiple mode filters to multiple sentences in parallel. Each sentence is on\n" + " a separate line. A separate file is created for each file by appending the\n" + " 0-indexed line number to the output file name.\n" + "union mode produces one filtered model that is the union of models created by\n" + " multiple mode.\n\n" + "context means only the context (all but last word) has to pass the filter, but\n" + " the entire n-gram is output.\n\n" + "phrase means that the vocabulary is actually tab-delimited phrases and that the\n" + " phrases can generate the n-gram when assembled in arbitrary order and\n" + " clipped. Currently works with multiple or union mode.\n\n" + "The file format is set by [raw|arpa] with default arpa:\n" + "raw means space-separated tokens, optionally followed by a tab and arbitrary\n" + " text. This is useful for ngram count files.\n" + "arpa means the ARPA file format for n-gram language models.\n\n" +#ifndef NTHREAD + "threads:m sets m threads (default: conccurrency detected by boost)\n" + "batch_size:m sets the batch size for threading. Expect memory usage from this\n" + " of 2*threads*batch_size n-grams.\n\n" +#else + "This binary was compiled with -DNTHREAD, disabling threading. If you wanted\n" + " threading, compile without this flag against Boost >=1.42.0.\n\n" +#endif + "There are two inputs: vocabulary and model. Either may be given as a file\n" + " while the other is on stdin. Specify the type given as a file using\n" + " vocab: or model: before the file name. \n\n" + "For ARPA format, the output must be seekable. For raw format, it can be a\n" + " stream i.e. /dev/stdout\n"; +} + +typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION} FilterMode; +typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format; + +struct Config { + Config() : +#ifndef NTHREAD + batch_size(25000), + threads(boost::thread::hardware_concurrency()), +#endif + phrase(false), + context(false), + format(FORMAT_ARPA) + { +#ifndef NTHREAD + if (!threads) threads = 1; +#endif + } + +#ifndef NTHREAD + size_t batch_size; + size_t threads; +#endif + bool phrase; + bool context; + FilterMode mode; + Format format; +}; + +template void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) { +#ifndef NTHREAD + if (config.threads == 1) { +#endif + Format::RunFilter(in_lm, filter, output); +#ifndef NTHREAD + } else { + typedef Controller Threaded; + Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output); + Format::RunFilter(in_lm, threading, output); + } +#endif +} + +template void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) { + if (config.context) { + ContextFilter context_filter(filter); + RunThreadedFilter, OutputBuffer, Output>(config, in_lm, context_filter, output); + } else { + RunThreadedFilter(config, in_lm, filter, output); + } +} + +template void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) { + typedef BinaryFilter Filter; + RunContextFilter(config, in_lm, Filter(binary), out); +} + +template void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) { + if (config.mode == MODE_MULTIPLE) { + if (config.phrase) { + typedef phrase::Multiple Filter; + phrase::Substrings substrings; + typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings)); + RunContextFilter(config, in_lm, Filter(substrings), out); + } else { + typedef vocab::Multiple Filter; + boost::unordered_map > words; + typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words)); + RunContextFilter(config, in_lm, Filter(words), out); + } + return; + } + + typename Format::Output out(out_name); + + if (config.mode == MODE_COPY) { + Format::Copy(in_lm, out); + return; + } + + if (config.mode == MODE_SINGLE) { + vocab::Single::Words words; + vocab::ReadSingle(in_vocab, words); + DispatchBinaryFilter(config, in_lm, vocab::Single(words), out); + return; + } + + if (config.mode == MODE_UNION) { + if (config.phrase) { + phrase::Substrings substrings; + phrase::ReadMultiple(in_vocab, substrings); + DispatchBinaryFilter(config, in_lm, phrase::Union(substrings), out); + } else { + vocab::Union::Words words; + vocab::ReadMultiple(in_vocab, words); + DispatchBinaryFilter(config, in_lm, vocab::Union(words), out); + } + return; + } +} + +} // namespace +} // namespace lm + +int main(int argc, char *argv[]) { + if (argc < 4) { + lm::DisplayHelp(argv[0]); + return 1; + } + + // I used to have boost::program_options, but some users didn't want to compile boost. + lm::Config config; + boost::optional mode; + for (int i = 1; i < argc - 2; ++i) { + const char *str = argv[i]; + if (!std::strcmp(str, "copy")) { + mode = lm::MODE_COPY; + } else if (!std::strcmp(str, "single")) { + mode = lm::MODE_SINGLE; + } else if (!std::strcmp(str, "multiple")) { + mode = lm::MODE_MULTIPLE; + } else if (!std::strcmp(str, "union")) { + mode = lm::MODE_UNION; + } else if (!std::strcmp(str, "phrase")) { + config.phrase = true; + } else if (!std::strcmp(str, "context")) { + config.context = true; + } else if (!std::strcmp(str, "arpa")) { + config.format = lm::FORMAT_ARPA; + } else if (!std::strcmp(str, "raw")) { + config.format = lm::FORMAT_COUNT; +#ifndef NTHREAD + } else if (!std::strncmp(str, "threads:", 8)) { + config.threads = boost::lexical_cast(str + 8); + if (!config.threads) { + std::cerr << "Specify at least one thread." << std::endl; + return 1; + } + } else if (!std::strncmp(str, "batch_size:", 11)) { + config.batch_size = boost::lexical_cast(str + 11); + if (config.batch_size < 5000) { + std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; + if (!config.batch_size) return 1; + } +#endif + } else { + lm::DisplayHelp(argv[0]); + return 1; + } + } + + if (!mode) { + lm::DisplayHelp(argv[0]); + return 1; + } + config.mode = *mode; + + if (config.phrase && config.mode != lm::MODE_UNION && mode != lm::MODE_MULTIPLE) { + std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; + return 1; + } + + bool cmd_is_model = true; + const char *cmd_input = argv[argc - 2]; + if (!strncmp(cmd_input, "vocab:", 6)) { + cmd_is_model = false; + cmd_input += 6; + } else if (!strncmp(cmd_input, "model:", 6)) { + cmd_input += 6; + } else if (strchr(cmd_input, ':')) { + errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); + } else { + std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; + } + std::ifstream cmd_file; + std::istream *vocab; + if (cmd_is_model) { + vocab = &std::cin; + } else { + cmd_file.open(cmd_input, std::ios::in); + if (!cmd_file) { + err(2, "Could not open input file %s", cmd_input); + } + vocab = &cmd_file; + } + + util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); + + if (config.format == lm::FORMAT_ARPA) { + lm::DispatchFilterModes(config, *vocab, model, argv[argc - 1]); + } else if (config.format == lm::FORMAT_COUNT) { + lm::DispatchFilterModes(config, *vocab, model, argv[argc - 1]); + } + return 0; +} diff --git a/klm/lm/filter/phrase.cc b/klm/lm/filter/phrase.cc new file mode 100644 index 00000000..1bef2a3f --- /dev/null +++ b/klm/lm/filter/phrase.cc @@ -0,0 +1,281 @@ +#include "lm/filter/phrase.hh" + +#include "lm/filter/format.hh" + +#include +#include +#include +#include +#include +#include + +#include + +namespace lm { +namespace phrase { + +unsigned int ReadMultiple(std::istream &in, Substrings &out) { + bool sentence_content = false; + unsigned int sentence_id = 0; + std::vector phrase; + std::string word; + while (in) { + char c; + // Gather a word. + while (!isspace(c = in.get()) && in) word += c; + // Treat EOF like a newline. + if (!in) c = '\n'; + // Add the word to the phrase. + if (!word.empty()) { + phrase.push_back(util::MurmurHashNative(word.data(), word.size())); + word.clear(); + } + if (c == ' ') continue; + // It's more than just a space. Close out the phrase. + if (!phrase.empty()) { + sentence_content = true; + out.AddPhrase(sentence_id, phrase.begin(), phrase.end()); + phrase.clear(); + } + if (c == '\t' || c == '\v') continue; + // It's more than a space or tab: a newline. + if (sentence_content) { + ++sentence_id; + sentence_content = false; + } + } + if (!in.eof()) in.exceptions(std::istream::failbit | std::istream::badbit); + return sentence_id + sentence_content; +} + +namespace detail { const StringPiece kEndSentence(""); } + +namespace { + +typedef unsigned int Sentence; +typedef std::vector Sentences; + +class Vertex; + +class Arc { + public: + Arc() {} + + // For arcs from one vertex to another. + void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) { + Set(to, intersect); + from_ = &from; + } + + /* For arcs from before the n-gram begins to somewhere in the n-gram (right + * aligned). These have no from_ vertex; it implictly matches every + * sentence. This also handles when the n-gram is a substring of a phrase. + */ + void SetRight(Vertex &to, const Sentences &complete) { + Set(to, complete); + from_ = NULL; + } + + Sentence Current() const { + return *current_; + } + + bool Empty() const { + return current_ == last_; + } + + /* When this function returns: + * If Empty() then there's nothing left from this intersection. + * + * If Current() == to then to is part of the intersection. + * + * Otherwise, Current() > to. In this case, to is not part of the + * intersection and neither is anything < Current(). To determine if + * any value >= Current() is in the intersection, call LowerBound again + * with the value. + */ + void LowerBound(const Sentence to); + + private: + void Set(Vertex &to, const Sentences &sentences); + + const Sentence *current_; + const Sentence *last_; + Vertex *from_; +}; + +struct ArcGreater : public std::binary_function { + bool operator()(const Arc *first, const Arc *second) const { + return first->Current() > second->Current(); + } +}; + +class Vertex { + public: + Vertex() : current_(0) {} + + Sentence Current() const { + return current_; + } + + bool Empty() const { + return incoming_.empty(); + } + + void LowerBound(const Sentence to); + + private: + friend class Arc; + + void AddIncoming(Arc *arc) { + if (!arc->Empty()) incoming_.push(arc); + } + + unsigned int current_; + std::priority_queue, ArcGreater> incoming_; +}; + +void Arc::LowerBound(const Sentence to) { + current_ = std::lower_bound(current_, last_, to); + // If *current_ > to, don't advance from_. The intervening values of + // from_ may be useful for another one of its outgoing arcs. + if (!from_ || Empty() || (Current() > to)) return; + assert(Current() == to); + from_->LowerBound(to); + if (from_->Empty()) { + current_ = last_; + return; + } + assert(from_->Current() >= to); + if (from_->Current() > to) { + current_ = std::lower_bound(current_ + 1, last_, from_->Current()); + } +} + +void Arc::Set(Vertex &to, const Sentences &sentences) { + current_ = &*sentences.begin(); + last_ = &*sentences.end(); + to.AddIncoming(this); +} + +void Vertex::LowerBound(const Sentence to) { + if (Empty()) return; + // Union lower bound. + while (true) { + Arc *top = incoming_.top(); + if (top->Current() > to) { + current_ = top->Current(); + return; + } + // If top->Current() == to, we still need to verify that's an actual + // element and not just a bound. + incoming_.pop(); + top->LowerBound(to); + if (!top->Empty()) { + incoming_.push(top); + if (top->Current() == to) { + current_ = to; + return; + } + } else if (Empty()) { + return; + } + } +} + +void BuildGraph(const Substrings &phrase, const std::vector &hashes, Vertex *const vertices, Arc *free_arc) { + assert(!hashes.empty()); + + const Hash *const first_word = &*hashes.begin(); + const Hash *const last_word = &*hashes.end() - 1; + + Hash hash = 0; + const Sentences *found; + // Phrases starting at or before the first word in the n-gram. + { + Vertex *vertex = vertices; + for (const Hash *word = first_word; ; ++word, ++vertex) { + hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word); + // Now hash is [hashes.begin(), word]. + if (word == last_word) { + if (phrase.FindSubstring(hash, found)) + (free_arc++)->SetRight(*vertex, *found); + break; + } + if (!phrase.FindRight(hash, found)) break; + (free_arc++)->SetRight(*vertex, *found); + } + } + + // Phrases starting at the second or later word in the n-gram. + Vertex *vertex_from = vertices; + for (const Hash *word_from = first_word + 1; word_from != &*hashes.end(); ++word_from, ++vertex_from) { + hash = 0; + Vertex *vertex_to = vertex_from + 1; + for (const Hash *word_to = word_from; ; ++word_to, ++vertex_to) { + // Notice that word_to and vertex_to have the same index. + hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word_to); + // Now hash covers [word_from, word_to]. + if (word_to == last_word) { + if (phrase.FindLeft(hash, found)) + (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found); + break; + } + if (!phrase.FindPhrase(hash, found)) break; + (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found); + } + } +} + +} // namespace + +namespace detail { + +} // namespace detail + +bool Union::Evaluate() { + assert(!hashes_.empty()); + // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable. + Vertex vertices[hashes_.size()]; + // One for every substring. + Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; + BuildGraph(substrings_, hashes_, vertices, arcs); + Vertex &last_vertex = vertices[hashes_.size() - 1]; + + unsigned int lower = 0; + while (true) { + last_vertex.LowerBound(lower); + if (last_vertex.Empty()) return false; + if (last_vertex.Current() == lower) return true; + lower = last_vertex.Current(); + } +} + +template void Multiple::Evaluate(const StringPiece &line, Output &output) { + assert(!hashes_.empty()); + // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable. + Vertex vertices[hashes_.size()]; + // One for every substring. + Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; + BuildGraph(substrings_, hashes_, vertices, arcs); + Vertex &last_vertex = vertices[hashes_.size() - 1]; + + unsigned int lower = 0; + while (true) { + last_vertex.LowerBound(lower); + if (last_vertex.Empty()) return; + if (last_vertex.Current() == lower) { + output.SingleAddNGram(lower, line); + ++lower; + } else { + lower = last_vertex.Current(); + } + } +} + +template void Multiple::Evaluate(const StringPiece &line, CountFormat::Multiple &output); +template void Multiple::Evaluate(const StringPiece &line, ARPAFormat::Multiple &output); +template void Multiple::Evaluate(const StringPiece &line, MultipleOutputBuffer &output); + +} // namespace phrase +} // namespace lm diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh new file mode 100644 index 00000000..07479dea --- /dev/null +++ b/klm/lm/filter/phrase.hh @@ -0,0 +1,153 @@ +#ifndef LM_FILTER_PHRASE_H__ +#define LM_FILTER_PHRASE_H__ + +#include "util/murmur_hash.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include + +#include +#include + +#define LM_FILTER_PHRASE_METHOD(caps, lower) \ +bool Find##caps(Hash key, const std::vector *&out) const {\ + Table::const_iterator i(table_.find(key));\ + if (i==table_.end()) return false; \ + out = &i->second.lower; \ + return true; \ +} + +namespace lm { +namespace phrase { + +typedef uint64_t Hash; + +class Substrings { + private: + /* This is the value in a hash table where the key is a string. It indicates + * four sets of sentences: + * substring is sentences with a phrase containing the key as a substring. + * left is sentencess with a phrase that begins with the key (left aligned). + * right is sentences with a phrase that ends with the key (right aligned). + * phrase is sentences where the key is a phrase. + * Each set is encoded as a vector of sentence ids in increasing order. + */ + struct SentenceRelation { + std::vector substring, left, right, phrase; + }; + /* Most of the CPU is hash table lookups, so let's not complicate it with + * vector equality comparisons. If a collision happens, the SentenceRelation + * structure will contain the union of sentence ids over the colliding strings. + * In that case, the filter will be slightly more permissive. + * The key here is the same as boost's hash of std::vector. + */ + typedef boost::unordered_map Table; + + public: + Substrings() {} + + /* If the string isn't a substring of any phrase, return NULL. Otherwise, + * return a pointer to std::vector listing sentences with + * matching phrases. This set may be empty for Left, Right, or Phrase. + * Example: const std::vector *FindSubstring(Hash key) + */ + LM_FILTER_PHRASE_METHOD(Substring, substring) + LM_FILTER_PHRASE_METHOD(Left, left) + LM_FILTER_PHRASE_METHOD(Right, right) + LM_FILTER_PHRASE_METHOD(Phrase, phrase) + + // sentence_id must be non-decreasing. Iterators are over words in the phrase. + template void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) { + // Iterate over all substrings. + for (Iterator start = begin; start != end; ++start) { + Hash hash = 0; + SentenceRelation *relation; + for (Iterator finish = start; finish != end; ++finish) { + hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish); + // Now hash is of [start, finish]. + relation = &table_[hash]; + AppendSentence(relation->substring, sentence_id); + if (start == begin) AppendSentence(relation->left, sentence_id); + } + AppendSentence(relation->right, sentence_id); + if (start == begin) AppendSentence(relation->phrase, sentence_id); + } + } + + private: + void AppendSentence(std::vector &vec, unsigned int sentence_id) { + if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id); + } + + Table table_; +}; + +// Read a file with one sentence per line containing tab-delimited phrases of +// space-separated words. +unsigned int ReadMultiple(std::istream &in, Substrings &out); + +namespace detail { +extern const StringPiece kEndSentence; + +template void MakeHashes(Iterator i, const Iterator &end, std::vector &hashes) { + hashes.clear(); + if (i == end) return; + // TODO: check strict phrase boundaries after and before . For now, just skip tags. + if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) { + ++i; + } + for (; i != end && (*i != kEndSentence); ++i) { + hashes.push_back(util::MurmurHashNative(i->data(), i->size())); + } +} + +} // namespace detail + +class Union { + public: + explicit Union(const Substrings &substrings) : substrings_(substrings) {} + + template bool PassNGram(const Iterator &begin, const Iterator &end) { + detail::MakeHashes(begin, end, hashes_); + return hashes_.empty() || Evaluate(); + } + + private: + bool Evaluate(); + + std::vector hashes_; + + const Substrings &substrings_; +}; + +class Multiple { + public: + explicit Multiple(const Substrings &substrings) : substrings_(substrings) {} + + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { + detail::MakeHashes(begin, end, hashes_); + if (hashes_.empty()) { + output.AddNGram(line); + return; + } + Evaluate(line, output); + } + + template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + AddNGram(util::TokenIter(ngram, ' '), util::TokenIter::end(), line, output); + } + + void Flush() const {} + + private: + template void Evaluate(const StringPiece &line, Output &output); + + std::vector hashes_; + + const Substrings &substrings_; +}; + +} // namespace phrase +} // namespace lm +#endif // LM_FILTER_PHRASE_H__ diff --git a/klm/lm/filter/thread.hh b/klm/lm/filter/thread.hh new file mode 100644 index 00000000..e785b263 --- /dev/null +++ b/klm/lm/filter/thread.hh @@ -0,0 +1,167 @@ +#ifndef LM_FILTER_THREAD_H__ +#define LM_FILTER_THREAD_H__ + +#include "util/thread_pool.hh" + +#include + +#include +#include + +namespace lm { + +template class ThreadBatch { + public: + ThreadBatch() {} + + void Reserve(size_t size) { + input_.Reserve(size); + output_.Reserve(size); + } + + // File reading thread. + InputBuffer &Fill(uint64_t sequence) { + sequence_ = sequence; + // Why wait until now to clear instead of after output? free in the same + // thread as allocated. + input_.Clear(); + return input_; + } + + // Filter worker thread. + template void CallFilter(Filter &filter) { + input_.CallFilter(filter, output_); + } + + uint64_t Sequence() const { return sequence_; } + + // File writing thread. + template void Flush(RealOutput &output) { + output_.Flush(output); + } + + private: + InputBuffer input_; + OutputBuffer output_; + + uint64_t sequence_; +}; + +template class FilterWorker { + public: + typedef Batch *Request; + + FilterWorker(const Filter &filter, util::PCQueue &done) : filter_(filter), done_(done) {} + + void operator()(Request request) { + request->CallFilter(filter_); + done_.Produce(request); + } + + private: + Filter filter_; + + util::PCQueue &done_; +}; + +// There should only be one OutputWorker. +template class OutputWorker { + public: + typedef Batch *Request; + + OutputWorker(Output &output, util::PCQueue &done) : output_(output), done_(done), base_sequence_(0) {} + + void operator()(Request request) { + assert(request->Sequence() >= base_sequence_); + // Assemble the output in order. + uint64_t pos = request->Sequence() - base_sequence_; + if (pos >= ordering_.size()) { + ordering_.resize(pos + 1, NULL); + } + ordering_[pos] = request; + while (!ordering_.empty() && ordering_.front()) { + ordering_.front()->Flush(output_); + done_.Produce(ordering_.front()); + ordering_.pop_front(); + ++base_sequence_; + } + } + + private: + Output &output_; + + util::PCQueue &done_; + + std::deque ordering_; + + uint64_t base_sequence_; +}; + +template class Controller : boost::noncopyable { + private: + typedef ThreadBatch Batch; + + public: + Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output) + : batch_size_(batch_size), queue_size_(queue), + batches_(queue), + to_read_(queue), + output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL), + filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL), + sequence_(0) { + for (size_t i = 0; i < queue; ++i) { + batches_[i].Reserve(batch_size); + local_read_.push(&batches_[i]); + } + NewInput(); + } + + void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) { + input_->AddNGram(ngram, line, output); + if (input_->Size() == batch_size_) { + FlushInput(); + NewInput(); + } + } + + void Flush() { + FlushInput(); + while (local_read_.size() < queue_size_) { + MoveRead(); + } + NewInput(); + } + + private: + void FlushInput() { + if (input_->Empty()) return; + filter_.Produce(local_read_.top()); + local_read_.pop(); + if (local_read_.empty()) MoveRead(); + } + + void NewInput() { + input_ = &local_read_.top()->Fill(sequence_++); + } + + void MoveRead() { + local_read_.push(to_read_.Consume()); + } + + const size_t batch_size_; + const size_t queue_size_; + + std::vector batches_; + + util::PCQueue to_read_; + std::stack local_read_; + util::ThreadPool > output_; + util::ThreadPool > filter_; + + uint64_t sequence_; + InputBuffer *input_; +}; + +} // namespace lm + +#endif // LM_FILTER_THREAD_H__ diff --git a/klm/lm/filter/vocab.cc b/klm/lm/filter/vocab.cc new file mode 100644 index 00000000..7ee4e84b --- /dev/null +++ b/klm/lm/filter/vocab.cc @@ -0,0 +1,54 @@ +#include "lm/filter/vocab.hh" + +#include +#include + +#include +#include + +namespace lm { +namespace vocab { + +void ReadSingle(std::istream &in, boost::unordered_set &out) { + in.exceptions(std::istream::badbit); + std::string word; + while (in >> word) { + out.insert(word); + } +} + +namespace { +bool IsLineEnd(std::istream &in) { + int got; + do { + got = in.get(); + if (!in) return true; + if (got == '\n') return true; + } while (isspace(got)); + in.unget(); + return false; +} +}// namespace + +// Read space separated words in enter separated lines. These lines can be +// very long, so don't read an entire line at a time. +unsigned int ReadMultiple(std::istream &in, boost::unordered_map > &out) { + in.exceptions(std::istream::badbit); + unsigned int sentence = 0; + bool used_id = false; + std::string word; + while (in >> word) { + used_id = true; + std::vector &posting = out[word]; + if (posting.empty() || (posting.back() != sentence)) + posting.push_back(sentence); + if (IsLineEnd(in)) { + ++sentence; + used_id = false; + } + } + return sentence + used_id; +} + +} // namespace vocab +} // namespace lm diff --git a/klm/lm/filter/vocab.hh b/klm/lm/filter/vocab.hh new file mode 100644 index 00000000..e2b6adff --- /dev/null +++ b/klm/lm/filter/vocab.hh @@ -0,0 +1,132 @@ +#ifndef LM_FILTER_VOCAB_H__ +#define LM_FILTER_VOCAB_H__ + +// Vocabulary-based filters for language models. + +#include "util/multi_intersection.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include +#include +#include +#include + +#include +#include + +namespace lm { +namespace vocab { + +void ReadSingle(std::istream &in, boost::unordered_set &out); + +// Read one sentence vocabulary per line. Return the number of sentences. +unsigned int ReadMultiple(std::istream &in, boost::unordered_map > &out); + +/* Is this a special tag like or ? This actually includes anything + * surrounded with < and >, which most tokenizers separate for real words, so + * this should not catch real words as it looks at a single token. + */ +inline bool IsTag(const StringPiece &value) { + // The parser should never give an empty string. + assert(!value.empty()); + return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>'); +} + +class Single { + public: + typedef boost::unordered_set Words; + + explicit Single(const Words &vocab) : vocab_(vocab) {} + + template bool PassNGram(const Iterator &begin, const Iterator &end) { + for (Iterator i = begin; i != end; ++i) { + if (IsTag(*i)) continue; + if (FindStringPiece(vocab_, *i) == vocab_.end()) return false; + } + return true; + } + + private: + const Words &vocab_; +}; + +class Union { + public: + typedef boost::unordered_map > Words; + + explicit Union(const Words &vocabs) : vocabs_(vocabs) {} + + template bool PassNGram(const Iterator &begin, const Iterator &end) { + sets_.clear(); + + for (Iterator i(begin); i != end; ++i) { + if (IsTag(*i)) continue; + Words::const_iterator found(FindStringPiece(vocabs_, *i)); + if (vocabs_.end() == found) return false; + sets_.push_back(boost::iterator_range(&*found->second.begin(), &*found->second.end())); + } + return (sets_.empty() || util::FirstIntersection(sets_)); + } + + private: + const Words &vocabs_; + + std::vector > sets_; +}; + +class Multiple { + public: + typedef boost::unordered_map > Words; + + Multiple(const Words &vocabs) : vocabs_(vocabs) {} + + private: + // Callback from AllIntersection that does AddNGram. + template class Callback { + public: + Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {} + + void operator()(unsigned int index) { + out_.SingleAddNGram(index, line_); + } + + private: + Output &out_; + const StringPiece &line_; + }; + + public: + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { + sets_.clear(); + for (Iterator i(begin); i != end; ++i) { + if (IsTag(*i)) continue; + Words::const_iterator found(FindStringPiece(vocabs_, *i)); + if (vocabs_.end() == found) return; + sets_.push_back(boost::iterator_range(&*found->second.begin(), &*found->second.end())); + } + if (sets_.empty()) { + output.AddNGram(line); + return; + } + + Callback cb(output, line); + util::AllIntersection(sets_, cb); + } + + template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + AddNGram(util::TokenIter(ngram, ' '), util::TokenIter::end(), line, output); + } + + void Flush() const {} + + private: + const Words &vocabs_; + + std::vector > sets_; +}; + +} // namespace vocab +} // namespace lm + +#endif // LM_FILTER_VOCAB_H__ diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh new file mode 100644 index 00000000..90b07a08 --- /dev/null +++ b/klm/lm/filter/wrapper.hh @@ -0,0 +1,58 @@ +#ifndef LM_FILTER_WRAPPER_H__ +#define LM_FILTER_WRAPPER_H__ + +#include "util/string_piece.hh" + +#include +#include +#include + +namespace lm { + +// Provide a single-output filter with the same interface as a +// multiple-output filter so clients code against one interface. +template class BinaryFilter { + public: + // Binary modes are just references (and a set) and it makes the API cleaner to copy them. + explicit BinaryFilter(Binary binary) : binary_(binary) {} + + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { + if (binary_.PassNGram(begin, end)) + output.AddNGram(line); + } + + template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + AddNGram(util::TokenIter(ngram, ' '), util::TokenIter::end(), line, output); + } + + void Flush() const {} + + private: + Binary binary_; +}; + +// Wrap another filter to pay attention only to context words +template class ContextFilter { + public: + typedef FilterT Filter; + + explicit ContextFilter(Filter &backend) : backend_(backend) {} + + template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + pieces_.clear(); + // TODO: this copy could be avoided by a lookahead iterator. + std::copy(util::TokenIter(ngram, ' '), util::TokenIter::end(), std::back_insert_iterator >(pieces_)); + backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output); + } + + void Flush() const {} + + private: + std::vector pieces_; + + Filter backend_; +}; + +} // namespace lm + +#endif // LM_FILTER_WRAPPER_H__ diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 32084b5b..eb159094 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -1,6 +1,7 @@ #include "lm/model.hh" #include +#include #define BOOST_TEST_MODULE ModelTest #include @@ -22,17 +23,20 @@ std::ostream &operator<<(std::ostream &o, const State &state) { namespace { +// Stupid bjam reverses the command line arguments randomly. const char *TestLocation() { - if (boost::unit_test::framework::master_test_suite().argc < 2) { + if (boost::unit_test::framework::master_test_suite().argc < 3) { return "test.arpa"; } - return boost::unit_test::framework::master_test_suite().argv[1]; + char **argv = boost::unit_test::framework::master_test_suite().argv; + return argv[strstr(argv[1], "nounk") ? 2 : 1]; } const char *TestNoUnkLocation() { if (boost::unit_test::framework::master_test_suite().argc < 3) { return "test_nounk.arpa"; } - return boost::unit_test::framework::master_test_suite().argv[2]; + char **argv = boost::unit_test::framework::master_test_suite().argv; + return argv[strstr(argv[1], "nounk") ? 1 : 2]; } template State GetState(const Model &model, const char *word, const State &in) { diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index b709fef9..9ea08798 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -1,6 +1,7 @@ #include "lm/read_arpa.hh" #include "lm/blank.hh" +#include "util/file.hh" #include #include @@ -45,8 +46,14 @@ uint64_t ReadCount(const std::string &from) { void ReadARPACounts(util::FilePiece &in, std::vector &number) { number.clear(); - StringPiece line; - while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} + StringPiece line = in.ReadLine(); + // In general, ARPA files can have arbitrary text before "\data\" + // But in KenLM, we require such lines to start with "#", so that + // we can do stricter error checking + while (IsEntirelyWhiteSpace(line) || line.starts_with("#")) { + line = in.ReadLine(); + } + if (line != "\\data\\") { if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast(line.data()[1]) == 0x8b)) { UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); diff --git a/klm/lm/sizes.cc b/klm/lm/sizes.cc new file mode 100644 index 00000000..55ad586c --- /dev/null +++ b/klm/lm/sizes.cc @@ -0,0 +1,63 @@ +#include "lm/sizes.hh" +#include "lm/model.hh" +#include "util/file_piece.hh" + +#include +#include + +namespace lm { +namespace ngram { + +void ShowSizes(const std::vector &counts, const lm::ngram::Config &config) { + uint64_t sizes[6]; + sizes[0] = ProbingModel::Size(counts, config); + sizes[1] = RestProbingModel::Size(counts, config); + sizes[2] = TrieModel::Size(counts, config); + sizes[3] = QuantTrieModel::Size(counts, config); + sizes[4] = ArrayTrieModel::Size(counts, config); + sizes[5] = QuantArrayTrieModel::Size(counts, config); + uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); + uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); + uint64_t divide; + char prefix; + if (min_length < (1 << 10) * 10) { + prefix = ' '; + divide = 1; + } else if (min_length < (1 << 20) * 10) { + prefix = 'k'; + divide = 1 << 10; + } else if (min_length < (1ULL << 30) * 10) { + prefix = 'M'; + divide = 1 << 20; + } else { + prefix = 'G'; + divide = 1 << 30; + } + long int length = std::max(2, static_cast(ceil(log10((double) max_length / divide)))); + std::cerr << "Memory estimate for binary LM:\ntype "; + + // right align bytes. + for (long int i = 0; i < length - 2; ++i) std::cerr << ' '; + + std::cerr << prefix << "B\n" + "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" + "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n" + "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n" + "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" + "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" + "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; +} + +void ShowSizes(const std::vector &counts) { + lm::ngram::Config config; + ShowSizes(counts, config); +} + +void ShowSizes(const char *file, const lm::ngram::Config &config) { + std::vector counts; + util::FilePiece f(file); + lm::ReadARPACounts(f, counts); + ShowSizes(counts, config); +} + +}} //namespaces diff --git a/klm/lm/sizes.hh b/klm/lm/sizes.hh new file mode 100644 index 00000000..85abade7 --- /dev/null +++ b/klm/lm/sizes.hh @@ -0,0 +1,17 @@ +#ifndef LM_SIZES__ +#define LM_SIZES__ + +#include + +#include + +namespace lm { namespace ngram { + +struct Config; + +void ShowSizes(const std::vector &counts, const lm::ngram::Config &config); +void ShowSizes(const std::vector &counts); +void ShowSizes(const char *file, const lm::ngram::Config &config); + +}} // namespaces +#endif // LM_SIZES__ diff --git a/klm/lm/state.hh b/klm/lm/state.hh index 551510a8..d8e6c132 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -56,14 +56,14 @@ inline uint64_t hash_value(const State &state, uint64_t seed = 0) { struct Left { bool operator==(const Left &other) const { return - (length == other.length) && - pointers[length - 1] == other.pointers[length - 1] && - full == other.full; + length == other.length && + (!length || (pointers[length - 1] == other.pointers[length - 1] && full == other.full)); } int Compare(const Left &other) const { if (length < other.length) return -1; if (length > other.length) return 1; + if (length == 0) return 0; // Must be full. if (pointers[length - 1] > other.pointers[length - 1]) return 1; if (pointers[length - 1] < other.pointers[length - 1]) return -1; return (int)full - (int)other.full; diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 8663e94e..dc542bb3 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -65,13 +65,13 @@ class PartialViewProxy { typedef util::ProxyIterator PartialIter; -FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) { - util::scoped_fd file(maker.Make()); +FILE *DiskFlush(const void *mem_begin, const void *mem_end, const std::string &temp_prefix) { + util::scoped_fd file(util::MakeTemp(temp_prefix)); util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); return util::FDOpenOrThrow(file); } -FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) { +FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_prefix, std::size_t entry_size, unsigned char order) { const size_t context_size = sizeof(WordIndex) * (order - 1); // Sort just the contexts using the same memory. PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); @@ -84,7 +84,7 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make #endif (context_begin, context_end, util::SizedCompare(EntryCompare(order - 1))); - util::scoped_FILE out(maker.MakeFile()); + util::scoped_FILE out(util::FMakeTemp(temp_prefix)); // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. if (context_begin == context_end) return out.release(); @@ -114,12 +114,12 @@ struct FirstCombine { } }; -template FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) { +template FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const std::string &temp_prefix, std::size_t weights_size, unsigned char order, const Combine &combine) { std::size_t entry_size = sizeof(WordIndex) * order + weights_size; RecordReader first, second; first.Init(first_file, entry_size); second.Init(second_file, entry_size); - util::scoped_FILE out_file(maker.MakeFile()); + util::scoped_FILE out_file(util::FMakeTemp(temp_prefix)); EntryCompare less(order); while (first && second) { if (less(first.Data(), second.Data())) { @@ -177,9 +177,8 @@ void RecordReader::Rewind() { } SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { - util::TempMaker maker(file_prefix); PositiveProbWarn warn(config.positive_log_probability); - unigram_.reset(maker.Make()); + unigram_.reset(util::MakeTemp(file_prefix)); { // In case appears. size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff); @@ -202,7 +201,7 @@ SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) { +void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) { ReadNGramHeader(f, order); const size_t count = counts[order - 1]; // Size of weights. Does it include backoff? @@ -261,8 +260,8 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo std::sort #endif (NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare(EntryCompare(order))); - files.push_back(DiskFlush(begin, out_end, maker)); - contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order)); + files.push_back(DiskFlush(begin, out_end, file_prefix)); + contexts.push_back(WriteContextFile(begin, out_end, file_prefix, entry_size, order)); done += (out_end - begin) / entry_size; } @@ -270,10 +269,10 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo // All individual files created. Merge them. while (files.size() > 1) { - files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine())); + files.push_back(MergeSortedFiles(files[0], files[1], file_prefix, weights_size, order, ThrowCombine())); files_closer.PopFront(); files_closer.PopFront(); - contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine())); + contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], file_prefix, 0, order - 1, FirstCombine())); contexts_closer.PopFront(); contexts_closer.PopFront(); } diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh index 2197b80c..1afd9562 100644 --- a/klm/lm/trie_sort.hh +++ b/klm/lm/trie_sort.hh @@ -18,7 +18,6 @@ namespace util { class FilePiece; -class TempMaker; } // namespace util namespace lm { @@ -101,7 +100,7 @@ class SortedFiles { } private: - void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size); + void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, const std::string &prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size); util::scoped_fd unigram_; -- cgit v1.2.3