summaryrefslogtreecommitdiff
path: root/klm/lm/builder
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/builder')
-rw-r--r--klm/lm/builder/README.md47
-rw-r--r--klm/lm/builder/TODO5
-rw-r--r--klm/lm/builder/adjust_counts.cc216
-rw-r--r--klm/lm/builder/adjust_counts.hh44
-rw-r--r--klm/lm/builder/adjust_counts_test.cc106
-rw-r--r--klm/lm/builder/corpus_count.cc223
-rw-r--r--klm/lm/builder/corpus_count.hh42
-rw-r--r--klm/lm/builder/corpus_count_test.cc76
-rw-r--r--klm/lm/builder/discount.hh26
-rw-r--r--klm/lm/builder/header_info.hh20
-rw-r--r--klm/lm/builder/initial_probabilities.cc136
-rw-r--r--klm/lm/builder/initial_probabilities.hh34
-rw-r--r--klm/lm/builder/interpolate.cc65
-rw-r--r--klm/lm/builder/interpolate.hh27
-rw-r--r--klm/lm/builder/joint_order.hh43
-rw-r--r--klm/lm/builder/main.cc94
-rw-r--r--klm/lm/builder/multi_stream.hh180
-rw-r--r--klm/lm/builder/ngram.hh84
-rw-r--r--klm/lm/builder/ngram_stream.hh55
-rw-r--r--klm/lm/builder/pipeline.cc320
-rw-r--r--klm/lm/builder/pipeline.hh40
-rw-r--r--klm/lm/builder/print.cc135
-rw-r--r--klm/lm/builder/print.hh102
-rw-r--r--klm/lm/builder/sort.hh103
24 files changed, 2223 insertions, 0 deletions
diff --git a/klm/lm/builder/README.md b/klm/lm/builder/README.md
new file mode 100644
index 00000000..be0d35e2
--- /dev/null
+++ b/klm/lm/builder/README.md
@@ -0,0 +1,47 @@
+Dependencies
+============
+
+Boost >= 1.42.0 is required.
+
+For Ubuntu,
+```bash
+sudo apt-get install libboost1.48-all-dev
+```
+
+Alternatively, you can download, compile, and install it yourself:
+
+```bash
+wget http://sourceforge.net/projects/boost/files/boost/1.52.0/boost_1_52_0.tar.gz/download -O boost_1_52_0.tar.gz
+tar -xvzf boost_1_52_0.tar.gz
+cd boost_1_52_0
+./bootstrap.sh
+./b2
+sudo ./b2 install
+```
+
+Local install options (in a user-space prefix directory) are also possible. See http://www.boost.org/doc/libs/1_52_0/doc/html/bbv2/installation.html.
+
+
+Building
+========
+
+```bash
+bjam
+```
+Your distribution might package bjam and boost-build separately from Boost. Both are required.
+
+Usage
+=====
+
+Run
+```bash
+$ bin/lmplz
+```
+to see command line arguments
+
+Running
+=======
+
+```bash
+bin/lmplz -o 5 <text >text.arpa
+```
diff --git a/klm/lm/builder/TODO b/klm/lm/builder/TODO
new file mode 100644
index 00000000..cb5aef3a
--- /dev/null
+++ b/klm/lm/builder/TODO
@@ -0,0 +1,5 @@
+More tests!
+Sharding.
+Some way to manage all the crazy config options.
+Option to build the binary file directly.
+Interpolation of different orders.
diff --git a/klm/lm/builder/adjust_counts.cc b/klm/lm/builder/adjust_counts.cc
new file mode 100644
index 00000000..a6f48011
--- /dev/null
+++ b/klm/lm/builder/adjust_counts.cc
@@ -0,0 +1,216 @@
+#include "lm/builder/adjust_counts.hh"
+#include "lm/builder/multi_stream.hh"
+#include "util/stream/timer.hh"
+
+#include <algorithm>
+
+namespace lm { namespace builder {
+
+BadDiscountException::BadDiscountException() throw() {}
+BadDiscountException::~BadDiscountException() throw() {}
+
+namespace {
+// Return last word in full that is different.
+const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) {
+ const WordIndex *cur_word = full.end() - 1;
+ const WordIndex *pre_word = lower_last.end() - 1;
+ // Find last difference.
+ for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {}
+ return cur_word;
+}
+
+class StatCollector {
+ public:
+ StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts)
+ : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) {
+ memset(&orders_[0], 0, sizeof(OrderStat) * order);
+ }
+
+ ~StatCollector() {}
+
+ void CalculateDiscounts() {
+ counts_.resize(orders_.size());
+ discounts_.resize(orders_.size());
+ for (std::size_t i = 0; i < orders_.size(); ++i) {
+ const OrderStat &s = orders_[i];
+ counts_[i] = s.count;
+
+ for (unsigned j = 1; j < 4; ++j) {
+ // TODO: Specialize error message for j == 3, meaning 3+
+ UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
+ << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
+ << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?");
+ }
+
+ // See equation (26) in Chen and Goodman.
+ discounts_[i].amount[0] = 0.0;
+ float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
+ for (unsigned j = 1; j < 4; ++j) {
+ discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
+ UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
+ }
+ }
+ }
+
+ void Add(std::size_t order_minus_1, uint64_t count) {
+ OrderStat &stat = orders_[order_minus_1];
+ ++stat.count;
+ if (count < 5) ++stat.n[count];
+ }
+
+ void AddFull(uint64_t count) {
+ ++full_.count;
+ if (count < 5) ++full_.n[count];
+ }
+
+ private:
+ struct OrderStat {
+ // n_1 in equation 26 of Chen and Goodman etc
+ uint64_t n[5];
+ uint64_t count;
+ };
+
+ std::vector<OrderStat> orders_;
+ OrderStat &full_;
+
+ std::vector<uint64_t> &counts_;
+ std::vector<Discount> &discounts_;
+};
+
+// Reads all entries in order like NGramStream does.
+// But deletes any entries that have <s> in the 1st (not 0th) position on the
+// way out by putting other entries in their place. This disrupts the sort
+// order but we don't care because the data is going to be sorted again.
+class CollapseStream {
+ public:
+ CollapseStream(const util::stream::ChainPosition &position) :
+ current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
+ block_(position) {
+ StartBlock();
+ }
+
+ const NGram &operator*() const { return current_; }
+ const NGram *operator->() const { return &current_; }
+
+ operator bool() const { return block_; }
+
+ CollapseStream &operator++() {
+ assert(block_);
+ if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) {
+ memcpy(current_.Base(), copy_from_, current_.TotalSize());
+ UpdateCopyFrom();
+ }
+ current_.NextInMemory();
+ uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
+ if (current_.Base() == block_base + block_->ValidSize()) {
+ block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base);
+ ++block_;
+ StartBlock();
+ }
+ return *this;
+ }
+
+ private:
+ void StartBlock() {
+ for (; ; ++block_) {
+ if (!block_) return;
+ if (block_->ValidSize()) break;
+ }
+ current_.ReBase(block_->Get());
+ copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize();
+ UpdateCopyFrom();
+ }
+
+ // Find last without bos.
+ void UpdateCopyFrom() {
+ for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
+ if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break;
+ }
+ }
+
+ NGram current_;
+
+ // Goes backwards in the block
+ uint8_t *copy_from_;
+
+ util::stream::Link block_;
+};
+
+} // namespace
+
+void AdjustCounts::Run(const ChainPositions &positions) {
+ UTIL_TIMER("(%w s) Adjusted counts\n");
+
+ const std::size_t order = positions.size();
+ StatCollector stats(order, counts_, discounts_);
+ if (order == 1) {
+ // Only unigrams. Just collect stats.
+ for (NGramStream full(positions[0]); full; ++full)
+ stats.AddFull(full->Count());
+ stats.CalculateDiscounts();
+ return;
+ }
+
+ NGramStreams streams;
+ streams.Init(positions, positions.size() - 1);
+ CollapseStream full(positions[positions.size() - 1]);
+
+ // Initialization: <unk> has count 0 and so does <s>.
+ NGramStream *lower_valid = streams.begin();
+ streams[0]->Count() = 0;
+ *streams[0]->begin() = kUNK;
+ stats.Add(0, 0);
+ (++streams[0])->Count() = 0;
+ *streams[0]->begin() = kBOS;
+ // not in stats because it will get put in later.
+
+ // iterate over full (the stream of the highest order ngrams)
+ for (; full; ++full) {
+ const WordIndex *different = FindDifference(*full, **lower_valid);
+ std::size_t same = full->end() - 1 - different;
+ // Increment the adjusted count.
+ if (same) ++streams[same - 1]->Count();
+
+ // Output all the valid ones that changed.
+ for (; lower_valid >= &streams[same]; --lower_valid) {
+ stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count());
+ ++*lower_valid;
+ }
+
+ // This is here because bos is also const WordIndex *, so copy gets
+ // consistent argument types.
+ const WordIndex *full_end = full->end();
+ // Initialize and mark as valid up to bos.
+ const WordIndex *bos;
+ for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
+ ++lower_valid;
+ std::copy(bos, full_end, (*lower_valid)->begin());
+ (*lower_valid)->Count() = 1;
+ }
+ // Now bos indicates where <s> is or is the 0th word of full.
+ if (bos != full->begin()) {
+ // There is an <s> beyond the 0th word.
+ NGramStream &to = *++lower_valid;
+ std::copy(bos, full_end, to->begin());
+ to->Count() = full->Count();
+ } else {
+ stats.AddFull(full->Count());
+ }
+ assert(lower_valid >= &streams[0]);
+ }
+
+ // Output everything valid.
+ for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
+ stats.Add(s - streams.begin(), (*s)->Count());
+ ++*s;
+ }
+ // Poison everyone! Except the N-grams which were already poisoned by the input.
+ for (NGramStream *s = streams.begin(); s != streams.end(); ++s)
+ s->Poison();
+
+ stats.CalculateDiscounts();
+
+ // NOTE: See special early-return case for unigrams near the top of this function
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh
new file mode 100644
index 00000000..f38ff79d
--- /dev/null
+++ b/klm/lm/builder/adjust_counts.hh
@@ -0,0 +1,44 @@
+#ifndef LM_BUILDER_ADJUST_COUNTS__
+#define LM_BUILDER_ADJUST_COUNTS__
+
+#include "lm/builder/discount.hh"
+#include "util/exception.hh"
+
+#include <vector>
+
+#include <stdint.h>
+
+namespace lm {
+namespace builder {
+
+class ChainPositions;
+
+class BadDiscountException : public util::Exception {
+ public:
+ BadDiscountException() throw();
+ ~BadDiscountException() throw();
+};
+
+/* Compute adjusted counts.
+ * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts.
+ * Output: [1,N]-grams with adjusted counts.
+ * [1,N)-grams are in suffix order
+ * N-grams are in undefined order (they're going to be sorted anyway).
+ */
+class AdjustCounts {
+ public:
+ AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts)
+ : counts_(counts), discounts_(discounts) {}
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ std::vector<uint64_t> &counts_;
+ std::vector<Discount> &discounts_;
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_ADJUST_COUNTS__
+
diff --git a/klm/lm/builder/adjust_counts_test.cc b/klm/lm/builder/adjust_counts_test.cc
new file mode 100644
index 00000000..68b5f33e
--- /dev/null
+++ b/klm/lm/builder/adjust_counts_test.cc
@@ -0,0 +1,106 @@
+#include "lm/builder/adjust_counts.hh"
+
+#include "lm/builder/multi_stream.hh"
+#include "util/scoped.hh"
+
+#include <boost/thread/thread.hpp>
+#define BOOST_TEST_MODULE AdjustCounts
+#include <boost/test/unit_test.hpp>
+
+namespace lm { namespace builder { namespace {
+
+class KeepCopy {
+ public:
+ KeepCopy() : size_(0) {}
+
+ void Run(const util::stream::ChainPosition &position) {
+ for (util::stream::Link link(position); link; ++link) {
+ mem_.call_realloc(size_ + link->ValidSize());
+ memcpy(static_cast<uint8_t*>(mem_.get()) + size_, link->Get(), link->ValidSize());
+ size_ += link->ValidSize();
+ }
+ }
+
+ uint8_t *Get() { return static_cast<uint8_t*>(mem_.get()); }
+ std::size_t Size() const { return size_; }
+
+ private:
+ util::scoped_malloc mem_;
+ std::size_t size_;
+};
+
+struct Gram4 {
+ WordIndex ids[4];
+ uint64_t count;
+};
+
+class WriteInput {
+ public:
+ void Run(const util::stream::ChainPosition &position) {
+ NGramStream input(position);
+ Gram4 grams[] = {
+ {{0,0,0,0},10},
+ {{0,0,3,0},3},
+ // bos
+ {{1,1,1,2},5},
+ {{0,0,3,2},5},
+ };
+ for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) {
+ memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4);
+ input->Count() = grams[i].count;
+ }
+ input.Poison();
+ }
+};
+
+BOOST_AUTO_TEST_CASE(Simple) {
+ KeepCopy outputs[4];
+ std::vector<uint64_t> counts;
+ std::vector<Discount> discount;
+ {
+ util::stream::ChainConfig config;
+ config.total_memory = 100;
+ config.block_count = 1;
+ Chains chains(4);
+ for (unsigned i = 0; i < 4; ++i) {
+ config.entry_size = NGram::TotalSize(i + 1);
+ chains.push_back(config);
+ }
+
+ chains[3] >> WriteInput();
+ ChainPositions for_adjust(chains);
+ for (unsigned i = 0; i < 4; ++i) {
+ chains[i] >> boost::ref(outputs[i]);
+ }
+ chains >> util::stream::kRecycle;
+ BOOST_CHECK_THROW(AdjustCounts(counts, discount).Run(for_adjust), BadDiscountException);
+ }
+ BOOST_REQUIRE_EQUAL(4UL, counts.size());
+ BOOST_CHECK_EQUAL(4UL, counts[0]);
+ // These are no longer set because the discounts are bad.
+/* BOOST_CHECK_EQUAL(4UL, counts[1]);
+ BOOST_CHECK_EQUAL(3UL, counts[2]);
+ BOOST_CHECK_EQUAL(3UL, counts[3]);*/
+ BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size());
+ NGram uni(outputs[0].Get(), 1);
+ BOOST_CHECK_EQUAL(kUNK, *uni.begin());
+ BOOST_CHECK_EQUAL(0ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(kBOS, *uni.begin());
+ BOOST_CHECK_EQUAL(0ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(0UL, *uni.begin());
+ BOOST_CHECK_EQUAL(2ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(2ULL, uni.Count());
+ BOOST_CHECK_EQUAL(2UL, *uni.begin());
+
+ BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size());
+ NGram bi(outputs[1].Get(), 2);
+ BOOST_CHECK_EQUAL(0UL, *bi.begin());
+ BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1));
+ BOOST_CHECK_EQUAL(1ULL, bi.Count());
+ bi.NextInMemory();
+}
+
+}}} // namespaces
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc
new file mode 100644
index 00000000..8c3de57d
--- /dev/null
+++ b/klm/lm/builder/corpus_count.cc
@@ -0,0 +1,223 @@
+#include "lm/builder/corpus_count.hh"
+
+#include "lm/builder/ngram.hh"
+#include "lm/lm_exception.hh"
+#include "lm/word_index.hh"
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/murmur_hash.hh"
+#include "util/probing_hash_table.hh"
+#include "util/scoped.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/timer.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/unordered_set.hpp>
+#include <boost/unordered_map.hpp>
+
+#include <functional>
+
+#include <stdint.h>
+
+namespace lm {
+namespace builder {
+namespace {
+
+class VocabHandout {
+ public:
+ explicit VocabHandout(int fd) {
+ util::scoped_fd duped(util::DupOrThrow(fd));
+ word_list_.reset(util::FDOpenOrThrow(duped));
+
+ Lookup("<unk>"); // Force 0
+ Lookup("<s>"); // Force 1
+ Lookup("</s>"); // Force 2
+ }
+
+ WordIndex Lookup(const StringPiece &word) {
+ uint64_t hashed = util::MurmurHashNative(word.data(), word.size());
+ std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size())));
+ if (ret.second) {
+ char null_delimit = 0;
+ util::WriteOrThrow(word_list_.get(), word.data(), word.size());
+ util::WriteOrThrow(word_list_.get(), &null_delimit, 1);
+ UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ }
+ return ret.first->second;
+ }
+
+ WordIndex Size() const {
+ return seen_.size();
+ }
+
+ private:
+ typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen;
+
+ Seen seen_;
+
+ util::scoped_FILE word_list_;
+};
+
+class DedupeHash : public std::unary_function<const WordIndex *, bool> {
+ public:
+ explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
+
+ std::size_t operator()(const WordIndex *start) const {
+ return util::MurmurHashNative(start, size_);
+ }
+
+ private:
+ const std::size_t size_;
+};
+
+class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> {
+ public:
+ explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {}
+
+ bool operator()(const WordIndex *first, const WordIndex *second) const {
+ return !memcmp(first, second, size_);
+ }
+
+ private:
+ const std::size_t size_;
+};
+
+struct DedupeEntry {
+ typedef WordIndex *Key;
+ Key GetKey() const { return key; }
+ Key key;
+ static DedupeEntry Construct(WordIndex *at) {
+ DedupeEntry ret;
+ ret.key = at;
+ return ret;
+ }
+};
+
+typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
+
+const float kProbingMultiplier = 1.5;
+
+class Writer {
+ public:
+ Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
+ : block_(position), gram_(block_->Get(), order),
+ dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()),
+ dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
+ buffer_(new WordIndex[order - 1]),
+ block_size_(position.GetChain().BlockSize()) {
+ dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
+ if (order == 1) {
+ // Add special words. AdjustCounts is responsible if order != 1.
+ AddUnigramWord(kUNK);
+ AddUnigramWord(kBOS);
+ }
+ }
+
+ ~Writer() {
+ block_->SetValidSize(reinterpret_cast<const uint8_t*>(gram_.begin()) - static_cast<const uint8_t*>(block_->Get()));
+ (++block_).Poison();
+ }
+
+ // Write context with a bunch of <s>
+ void StartSentence() {
+ for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) {
+ *i = kBOS;
+ }
+ }
+
+ void Append(WordIndex word) {
+ *(gram_.end() - 1) = word;
+ Dedupe::MutableIterator at;
+ bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at);
+ if (found) {
+ // Already present.
+ NGram already(at->key, gram_.Order());
+ ++(already.Count());
+ // Shift left by one.
+ memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1));
+ return;
+ }
+ // Complete the write.
+ gram_.Count() = 1;
+ // Prepare the next n-gram.
+ if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) {
+ NGram last(gram_);
+ gram_.NextInMemory();
+ std::copy(last.begin() + 1, last.end(), gram_.begin());
+ return;
+ }
+ // Block end. Need to store the context in a temporary buffer.
+ std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
+ dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ block_->SetValidSize(block_size_);
+ gram_.ReBase((++block_)->Get());
+ std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
+ }
+
+ private:
+ void AddUnigramWord(WordIndex index) {
+ *gram_.begin() = index;
+ gram_.Count() = 0;
+ gram_.NextInMemory();
+ if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) {
+ block_->SetValidSize(block_size_);
+ gram_.ReBase((++block_)->Get());
+ }
+ }
+
+ util::stream::Link block_;
+
+ NGram gram_;
+
+ // This is the memory behind the invalid value in dedupe_.
+ std::vector<WordIndex> dedupe_invalid_;
+ // Hash table combiner implementation.
+ Dedupe dedupe_;
+
+ // Small buffer to hold existing ngrams when shifting across a block boundary.
+ boost::scoped_array<WordIndex> buffer_;
+
+ const std::size_t block_size_;
+};
+
+} // namespace
+
+float CorpusCount::DedupeMultiplier(std::size_t order) {
+ return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
+}
+
+CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
+ : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
+ dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
+ dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
+ token_count_ = 0;
+ type_count_ = 0;
+}
+
+void CorpusCount::Run(const util::stream::ChainPosition &position) {
+ UTIL_TIMER("(%w s) Counted n-grams\n");
+
+ VocabHandout vocab(vocab_write_);
+ const WordIndex end_sentence = vocab.Lookup("</s>");
+ Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
+ uint64_t count = 0;
+ try {
+ while(true) {
+ StringPiece line(from_.ReadLine());
+ writer.StartSentence();
+ for (util::TokenIter<util::AnyCharacter, true> w(line, " \t"); w; ++w) {
+ WordIndex word = vocab.Lookup(*w);
+ UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future.");
+ writer.Append(word);
+ ++count;
+ }
+ writer.Append(end_sentence);
+ }
+ } catch (const util::EndOfFileException &e) {}
+ token_count_ = count;
+ type_count_ = vocab.Size();
+}
+
+} // namespace builder
+} // namespace lm
diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh
new file mode 100644
index 00000000..e255bad1
--- /dev/null
+++ b/klm/lm/builder/corpus_count.hh
@@ -0,0 +1,42 @@
+#ifndef LM_BUILDER_CORPUS_COUNT__
+#define LM_BUILDER_CORPUS_COUNT__
+
+#include "lm/word_index.hh"
+#include "util/scoped.hh"
+
+#include <cstddef>
+#include <string>
+#include <stdint.h>
+
+namespace util {
+class FilePiece;
+namespace stream {
+class ChainPosition;
+} // namespace stream
+} // namespace util
+
+namespace lm {
+namespace builder {
+
+class CorpusCount {
+ public:
+ // Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size
+ static float DedupeMultiplier(std::size_t order);
+
+ CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);
+
+ void Run(const util::stream::ChainPosition &position);
+
+ private:
+ util::FilePiece &from_;
+ int vocab_write_;
+ uint64_t &token_count_;
+ WordIndex &type_count_;
+
+ std::size_t dedupe_mem_size_;
+ util::scoped_malloc dedupe_mem_;
+};
+
+} // namespace builder
+} // namespace lm
+#endif // LM_BUILDER_CORPUS_COUNT__
diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc
new file mode 100644
index 00000000..8d53ca9d
--- /dev/null
+++ b/klm/lm/builder/corpus_count_test.cc
@@ -0,0 +1,76 @@
+#include "lm/builder/corpus_count.hh"
+
+#include "lm/builder/ngram.hh"
+#include "lm/builder/ngram_stream.hh"
+
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/tokenize_piece.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/stream.hh"
+
+#define BOOST_TEST_MODULE CorpusCountTest
+#include <boost/test/unit_test.hpp>
+
+namespace lm { namespace builder { namespace {
+
+#define Check(str, count) { \
+ BOOST_REQUIRE(stream); \
+ w = stream->begin(); \
+ for (util::TokenIter<util::AnyCharacter, true> t(str, " "); t; ++t, ++w) { \
+ BOOST_CHECK_EQUAL(*t, v[*w]); \
+ } \
+ BOOST_CHECK_EQUAL((uint64_t)count, stream->Count()); \
+ ++stream; \
+}
+
+BOOST_AUTO_TEST_CASE(Short) {
+ util::scoped_fd input_file(util::MakeTemp("corpus_count_test_temp"));
+ const char input[] = "looking on a little more loin\non a little more loin\non foo little more loin\nbar\n\n";
+ // Blocks of 10 are
+ // looking on a little more loin </s> on a little[duplicate] more[duplicate] loin[duplicate] </s>[duplicate] on[duplicate] foo
+ // little more loin </s> bar </s> </s>
+
+ util::WriteOrThrow(input_file.get(), input, sizeof(input) - 1);
+ util::FilePiece input_piece(input_file.release(), "temp file");
+
+ util::stream::ChainConfig config;
+ config.entry_size = NGram::TotalSize(3);
+ config.total_memory = config.entry_size * 20;
+ config.block_count = 2;
+
+ util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab"));
+
+ util::stream::Chain chain(config);
+ NGramStream stream;
+ uint64_t token_count;
+ WordIndex type_count;
+ CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;
+
+ const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"};
+
+ WordIndex *w;
+
+ Check("<s> <s> looking", 1);
+ Check("<s> looking on", 1);
+ Check("looking on a", 1);
+ Check("on a little", 2);
+ Check("a little more", 2);
+ Check("little more loin", 2);
+ Check("more loin </s>", 2);
+ Check("<s> <s> on", 2);
+ Check("<s> on a", 1);
+ Check("<s> on foo", 1);
+ Check("on foo little", 1);
+ Check("foo little more", 1);
+ Check("little more loin", 1);
+ Check("more loin </s>", 1);
+ Check("<s> <s> bar", 1);
+ Check("<s> bar </s>", 1);
+ Check("<s> <s> </s>", 1);
+ BOOST_CHECK(!stream);
+ BOOST_CHECK_EQUAL(sizeof(v) / sizeof(const char*), type_count);
+}
+
+}}} // namespaces
diff --git a/klm/lm/builder/discount.hh b/klm/lm/builder/discount.hh
new file mode 100644
index 00000000..754fb20d
--- /dev/null
+++ b/klm/lm/builder/discount.hh
@@ -0,0 +1,26 @@
+#ifndef BUILDER_DISCOUNT__
+#define BUILDER_DISCOUNT__
+
+#include <algorithm>
+
+#include <inttypes.h>
+
+namespace lm {
+namespace builder {
+
+struct Discount {
+ float amount[4];
+
+ float Get(uint64_t count) const {
+ return amount[std::min<uint64_t>(count, 3)];
+ }
+
+ float Apply(uint64_t count) const {
+ return static_cast<float>(count) - Get(count);
+ }
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // BUILDER_DISCOUNT__
diff --git a/klm/lm/builder/header_info.hh b/klm/lm/builder/header_info.hh
new file mode 100644
index 00000000..ccca1456
--- /dev/null
+++ b/klm/lm/builder/header_info.hh
@@ -0,0 +1,20 @@
+#ifndef LM_BUILDER_HEADER_INFO__
+#define LM_BUILDER_HEADER_INFO__
+
+#include <string>
+#include <stdint.h>
+
+// Some configuration info that is used to add
+// comments to the beginning of an ARPA file
+struct HeaderInfo {
+ const std::string input_file;
+ const uint64_t token_count;
+
+ HeaderInfo(const std::string& input_file_in, uint64_t token_count_in)
+ : input_file(input_file_in), token_count(token_count_in) {}
+
+ // TODO: Add smoothing type
+ // TODO: More info if multiple models were interpolated
+};
+
+#endif
diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc
new file mode 100644
index 00000000..58b42a20
--- /dev/null
+++ b/klm/lm/builder/initial_probabilities.cc
@@ -0,0 +1,136 @@
+#include "lm/builder/initial_probabilities.hh"
+
+#include "lm/builder/discount.hh"
+#include "lm/builder/ngram_stream.hh"
+#include "lm/builder/sort.hh"
+#include "util/file.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/io.hh"
+#include "util/stream/stream.hh"
+
+#include <vector>
+
+namespace lm { namespace builder {
+
+namespace {
+struct BufferEntry {
+ // Gamma from page 20 of Chen and Goodman.
+ float gamma;
+ // \sum_w a(c w) for all w.
+ float denominator;
+};
+
+// Extract an array of gamma from an array of BufferEntry.
+class OnlyGamma {
+ public:
+ void Run(const util::stream::ChainPosition &position) {
+ for (util::stream::Link block_it(position); block_it; ++block_it) {
+ float *out = static_cast<float*>(block_it->Get());
+ const float *in = out;
+ const float *end = static_cast<const float*>(block_it->ValidEnd());
+ for (out += 1, in += 2; in < end; out += 1, in += 2) {
+ *out = *in;
+ }
+ block_it->SetValidSize(block_it->ValidSize() / 2);
+ }
+ }
+};
+
+class AddRight {
+ public:
+ AddRight(const Discount &discount, const util::stream::ChainPosition &input)
+ : discount_(discount), input_(input) {}
+
+ void Run(const util::stream::ChainPosition &output) {
+ NGramStream in(input_);
+ util::stream::Stream out(output);
+
+ std::vector<WordIndex> previous(in->Order() - 1);
+ const std::size_t size = sizeof(WordIndex) * previous.size();
+ for(; in; ++out) {
+ memcpy(&previous[0], in->begin(), size);
+ uint64_t denominator = 0;
+ uint64_t counts[4];
+ memset(counts, 0, sizeof(counts));
+ do {
+ denominator += in->Count();
+ ++counts[std::min(in->Count(), static_cast<uint64_t>(3))];
+ } while (++in && !memcmp(&previous[0], in->begin(), size));
+ BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get());
+ entry.denominator = static_cast<float>(denominator);
+ entry.gamma = 0.0;
+ for (unsigned i = 1; i <= 3; ++i) {
+ entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]);
+ }
+ entry.gamma /= entry.denominator;
+ }
+ out.Poison();
+ }
+
+ private:
+ const Discount &discount_;
+ const util::stream::ChainPosition input_;
+};
+
+class MergeRight {
+ public:
+ MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount)
+ : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {}
+
+ // calculate the initial probability of each n-gram (before order-interpolation)
+ // Run() gets invoked once for each order
+ void Run(const util::stream::ChainPosition &primary) {
+ util::stream::Stream summed(from_adder_);
+
+ NGramStream grams(primary);
+
+ // Without interpolation, the interpolation weight goes to <unk>.
+ if (grams->Order() == 1 && !interpolate_unigrams_) {
+ BufferEntry sums(*static_cast<const BufferEntry*>(summed.Get()));
+ assert(*grams->begin() == kUNK);
+ grams->Value().uninterp.prob = sums.gamma;
+ grams->Value().uninterp.gamma = 0.0;
+ while (++grams) {
+ grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator;
+ grams->Value().uninterp.gamma = 0.0;
+ }
+ ++summed;
+ return;
+ }
+
+ std::vector<WordIndex> previous(grams->Order() - 1);
+ const std::size_t size = sizeof(WordIndex) * previous.size();
+ for (; grams; ++summed) {
+ memcpy(&previous[0], grams->begin(), size);
+ const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get());
+ do {
+ Payload &pay = grams->Value();
+ pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator;
+ pay.uninterp.gamma = sums.gamma;
+ } while (++grams && !memcmp(&previous[0], grams->begin(), size));
+ }
+ }
+
+ private:
+ bool interpolate_unigrams_;
+ util::stream::ChainPosition from_adder_;
+ Discount discount_;
+};
+
+} // namespace
+
+void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) {
+ util::stream::ChainConfig gamma_config = config.adder_out;
+ gamma_config.entry_size = sizeof(BufferEntry);
+ for (size_t i = 0; i < primary.size(); ++i) {
+ util::stream::ChainPosition second(second_in[i].Add());
+ second_in[i] >> util::stream::kRecycle;
+ gamma_out.push_back(gamma_config);
+ gamma_out[i] >> AddRight(discounts[i], second);
+ primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]);
+ // Don't bother with the OnlyGamma thread for something to discard.
+ if (i) gamma_out[i] >> OnlyGamma();
+ }
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/initial_probabilities.hh b/klm/lm/builder/initial_probabilities.hh
new file mode 100644
index 00000000..626388eb
--- /dev/null
+++ b/klm/lm/builder/initial_probabilities.hh
@@ -0,0 +1,34 @@
+#ifndef LM_BUILDER_INITIAL_PROBABILITIES__
+#define LM_BUILDER_INITIAL_PROBABILITIES__
+
+#include "lm/builder/discount.hh"
+#include "util/stream/config.hh"
+
+#include <vector>
+
+namespace lm {
+namespace builder {
+class Chains;
+
+struct InitialProbabilitiesConfig {
+ // These should be small buffers to keep the adder from getting too far ahead
+ util::stream::ChainConfig adder_in;
+ util::stream::ChainConfig adder_out;
+ // SRILM doesn't normally interpolate unigrams.
+ bool interpolate_unigrams;
+};
+
+/* Compute initial (uninterpolated) probabilities
+ * primary: the normal chain of n-grams. Incoming is context sorted adjusted
+ * counts. Outgoing has uninterpolated probabilities for use by Interpolate.
+ * second_in: a second copy of the primary input. Discard the output.
+ * gamma_out: Computed gamma values are output on these chains in suffix order.
+ * The values are bare floats and should be buffered for interpolation to
+ * use.
+ */
+void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out);
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_INITIAL_PROBABILITIES__
diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc
new file mode 100644
index 00000000..50026806
--- /dev/null
+++ b/klm/lm/builder/interpolate.cc
@@ -0,0 +1,65 @@
+#include "lm/builder/interpolate.hh"
+
+#include "lm/builder/joint_order.hh"
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/sort.hh"
+#include "lm/lm_exception.hh"
+
+#include <assert.h>
+
+namespace lm { namespace builder {
+namespace {
+
+class Callback {
+ public:
+ Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) {
+ probs_[0] = uniform_prob;
+ for (std::size_t i = 0; i < backoffs.size(); ++i) {
+ backoffs_.push_back(backoffs[i]);
+ }
+ }
+
+ ~Callback() {
+ for (std::size_t i = 0; i < backoffs_.size(); ++i) {
+ if (backoffs_[i]) {
+ std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl;
+ abort();
+ }
+ }
+ }
+
+ void Enter(unsigned order_minus_1, NGram &gram) {
+ Payload &pay = gram.Value();
+ pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
+ probs_[order_minus_1 + 1] = pay.complete.prob;
+ pay.complete.prob = log10(pay.complete.prob);
+ // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling.
+ if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
+ pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
+ ++backoffs_[order_minus_1];
+ } else {
+ // Not a context.
+ pay.complete.backoff = 0.0;
+ }
+ }
+
+ void Exit(unsigned, const NGram &) const {}
+
+ private:
+ FixedArray<util::stream::Stream> backoffs_;
+
+ std::vector<float> probs_;
+};
+} // namespace
+
+Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)
+ : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {}
+
+// perform order-wise interpolation
+void Interpolate::Run(const ChainPositions &positions) {
+ assert(positions.size() == backoffs_.size() + 1);
+ Callback callback(uniform_prob_, backoffs_);
+ JointOrder<Callback, SuffixOrder>(positions, callback);
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/interpolate.hh b/klm/lm/builder/interpolate.hh
new file mode 100644
index 00000000..9268d404
--- /dev/null
+++ b/klm/lm/builder/interpolate.hh
@@ -0,0 +1,27 @@
+#ifndef LM_BUILDER_INTERPOLATE__
+#define LM_BUILDER_INTERPOLATE__
+
+#include <stdint.h>
+
+#include "lm/builder/multi_stream.hh"
+
+namespace lm { namespace builder {
+
+/* Interpolate step.
+ * Input: suffix sorted n-grams with (p_uninterpolated, gamma) from
+ * InitialProbabilities.
+ * Output: suffix sorted n-grams with complete probability
+ */
+class Interpolate {
+ public:
+ explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs);
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ float uniform_prob_;
+ ChainPositions backoffs_;
+};
+
+}} // namespaces
+#endif // LM_BUILDER_INTERPOLATE__
diff --git a/klm/lm/builder/joint_order.hh b/klm/lm/builder/joint_order.hh
new file mode 100644
index 00000000..b5620144
--- /dev/null
+++ b/klm/lm/builder/joint_order.hh
@@ -0,0 +1,43 @@
+#ifndef LM_BUILDER_JOINT_ORDER__
+#define LM_BUILDER_JOINT_ORDER__
+
+#include "lm/builder/multi_stream.hh"
+#include "lm/lm_exception.hh"
+
+#include <string.h>
+
+namespace lm { namespace builder {
+
+template <class Callback, class Compare> void JointOrder(const ChainPositions &positions, Callback &callback) {
+ // Allow matching to reference streams[-1].
+ NGramStreams streams_with_dummy;
+ streams_with_dummy.InitWithDummy(positions);
+ NGramStream *streams = streams_with_dummy.begin() + 1;
+
+ unsigned int order;
+ for (order = 0; order < positions.size() && streams[order]; ++order) {}
+ assert(order); // should always have <unk>.
+ unsigned int current = 0;
+ while (true) {
+ // Does the context match the lower one?
+ if (!memcmp(streams[static_cast<int>(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset, sizeof(WordIndex) * current)) {
+ callback.Enter(current, *streams[current]);
+ // Transition to looking for extensions.
+ if (++current < order) continue;
+ }
+ // No extension left.
+ while(true) {
+ assert(current > 0);
+ --current;
+ callback.Exit(current, *streams[current]);
+ if (++streams[current]) break;
+ UTIL_THROW_IF(order != current + 1, FormatLoadException, "Detected n-gram without matching suffix");
+ order = current;
+ if (!order) return;
+ }
+ }
+}
+
+}} // namespaces
+
+#endif // LM_BUILDER_JOINT_ORDER__
diff --git a/klm/lm/builder/main.cc b/klm/lm/builder/main.cc
new file mode 100644
index 00000000..90b9dca2
--- /dev/null
+++ b/klm/lm/builder/main.cc
@@ -0,0 +1,94 @@
+#include "lm/builder/pipeline.hh"
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/usage.hh"
+
+#include <iostream>
+
+#include <boost/program_options.hpp>
+
+namespace {
+class SizeNotify {
+ public:
+ SizeNotify(std::size_t &out) : behind_(out) {}
+
+ void operator()(const std::string &from) {
+ behind_ = util::ParseSize(from);
+ }
+
+ private:
+ std::size_t &behind_;
+};
+
+boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value) {
+ return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value);
+}
+
+} // namespace
+
+int main(int argc, char *argv[]) {
+ try {
+ namespace po = boost::program_options;
+ po::options_description options("Language model building options");
+ lm::builder::PipelineConfig pipeline;
+
+ options.add_options()
+ ("order,o", po::value<std::size_t>(&pipeline.order)->required(), "Order of the model")
+ ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
+ ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
+ ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
+ ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step")
+ ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
+ ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
+ ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
+ ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
+ ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.");
+ if (argc == 1) {
+ std::cerr <<
+ "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"
+ "Please cite:\n"
+ "@inproceedings{kenlm,\n"
+ "author = {Kenneth Heafield},\n"
+ "title = {{KenLM}: Faster and Smaller Language Model Queries},\n"
+ "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n"
+ "month = {July}, year={2011},\n"
+ "address = {Edinburgh, UK},\n"
+ "publisher = {Association for Computational Linguistics},\n"
+ "}\n\n"
+ "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n"
+ "the model (-o) is the only mandatory option. As this is an on-disk program,\n"
+ "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n"
+ "Memory sizes are specified like GNU sort: a number followed by a unit character.\n"
+ "Valid units are \% for percentage of memory (supported platforms only) and (in\n"
+ "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n";
+ std::cerr << options << std::endl;
+ return 1;
+ }
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, options), vm);
+ po::notify(vm);
+
+ util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
+
+ lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;
+ // TODO: evaluate options for these.
+ initial.adder_in.total_memory = 32768;
+ initial.adder_in.block_count = 2;
+ initial.adder_out.total_memory = 32768;
+ initial.adder_out.block_count = 2;
+ pipeline.read_backoffs = initial.adder_out;
+
+ // Read from stdin
+ try {
+ lm::builder::Pipeline(pipeline, 0, 1);
+ } catch (const util::MallocException &e) {
+ std::cerr << e.what() << std::endl;
+ std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl;
+ return 1;
+ }
+ util::PrintUsage(std::cerr);
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
+ }
+}
diff --git a/klm/lm/builder/multi_stream.hh b/klm/lm/builder/multi_stream.hh
new file mode 100644
index 00000000..707a98c7
--- /dev/null
+++ b/klm/lm/builder/multi_stream.hh
@@ -0,0 +1,180 @@
+#ifndef LM_BUILDER_MULTI_STREAM__
+#define LM_BUILDER_MULTI_STREAM__
+
+#include "lm/builder/ngram_stream.hh"
+#include "util/scoped.hh"
+#include "util/stream/chain.hh"
+
+#include <cstddef>
+#include <new>
+
+#include <assert.h>
+#include <stdlib.h>
+
+namespace lm { namespace builder {
+
+template <class T> class FixedArray {
+ public:
+ explicit FixedArray(std::size_t count) {
+ Init(count);
+ }
+
+ FixedArray() : newed_end_(NULL) {}
+
+ void Init(std::size_t count) {
+ assert(!block_.get());
+ block_.reset(malloc(sizeof(T) * count));
+ if (!block_.get()) throw std::bad_alloc();
+ newed_end_ = begin();
+ }
+
+ FixedArray(const FixedArray &from) {
+ std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get());
+ Init(size);
+ for (std::size_t i = 0; i < size; ++i) {
+ new(end()) T(from[i]);
+ Constructed();
+ }
+ }
+
+ ~FixedArray() { clear(); }
+
+ T *begin() { return static_cast<T*>(block_.get()); }
+ const T *begin() const { return static_cast<const T*>(block_.get()); }
+ // Always call Constructed after successful completion of new.
+ T *end() { return newed_end_; }
+ const T *end() const { return newed_end_; }
+
+ T &back() { return *(end() - 1); }
+ const T &back() const { return *(end() - 1); }
+
+ std::size_t size() const { return end() - begin(); }
+ bool empty() const { return begin() == end(); }
+
+ T &operator[](std::size_t i) { return begin()[i]; }
+ const T &operator[](std::size_t i) const { return begin()[i]; }
+
+ template <class C> void push_back(const C &c) {
+ new (end()) T(c);
+ Constructed();
+ }
+
+ void clear() {
+ for (T *i = begin(); i != end(); ++i)
+ i->~T();
+ newed_end_ = begin();
+ }
+
+ protected:
+ void Constructed() {
+ ++newed_end_;
+ }
+
+ private:
+ util::scoped_malloc block_;
+
+ T *newed_end_;
+};
+
+class Chains;
+
+class ChainPositions : public FixedArray<util::stream::ChainPosition> {
+ public:
+ ChainPositions() {}
+
+ void Init(Chains &chains);
+
+ explicit ChainPositions(Chains &chains) {
+ Init(chains);
+ }
+};
+
+class Chains : public FixedArray<util::stream::Chain> {
+ private:
+ template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun {
+ typedef Chains type;
+ };
+
+ public:
+ explicit Chains(std::size_t limit) : FixedArray<util::stream::Chain>(limit) {}
+
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) {
+ threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
+ return *this;
+ }
+
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) {
+ threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
+ return *this;
+ }
+
+ Chains &operator>>(const util::stream::Recycler &recycler) {
+ for (util::stream::Chain *i = begin(); i != end(); ++i)
+ *i >> recycler;
+ return *this;
+ }
+
+ void Wait(bool release_memory = true) {
+ threads_.clear();
+ for (util::stream::Chain *i = begin(); i != end(); ++i) {
+ i->Wait(release_memory);
+ }
+ }
+
+ private:
+ boost::ptr_vector<util::stream::Thread> threads_;
+
+ Chains(const Chains &);
+ void operator=(const Chains &);
+};
+
+inline void ChainPositions::Init(Chains &chains) {
+ FixedArray<util::stream::ChainPosition>::Init(chains.size());
+ for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) {
+ new (end()) util::stream::ChainPosition(i->Add()); Constructed();
+ }
+}
+
+inline Chains &operator>>(Chains &chains, ChainPositions &positions) {
+ positions.Init(chains);
+ return chains;
+}
+
+class NGramStreams : public FixedArray<NGramStream> {
+ public:
+ NGramStreams() {}
+
+ // This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning).
+ void InitWithDummy(const ChainPositions &positions) {
+ FixedArray<NGramStream>::Init(positions.size() + 1);
+ new (end()) NGramStream(); Constructed();
+ for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) {
+ push_back(*i);
+ }
+ }
+
+ // Limit restricts to positions[0,limit)
+ void Init(const ChainPositions &positions, std::size_t limit) {
+ FixedArray<NGramStream>::Init(limit);
+ for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) {
+ push_back(*i);
+ }
+ }
+ void Init(const ChainPositions &positions) {
+ Init(positions, positions.size());
+ }
+
+ NGramStreams(const ChainPositions &positions) {
+ Init(positions);
+ }
+};
+
+inline Chains &operator>>(Chains &chains, NGramStreams &streams) {
+ ChainPositions positions;
+ chains >> positions;
+ streams.Init(positions);
+ return chains;
+}
+
+}} // namespaces
+#endif // LM_BUILDER_MULTI_STREAM__
diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh
new file mode 100644
index 00000000..2984ed0b
--- /dev/null
+++ b/klm/lm/builder/ngram.hh
@@ -0,0 +1,84 @@
+#ifndef LM_BUILDER_NGRAM__
+#define LM_BUILDER_NGRAM__
+
+#include "lm/weights.hh"
+#include "lm/word_index.hh"
+
+#include <cstddef>
+
+#include <assert.h>
+#include <stdint.h>
+#include <string.h>
+
+namespace lm {
+namespace builder {
+
+struct Uninterpolated {
+ float prob; // Uninterpolated probability.
+ float gamma; // Interpolation weight for lower order.
+};
+
+union Payload {
+ uint64_t count;
+ Uninterpolated uninterp;
+ ProbBackoff complete;
+};
+
+class NGram {
+ public:
+ NGram(void *begin, std::size_t order)
+ : begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {}
+
+ const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); }
+ uint8_t *Base() { return reinterpret_cast<uint8_t*>(begin_); }
+
+ void ReBase(void *to) {
+ std::size_t difference = end_ - begin_;
+ begin_ = reinterpret_cast<WordIndex*>(to);
+ end_ = begin_ + difference;
+ }
+
+ // Would do operator++ but that can get confusing for a stream.
+ void NextInMemory() {
+ ReBase(&Value() + 1);
+ }
+
+ // Lower-case in deference to STL.
+ const WordIndex *begin() const { return begin_; }
+ WordIndex *begin() { return begin_; }
+ const WordIndex *end() const { return end_; }
+ WordIndex *end() { return end_; }
+
+ const Payload &Value() const { return *reinterpret_cast<const Payload *>(end_); }
+ Payload &Value() { return *reinterpret_cast<Payload *>(end_); }
+
+ uint64_t &Count() { return Value().count; }
+ const uint64_t Count() const { return Value().count; }
+
+ std::size_t Order() const { return end_ - begin_; }
+
+ static std::size_t TotalSize(std::size_t order) {
+ return order * sizeof(WordIndex) + sizeof(Payload);
+ }
+ std::size_t TotalSize() const {
+ // Compiler should optimize this.
+ return TotalSize(Order());
+ }
+ static std::size_t OrderFromSize(std::size_t size) {
+ std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex);
+ assert(size == TotalSize(ret));
+ return ret;
+ }
+
+ private:
+ WordIndex *begin_, *end_;
+};
+
+const WordIndex kUNK = 0;
+const WordIndex kBOS = 1;
+const WordIndex kEOS = 2;
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_NGRAM__
diff --git a/klm/lm/builder/ngram_stream.hh b/klm/lm/builder/ngram_stream.hh
new file mode 100644
index 00000000..3c994664
--- /dev/null
+++ b/klm/lm/builder/ngram_stream.hh
@@ -0,0 +1,55 @@
+#ifndef LM_BUILDER_NGRAM_STREAM__
+#define LM_BUILDER_NGRAM_STREAM__
+
+#include "lm/builder/ngram.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/stream.hh"
+
+#include <cstddef>
+
+namespace lm { namespace builder {
+
+class NGramStream {
+ public:
+ NGramStream() : gram_(NULL, 0) {}
+
+ NGramStream(const util::stream::ChainPosition &position) : gram_(NULL, 0) {
+ Init(position);
+ }
+
+ void Init(const util::stream::ChainPosition &position) {
+ stream_.Init(position);
+ gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize()));
+ }
+
+ NGram &operator*() { return gram_; }
+ const NGram &operator*() const { return gram_; }
+
+ NGram *operator->() { return &gram_; }
+ const NGram *operator->() const { return &gram_; }
+
+ void *Get() { return stream_.Get(); }
+ const void *Get() const { return stream_.Get(); }
+
+ operator bool() const { return stream_; }
+ bool operator!() const { return !stream_; }
+ void Poison() { stream_.Poison(); }
+
+ NGramStream &operator++() {
+ ++stream_;
+ gram_.ReBase(stream_.Get());
+ return *this;
+ }
+
+ private:
+ NGram gram_;
+ util::stream::Stream stream_;
+};
+
+inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) {
+ str.Init(chain.Add());
+ return chain;
+}
+
+}} // namespaces
+#endif // LM_BUILDER_NGRAM_STREAM__
diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc
new file mode 100644
index 00000000..14a1f721
--- /dev/null
+++ b/klm/lm/builder/pipeline.cc
@@ -0,0 +1,320 @@
+#include "lm/builder/pipeline.hh"
+
+#include "lm/builder/adjust_counts.hh"
+#include "lm/builder/corpus_count.hh"
+#include "lm/builder/initial_probabilities.hh"
+#include "lm/builder/interpolate.hh"
+#include "lm/builder/print.hh"
+#include "lm/builder/sort.hh"
+
+#include "lm/sizes.hh"
+
+#include "util/exception.hh"
+#include "util/file.hh"
+#include "util/stream/io.hh"
+
+#include <algorithm>
+#include <iostream>
+#include <vector>
+
+namespace lm { namespace builder {
+
+namespace {
+void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts) {
+ std::cerr << "Statistics:\n";
+ for (size_t i = 0; i < counts.size(); ++i) {
+ std::cerr << (i + 1) << ' ' << counts[i];
+ for (size_t d = 1; d <= 3; ++d)
+ std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d];
+ std::cerr << '\n';
+ }
+}
+
+class Master {
+ public:
+ explicit Master(const PipelineConfig &config)
+ : config_(config), chains_(config.order), files_(config.order) {
+ config_.minimum_block = std::max(NGram::TotalSize(config_.order), config_.minimum_block);
+ }
+
+ const PipelineConfig &Config() const { return config_; }
+
+ Chains &MutableChains() { return chains_; }
+
+ template <class T> Master &operator>>(const T &worker) {
+ chains_ >> worker;
+ return *this;
+ }
+
+ // This takes the (partially) sorted ngrams and sets up for adjusted counts.
+ void InitForAdjust(util::stream::Sort<SuffixOrder, AddCombiner> &ngrams, WordIndex types) {
+ const std::size_t each_order_min = config_.minimum_block * config_.block_count;
+ // We know how many unigrams there are. Don't allocate more than needed to them.
+ const std::size_t min_chains = (config_.order - 1) * each_order_min +
+ std::min(types * NGram::TotalSize(1), each_order_min);
+ // Do merge sort with calculated laziness.
+ const std::size_t merge_using = ngrams.Merge(std::min(config_.TotalMemory() - min_chains, ngrams.DefaultLazy()));
+
+ std::vector<uint64_t> count_bounds(1, types);
+ CreateChains(config_.TotalMemory() - merge_using, count_bounds);
+ ngrams.Output(chains_.back(), merge_using);
+
+ // Setup unigram file.
+ files_.push_back(util::MakeTemp(config_.TempPrefix()));
+ }
+
+ // For initial probabilities, but this is generic.
+ void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, Chains &second, util::stream::ChainConfig second_config) {
+ // Do merge first before allocating chain memory.
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts[i - 1].Merge(0);
+ }
+ // There's no lazy merge, so just divide memory amongst the chains.
+ CreateChains(config_.TotalMemory(), counts);
+ chains_.back().ActivateProgress();
+ chains_[0] >> files_[0].Source();
+ second_config.entry_size = NGram::TotalSize(1);
+ second.push_back(second_config);
+ second.back() >> files_[0].Source();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ util::scoped_fd fd(sorts[i - 1].StealCompleted());
+ chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get()));
+ chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true);
+ second_config.entry_size = NGram::TotalSize(i + 1);
+ second.push_back(second_config);
+ second.back() >> util::stream::PRead(fd.release(), true);
+ }
+ }
+
+ // There is no sort after this, so go for broke on lazy merging.
+ template <class Compare> void MaximumLazyInput(const std::vector<uint64_t> &counts, Sorts<Compare> &sorts) {
+ // Determine the minimum we can use for all the chains.
+ std::size_t min_chains = 0;
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast<uint64_t>(config_.minimum_block));
+ }
+ std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains);
+ std::vector<std::size_t> laziness;
+ // Prioritize longer n-grams.
+ for (util::stream::Sort<SuffixOrder> *i = sorts.end() - 1; i >= sorts.begin(); --i) {
+ laziness.push_back(i->Merge(for_merge));
+ assert(for_merge >= laziness.back());
+ for_merge -= laziness.back();
+ }
+ std::reverse(laziness.begin(), laziness.end());
+
+ CreateChains(for_merge + min_chains, counts);
+ chains_.back().ActivateProgress();
+ chains_[0] >> files_[0].Source();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts[i - 1].Output(chains_[i], laziness[i - 1]);
+ }
+ }
+
+ void BufferFinal(const std::vector<uint64_t> &counts) {
+ chains_[0] >> files_[0].Sink();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ files_.push_back(util::MakeTemp(config_.TempPrefix()));
+ chains_[i] >> files_[i].Sink();
+ }
+ chains_.Wait(true);
+ // Use less memory. Because we can.
+ CreateChains(std::min(config_.sort.buffer_size * config_.order, config_.TotalMemory()), counts);
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ chains_[i] >> files_[i].Source();
+ }
+ }
+
+ template <class Compare> void SetupSorts(Sorts<Compare> &sorts) {
+ sorts.Init(config_.order - 1);
+ // Unigrams don't get sorted because their order is always the same.
+ chains_[0] >> files_[0].Sink();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts.push_back(chains_[i], config_.sort, Compare(i + 1));
+ }
+ chains_.Wait(true);
+ }
+
+ private:
+ // Create chains, allocating memory to them. Totally heuristic. Count
+ // bounds are upper bounds on the counts or not present.
+ void CreateChains(std::size_t remaining_mem, const std::vector<uint64_t> &count_bounds) {
+ std::vector<std::size_t> assignments;
+ assignments.reserve(config_.order);
+ // Start by assigning maximum memory usage (to be refined later).
+ for (std::size_t i = 0; i < count_bounds.size(); ++i) {
+ assignments.push_back(static_cast<std::size_t>(std::min(
+ static_cast<uint64_t>(remaining_mem),
+ count_bounds[i] * static_cast<uint64_t>(NGram::TotalSize(i + 1)))));
+ }
+ assignments.resize(config_.order, remaining_mem);
+
+ // Now we know how much memory everybody wants. How much will they get?
+ // Proportional to this.
+ std::vector<float> portions;
+ // Indices of orders that have yet to be assigned.
+ std::vector<std::size_t> unassigned;
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ portions.push_back(static_cast<float>((i+1) * NGram::TotalSize(i+1)));
+ unassigned.push_back(i);
+ }
+ /*If somebody doesn't eat their full dinner, give it to the rest of the
+ * family. Then somebody else might not eat their full dinner etc. Ends
+ * when everybody unassigned is hungry.
+ */
+ float sum;
+ bool found_more;
+ std::vector<std::size_t> block_count(config_.order);
+ do {
+ sum = 0.0;
+ for (std::size_t i = 0; i < unassigned.size(); ++i) {
+ sum += portions[unassigned[i]];
+ }
+ found_more = false;
+ // If the proportional assignment is more than needed, give it just what it needs.
+ for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end();) {
+ if (assignments[*i] <= remaining_mem * (portions[*i] / sum)) {
+ remaining_mem -= assignments[*i];
+ block_count[*i] = 1;
+ i = unassigned.erase(i);
+ found_more = true;
+ } else {
+ ++i;
+ }
+ }
+ } while (found_more);
+ for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end(); ++i) {
+ assignments[*i] = remaining_mem * (portions[*i] / sum);
+ block_count[*i] = config_.block_count;
+ }
+ chains_.clear();
+ std::cerr << "Chain sizes:";
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ std::cerr << ' ' << (i+1) << ":" << assignments[i];
+ chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i]));
+ }
+ std::cerr << std::endl;
+ }
+
+ PipelineConfig config_;
+
+ Chains chains_;
+ // Often only unigrams, but sometimes all orders.
+ FixedArray<util::stream::FileBuffer> files_;
+};
+
+void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) {
+ const PipelineConfig &config = master.Config();
+ std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl;
+
+ UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory());
+ std::size_t memory_for_chain =
+ // This much memory to work with after vocab hash table.
+ static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) /
+ // Solve for block size including the dedupe multiplier for one block.
+ (static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *
+ // Chain likes memory expressed in terms of total memory.
+ static_cast<float>(config.block_count);
+ util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain));
+
+ WordIndex type_count;
+ util::FilePiece text(text_file, NULL, &std::cerr);
+ text_file_name = text.FileName();
+ CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ chain >> boost::ref(counter);
+
+ util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
+ chain.Wait(true);
+ std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl;
+ master.InitForAdjust(sorter, type_count);
+}
+
+void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+ const PipelineConfig &config = master.Config();
+ Chains second(config.order);
+
+ {
+ Sorts<ContextOrder> sorts;
+ master.SetupSorts(sorts);
+ PrintStatistics(counts, discounts);
+ lm::ngram::ShowSizes(counts);
+ std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl;
+ master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in);
+ }
+
+ Chains gamma_chains(config.order);
+ InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains);
+ // Don't care about gamma for 0.
+ gamma_chains[0] >> util::stream::kRecycle;
+ gammas.Init(config.order - 1);
+ for (std::size_t i = 1; i < config.order; ++i) {
+ gammas.push_back(util::MakeTemp(config.TempPrefix()));
+ gamma_chains[i] >> gammas[i - 1].Sink();
+ }
+ // Has to be done here due to gamma_chains scope.
+ master.SetupSorts(primary);
+}
+
+void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+ std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl;
+ const PipelineConfig &config = master.Config();
+ master.MaximumLazyInput(counts, primary);
+
+ Chains gamma_chains(config.order - 1);
+ util::stream::ChainConfig read_backoffs(config.read_backoffs);
+ read_backoffs.entry_size = sizeof(float);
+ for (std::size_t i = 0; i < config.order - 1; ++i) {
+ gamma_chains.push_back(read_backoffs);
+ gamma_chains.back() >> gammas[i].Source();
+ }
+ master >> Interpolate(counts[0], ChainPositions(gamma_chains));
+ gamma_chains >> util::stream::kRecycle;
+ master.BufferFinal(counts);
+}
+
+} // namespace
+
+void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
+ // Some fail-fast sanity checks.
+ if (config.sort.buffer_size * 4 > config.TotalMemory()) {
+ config.sort.buffer_size = config.TotalMemory() / 4;
+ std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl;
+ }
+ if (config.minimum_block < NGram::TotalSize(config.order)) {
+ config.minimum_block = NGram::TotalSize(config.order);
+ std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl;
+ }
+ UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << ".");
+ UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception,
+ "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size.");
+
+ UTIL_TIMER("(%w s) Total wall time elapsed\n");
+ Master master(config);
+
+ util::scoped_fd vocab_file(config.vocab_file.empty() ?
+ util::MakeTemp(config.TempPrefix()) :
+ util::CreateOrThrow(config.vocab_file.c_str()));
+ uint64_t token_count;
+ std::string text_file_name;
+ CountText(text_file, vocab_file.get(), master, token_count, text_file_name);
+
+ std::vector<uint64_t> counts;
+ std::vector<Discount> discounts;
+ master >> AdjustCounts(counts, discounts);
+
+ {
+ FixedArray<util::stream::FileBuffer> gammas;
+ Sorts<SuffixOrder> primary;
+ InitialProbabilities(counts, discounts, master, primary, gammas);
+ InterpolateProbabilities(counts, master, primary, gammas);
+ }
+
+ std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
+ VocabReconstitute vocab(vocab_file.get());
+ UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?");
+ HeaderInfo header_info(text_file_name, token_count);
+ master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle;
+ master.MutableChains().Wait(true);
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh
new file mode 100644
index 00000000..f1d6c5f6
--- /dev/null
+++ b/klm/lm/builder/pipeline.hh
@@ -0,0 +1,40 @@
+#ifndef LM_BUILDER_PIPELINE__
+#define LM_BUILDER_PIPELINE__
+
+#include "lm/builder/initial_probabilities.hh"
+#include "lm/builder/header_info.hh"
+#include "util/stream/config.hh"
+#include "util/file_piece.hh"
+
+#include <string>
+#include <cstddef>
+
+namespace lm { namespace builder {
+
+struct PipelineConfig {
+ std::size_t order;
+ std::string vocab_file;
+ util::stream::SortConfig sort;
+ InitialProbabilitiesConfig initial_probs;
+ util::stream::ChainConfig read_backoffs;
+ bool verbose_header;
+
+ // Amount of memory to assume that the vocabulary hash table will use. This
+ // is subtracted from total memory for CorpusCount.
+ std::size_t assume_vocab_hash_size;
+
+ // Minimum block size to tolerate.
+ std::size_t minimum_block;
+
+ // Number of blocks to use. This will be overridden to 1 if everything fits.
+ std::size_t block_count;
+
+ const std::string &TempPrefix() const { return sort.temp_prefix; }
+ std::size_t TotalMemory() const { return sort.total_memory; }
+};
+
+// Takes ownership of text_file.
+void Pipeline(PipelineConfig config, int text_file, int out_arpa);
+
+}} // namespaces
+#endif // LM_BUILDER_PIPELINE__
diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc
new file mode 100644
index 00000000..b0323221
--- /dev/null
+++ b/klm/lm/builder/print.cc
@@ -0,0 +1,135 @@
+#include "lm/builder/print.hh"
+
+#include "util/double-conversion/double-conversion.h"
+#include "util/double-conversion/utils.h"
+#include "util/file.hh"
+#include "util/mmap.hh"
+#include "util/scoped.hh"
+#include "util/stream/timer.hh"
+
+#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE
+#include <boost/lexical_cast.hpp>
+
+#include <sstream>
+
+#include <string.h>
+
+namespace lm { namespace builder {
+
+VocabReconstitute::VocabReconstitute(int fd) {
+ uint64_t size = util::SizeOrThrow(fd);
+ util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_);
+ const char *const start = static_cast<const char*>(memory_.get());
+ const char *i;
+ for (i = start; i != start + size; i += strlen(i) + 1) {
+ map_.push_back(i);
+ }
+ // Last one for LookupPiece.
+ map_.push_back(i);
+}
+
+namespace {
+class OutputManager {
+ public:
+ static const std::size_t kOutBuf = 1048576;
+
+ // Does not take ownership of out.
+ explicit OutputManager(int out)
+ : buf_(util::MallocOrThrow(kOutBuf)),
+ builder_(static_cast<char*>(buf_.get()), kOutBuf),
+ // Mostly the default but with inf instead. And no flags.
+ convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0),
+ fd_(out) {}
+
+ ~OutputManager() {
+ Flush();
+ }
+
+ OutputManager &operator<<(float value) {
+ // Odd, but this is the largest number found in the comments.
+ EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
+ convert_.ToShortestSingle(value, &builder_);
+ return *this;
+ }
+
+ OutputManager &operator<<(StringPiece str) {
+ if (str.size() > kOutBuf) {
+ Flush();
+ util::WriteOrThrow(fd_, str.data(), str.size());
+ } else {
+ EnsureRemaining(str.size());
+ builder_.AddSubstring(str.data(), str.size());
+ }
+ return *this;
+ }
+
+ // Inefficient!
+ OutputManager &operator<<(unsigned val) {
+ return *this << boost::lexical_cast<std::string>(val);
+ }
+
+ OutputManager &operator<<(char c) {
+ EnsureRemaining(1);
+ builder_.AddCharacter(c);
+ return *this;
+ }
+
+ void Flush() {
+ util::WriteOrThrow(fd_, buf_.get(), builder_.position());
+ builder_.Reset();
+ }
+
+ private:
+ void EnsureRemaining(std::size_t amount) {
+ if (static_cast<std::size_t>(builder_.size() - builder_.position()) < amount) {
+ Flush();
+ }
+ }
+
+ util::scoped_malloc buf_;
+ double_conversion::StringBuilder builder_;
+ double_conversion::DoubleToStringConverter convert_;
+ int fd_;
+};
+} // namespace
+
+PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd)
+ : vocab_(vocab), out_fd_(out_fd) {
+ std::stringstream stream;
+
+ if (header_info) {
+ stream << "# Input file: " << header_info->input_file << '\n';
+ stream << "# Token count: " << header_info->token_count << '\n';
+ stream << "# Smoothing: Modified Kneser-Ney" << '\n';
+ }
+ stream << "\\data\\\n";
+ for (size_t i = 0; i < counts.size(); ++i) {
+ stream << "ngram " << (i+1) << '=' << counts[i] << '\n';
+ }
+ stream << '\n';
+ std::string as_string(stream.str());
+ util::WriteOrThrow(out_fd, as_string.data(), as_string.size());
+}
+
+void PrintARPA::Run(const ChainPositions &positions) {
+ UTIL_TIMER("(%w s) Wrote ARPA file\n");
+ OutputManager out(out_fd_);
+ for (unsigned order = 1; order <= positions.size(); ++order) {
+ out << "\\" << order << "-grams:" << '\n';
+ for (NGramStream stream(positions[order - 1]); stream; ++stream) {
+ // Correcting for numerical precision issues. Take that IRST.
+ out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin());
+ for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
+ out << ' ' << vocab_.Lookup(*i);
+ }
+ float backoff = stream->Value().complete.backoff;
+ if (backoff != 0.0)
+ out << '\t' << backoff;
+ out << '\n';
+ }
+ out << '\n';
+ }
+ out << "\\end\\\n";
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh
new file mode 100644
index 00000000..aa932e75
--- /dev/null
+++ b/klm/lm/builder/print.hh
@@ -0,0 +1,102 @@
+#ifndef LM_BUILDER_PRINT__
+#define LM_BUILDER_PRINT__
+
+#include "lm/builder/ngram.hh"
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/header_info.hh"
+#include "util/file.hh"
+#include "util/mmap.hh"
+#include "util/string_piece.hh"
+
+#include <ostream>
+
+#include <assert.h>
+
+// Warning: print routines read all unigrams before all bigrams before all
+// trigrams etc. So if other parts of the chain move jointly, you'll have to
+// buffer.
+
+namespace lm { namespace builder {
+
+class VocabReconstitute {
+ public:
+ // fd must be alive for life of this object; does not take ownership.
+ explicit VocabReconstitute(int fd);
+
+ const char *Lookup(WordIndex index) const {
+ assert(index < map_.size() - 1);
+ return map_[index];
+ }
+
+ StringPiece LookupPiece(WordIndex index) const {
+ return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]);
+ }
+
+ std::size_t Size() const {
+ // There's an extra entry to support StringPiece lengths.
+ return map_.size() - 1;
+ }
+
+ private:
+ util::scoped_memory memory_;
+ std::vector<const char*> map_;
+};
+
+// Not defined, only specialized.
+template <class T> void PrintPayload(std::ostream &to, const Payload &payload);
+template <> inline void PrintPayload<uint64_t>(std::ostream &to, const Payload &payload) {
+ to << payload.count;
+}
+template <> inline void PrintPayload<Uninterpolated>(std::ostream &to, const Payload &payload) {
+ to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma);
+}
+template <> inline void PrintPayload<ProbBackoff>(std::ostream &to, const Payload &payload) {
+ to << payload.complete.prob << ' ' << payload.complete.backoff;
+}
+
+// template parameter is the type stored.
+template <class V> class Print {
+ public:
+ explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {}
+
+ void Run(const ChainPositions &chains) {
+ NGramStreams streams(chains);
+ for (NGramStream *s = streams.begin(); s != streams.end(); ++s) {
+ DumpStream(*s);
+ }
+ }
+
+ void Run(const util::stream::ChainPosition &position) {
+ NGramStream stream(position);
+ DumpStream(stream);
+ }
+
+ private:
+ void DumpStream(NGramStream &stream) {
+ for (; stream; ++stream) {
+ PrintPayload<V>(to_, stream->Value());
+ for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) {
+ to_ << ' ' << vocab_.Lookup(*w) << '=' << *w;
+ }
+ to_ << '\n';
+ }
+ }
+
+ const VocabReconstitute &vocab_;
+ std::ostream &to_;
+};
+
+class PrintARPA {
+ public:
+ // header_info may be NULL to disable the header
+ explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd);
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ const VocabReconstitute &vocab_;
+ int out_fd_;
+};
+
+}} // namespaces
+#endif // LM_BUILDER_PRINT__
diff --git a/klm/lm/builder/sort.hh b/klm/lm/builder/sort.hh
new file mode 100644
index 00000000..9989389b
--- /dev/null
+++ b/klm/lm/builder/sort.hh
@@ -0,0 +1,103 @@
+#ifndef LM_BUILDER_SORT__
+#define LM_BUILDER_SORT__
+
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/ngram.hh"
+#include "lm/word_index.hh"
+#include "util/stream/sort.hh"
+
+#include "util/stream/timer.hh"
+
+#include <functional>
+#include <string>
+
+namespace lm {
+namespace builder {
+
+template <class Child> class Comparator : public std::binary_function<const void *, const void *, bool> {
+ public:
+ explicit Comparator(std::size_t order) : order_(order) {}
+
+ inline bool operator()(const void *lhs, const void *rhs) const {
+ return static_cast<const Child*>(this)->Compare(static_cast<const WordIndex*>(lhs), static_cast<const WordIndex*>(rhs));
+ }
+
+ std::size_t Order() const { return order_; }
+
+ protected:
+ std::size_t order_;
+};
+
+class SuffixOrder : public Comparator<SuffixOrder> {
+ public:
+ explicit SuffixOrder(std::size_t order) : Comparator<SuffixOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (std::size_t i = order_ - 1; i != 0; --i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return lhs[0] < rhs[0];
+ }
+
+ static const unsigned kMatchOffset = 1;
+};
+
+class ContextOrder : public Comparator<ContextOrder> {
+ public:
+ explicit ContextOrder(std::size_t order) : Comparator<ContextOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (int i = order_ - 2; i >= 0; --i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return lhs[order_ - 1] < rhs[order_ - 1];
+ }
+};
+
+class PrefixOrder : public Comparator<PrefixOrder> {
+ public:
+ explicit PrefixOrder(std::size_t order) : Comparator<PrefixOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (std::size_t i = 0; i < order_; ++i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return false;
+ }
+
+ static const unsigned kMatchOffset = 0;
+};
+
+// Sum counts for the same n-gram.
+struct AddCombiner {
+ bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const {
+ NGram first(first_void, compare.Order());
+ // There isn't a const version of NGram.
+ NGram second(const_cast<void*>(second_void), compare.Order());
+ if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false;
+ first.Count() += second.Count();
+ return true;
+ }
+};
+
+// The combiner is only used on a single chain, so I didn't bother to allow
+// that template.
+template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > {
+ private:
+ typedef util::stream::Sort<Compare> S;
+ typedef FixedArray<S> P;
+
+ public:
+ void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) {
+ new (P::end()) S(chain, config, compare);
+ P::Constructed();
+ }
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_SORT__