summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2013-01-18 17:12:51 +0000
committerKenneth Heafield <github@kheafield.com>2013-01-18 17:12:51 +0000
commit0b9031042500d45a098762f0a930bd6a66a58fac (patch)
tree38903f3e29225aa8d444ee66b6963c7148050fee /klm/lm
parent9d7167751a3712a79ad356764d803106a71ce5e3 (diff)
KenLM dffafbf with lmplz source (but not built)
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/Makefile.am4
-rw-r--r--klm/lm/build_binary.cc77
-rw-r--r--klm/lm/builder/README.md47
-rw-r--r--klm/lm/builder/TODO5
-rw-r--r--klm/lm/builder/adjust_counts.cc216
-rw-r--r--klm/lm/builder/adjust_counts.hh44
-rw-r--r--klm/lm/builder/adjust_counts_test.cc106
-rw-r--r--klm/lm/builder/corpus_count.cc223
-rw-r--r--klm/lm/builder/corpus_count.hh42
-rw-r--r--klm/lm/builder/corpus_count_test.cc76
-rw-r--r--klm/lm/builder/discount.hh26
-rw-r--r--klm/lm/builder/header_info.hh20
-rw-r--r--klm/lm/builder/initial_probabilities.cc136
-rw-r--r--klm/lm/builder/initial_probabilities.hh34
-rw-r--r--klm/lm/builder/interpolate.cc65
-rw-r--r--klm/lm/builder/interpolate.hh27
-rw-r--r--klm/lm/builder/joint_order.hh43
-rw-r--r--klm/lm/builder/main.cc94
-rw-r--r--klm/lm/builder/multi_stream.hh180
-rw-r--r--klm/lm/builder/ngram.hh84
-rw-r--r--klm/lm/builder/ngram_stream.hh55
-rw-r--r--klm/lm/builder/pipeline.cc320
-rw-r--r--klm/lm/builder/pipeline.hh40
-rw-r--r--klm/lm/builder/print.cc135
-rw-r--r--klm/lm/builder/print.hh102
-rw-r--r--klm/lm/builder/sort.hh103
-rw-r--r--klm/lm/filter/arpa_io.cc122
-rw-r--r--klm/lm/filter/arpa_io.hh122
-rw-r--r--klm/lm/filter/count_io.hh91
-rw-r--r--klm/lm/filter/format.hh250
-rw-r--r--klm/lm/filter/main.cc249
-rw-r--r--klm/lm/filter/phrase.cc281
-rw-r--r--klm/lm/filter/phrase.hh153
-rw-r--r--klm/lm/filter/thread.hh167
-rw-r--r--klm/lm/filter/vocab.cc54
-rw-r--r--klm/lm/filter/vocab.hh132
-rw-r--r--klm/lm/filter/wrapper.hh58
-rw-r--r--klm/lm/model_test.cc10
-rw-r--r--klm/lm/read_arpa.cc11
-rw-r--r--klm/lm/sizes.cc63
-rw-r--r--klm/lm/sizes.hh17
-rw-r--r--klm/lm/state.hh6
-rw-r--r--klm/lm/trie_sort.cc27
-rw-r--r--klm/lm/trie_sort.hh3
44 files changed, 4043 insertions, 77 deletions
diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am
index 870f7128..f15cbd77 100644
--- a/klm/lm/Makefile.am
+++ b/klm/lm/Makefile.am
@@ -1,7 +1,7 @@
bin_PROGRAMS = build_binary
build_binary_SOURCES = build_binary.cc
-build_binary_LDADD = libklm.a ../util/libklm_util.a -lz
+build_binary_LDADD = libklm.a ../util/libklm_util.a ../util/double-conversion/libklm_util_double.a -lz
#noinst_PROGRAMS = \
# ngram_test
@@ -30,6 +30,7 @@ libklm_a_SOURCES = \
return.hh \
search_hashed.hh \
search_trie.hh \
+ sizes.hh \
state.hh \
trie.hh \
trie_sort.hh \
@@ -49,6 +50,7 @@ libklm_a_SOURCES = \
read_arpa.cc \
search_hashed.cc \
search_trie.cc \
+ sizes.cc \
trie.cc \
trie_sort.cc \
value_build.cc \
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index 2b8c9d5b..ab2c0c32 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -1,10 +1,14 @@
#include "lm/model.hh"
+#include "lm/sizes.hh"
#include "util/file_piece.hh"
+#include "util/usage.hh"
+#include <algorithm>
#include <cstdlib>
#include <exception>
#include <iostream>
#include <iomanip>
+#include <limits>
#include <math.h>
#include <stdlib.h>
@@ -19,8 +23,8 @@ namespace lm {
namespace ngram {
namespace {
-void Usage(const char *name) {
- std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
+void Usage(const char *name, const char *default_mem) {
+ std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
" Default is -100. The ARPA file will always take precedence.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n"
@@ -38,8 +42,11 @@ void Usage(const char *name) {
"trie is a straightforward trie with bit-level packing. It uses the least\n"
"memory and is still faster than SRI or IRST. Building the trie format uses an\n"
"on-disk sort to save memory.\n"
-"-t is the temporary directory prefix. Default is the output file name.\n"
-"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n"
+"-T is the temporary directory prefix. Default is the output file name.\n"
+"-S determines memory use for sorting. Default is " << default_mem << ". This is compatible\n"
+" with GNU sort. The number is followed by a unit: \% for percent of physical\n"
+" memory, b for bytes, K for Kilobytes, M for megabytes, then G,T,P,E,Z,Y. \n"
+" Default unit is K for Kilobytes.\n"
"-q turns quantization on and sets the number of bits (e.g. -q 8).\n"
"-b sets backoff quantization bits. Requires -q and defaults to that value.\n"
"-a compresses pointers using an array of offsets. The parameter is the\n"
@@ -83,47 +90,6 @@ void ParseFileList(const char *from, std::vector<std::string> &to) {
}
}
-void ShowSizes(const char *file, const lm::ngram::Config &config) {
- std::vector<uint64_t> counts;
- util::FilePiece f(file);
- lm::ReadARPACounts(f, counts);
- uint64_t sizes[6];
- sizes[0] = ProbingModel::Size(counts, config);
- sizes[1] = RestProbingModel::Size(counts, config);
- sizes[2] = TrieModel::Size(counts, config);
- sizes[3] = QuantTrieModel::Size(counts, config);
- sizes[4] = ArrayTrieModel::Size(counts, config);
- sizes[5] = QuantArrayTrieModel::Size(counts, config);
- uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
- uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
- uint64_t divide;
- char prefix;
- if (min_length < (1 << 10) * 10) {
- prefix = ' ';
- divide = 1;
- } else if (min_length < (1 << 20) * 10) {
- prefix = 'k';
- divide = 1 << 10;
- } else if (min_length < (1ULL << 30) * 10) {
- prefix = 'M';
- divide = 1 << 20;
- } else {
- prefix = 'G';
- divide = 1 << 30;
- }
- long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide))));
- std::cout << "Memory estimate:\ntype ";
- // right align bytes.
- for (long int i = 0; i < length - 2; ++i) std::cout << ' ';
- std::cout << prefix << "B\n"
- "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"
- "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n"
- "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n"
- "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
- "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
- "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";
-}
-
void ProbingQuantizationUnsupported() {
std::cerr << "Quantization is only implemented in the trie data structure." << std::endl;
exit(1);
@@ -136,11 +102,14 @@ void ProbingQuantizationUnsupported() {
int main(int argc, char *argv[]) {
using namespace lm::ngram;
+ const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G";
+
try {
bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false;
lm::ngram::Config config;
+ config.building_memory = util::ParseSize(default_mem);
int opt;
- while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:sir:")) != -1) {
+ while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
@@ -161,12 +130,16 @@ int main(int argc, char *argv[]) {
case 'p':
config.probing_multiplier = ParseFloat(optarg);
break;
- case 't':
+ case 't': // legacy
+ case 'T':
config.temporary_directory_prefix = optarg;
break;
- case 'm':
+ case 'm': // legacy
config.building_memory = ParseUInt(optarg) * 1048576;
break;
+ case 'S':
+ config.building_memory = std::min(static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), util::ParseSize(optarg));
+ break;
case 'w':
set_write_method = true;
if (!strcmp(optarg, "mmap")) {
@@ -174,7 +147,7 @@ int main(int argc, char *argv[]) {
} else if (!strcmp(optarg, "after")) {
config.write_method = Config::WRITE_AFTER;
} else {
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
break;
case 's':
@@ -189,7 +162,7 @@ int main(int argc, char *argv[]) {
config.rest_function = Config::REST_LOWER;
break;
default:
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
}
if (!quantize && set_backoff_bits) {
@@ -212,7 +185,7 @@ int main(int argc, char *argv[]) {
from_file = argv[optind + 1];
config.write_mmap = argv[optind + 2];
} else {
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
if (!strcmp(model_type, "probing")) {
if (!set_write_method) config.write_method = Config::WRITE_AFTER;
@@ -242,7 +215,7 @@ int main(int argc, char *argv[]) {
}
}
} else {
- Usage(argv[0]);
+ Usage(argv[0], default_mem);
}
}
catch (const std::exception &e) {
diff --git a/klm/lm/builder/README.md b/klm/lm/builder/README.md
new file mode 100644
index 00000000..be0d35e2
--- /dev/null
+++ b/klm/lm/builder/README.md
@@ -0,0 +1,47 @@
+Dependencies
+============
+
+Boost >= 1.42.0 is required.
+
+For Ubuntu,
+```bash
+sudo apt-get install libboost1.48-all-dev
+```
+
+Alternatively, you can download, compile, and install it yourself:
+
+```bash
+wget http://sourceforge.net/projects/boost/files/boost/1.52.0/boost_1_52_0.tar.gz/download -O boost_1_52_0.tar.gz
+tar -xvzf boost_1_52_0.tar.gz
+cd boost_1_52_0
+./bootstrap.sh
+./b2
+sudo ./b2 install
+```
+
+Local install options (in a user-space prefix directory) are also possible. See http://www.boost.org/doc/libs/1_52_0/doc/html/bbv2/installation.html.
+
+
+Building
+========
+
+```bash
+bjam
+```
+Your distribution might package bjam and boost-build separately from Boost. Both are required.
+
+Usage
+=====
+
+Run
+```bash
+$ bin/lmplz
+```
+to see command line arguments
+
+Running
+=======
+
+```bash
+bin/lmplz -o 5 <text >text.arpa
+```
diff --git a/klm/lm/builder/TODO b/klm/lm/builder/TODO
new file mode 100644
index 00000000..cb5aef3a
--- /dev/null
+++ b/klm/lm/builder/TODO
@@ -0,0 +1,5 @@
+More tests!
+Sharding.
+Some way to manage all the crazy config options.
+Option to build the binary file directly.
+Interpolation of different orders.
diff --git a/klm/lm/builder/adjust_counts.cc b/klm/lm/builder/adjust_counts.cc
new file mode 100644
index 00000000..a6f48011
--- /dev/null
+++ b/klm/lm/builder/adjust_counts.cc
@@ -0,0 +1,216 @@
+#include "lm/builder/adjust_counts.hh"
+#include "lm/builder/multi_stream.hh"
+#include "util/stream/timer.hh"
+
+#include <algorithm>
+
+namespace lm { namespace builder {
+
+BadDiscountException::BadDiscountException() throw() {}
+BadDiscountException::~BadDiscountException() throw() {}
+
+namespace {
+// Return last word in full that is different.
+const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) {
+ const WordIndex *cur_word = full.end() - 1;
+ const WordIndex *pre_word = lower_last.end() - 1;
+ // Find last difference.
+ for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {}
+ return cur_word;
+}
+
+class StatCollector {
+ public:
+ StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts)
+ : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) {
+ memset(&orders_[0], 0, sizeof(OrderStat) * order);
+ }
+
+ ~StatCollector() {}
+
+ void CalculateDiscounts() {
+ counts_.resize(orders_.size());
+ discounts_.resize(orders_.size());
+ for (std::size_t i = 0; i < orders_.size(); ++i) {
+ const OrderStat &s = orders_[i];
+ counts_[i] = s.count;
+
+ for (unsigned j = 1; j < 4; ++j) {
+ // TODO: Specialize error message for j == 3, meaning 3+
+ UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
+ << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
+ << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?");
+ }
+
+ // See equation (26) in Chen and Goodman.
+ discounts_[i].amount[0] = 0.0;
+ float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
+ for (unsigned j = 1; j < 4; ++j) {
+ discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
+ UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
+ }
+ }
+ }
+
+ void Add(std::size_t order_minus_1, uint64_t count) {
+ OrderStat &stat = orders_[order_minus_1];
+ ++stat.count;
+ if (count < 5) ++stat.n[count];
+ }
+
+ void AddFull(uint64_t count) {
+ ++full_.count;
+ if (count < 5) ++full_.n[count];
+ }
+
+ private:
+ struct OrderStat {
+ // n_1 in equation 26 of Chen and Goodman etc
+ uint64_t n[5];
+ uint64_t count;
+ };
+
+ std::vector<OrderStat> orders_;
+ OrderStat &full_;
+
+ std::vector<uint64_t> &counts_;
+ std::vector<Discount> &discounts_;
+};
+
+// Reads all entries in order like NGramStream does.
+// But deletes any entries that have <s> in the 1st (not 0th) position on the
+// way out by putting other entries in their place. This disrupts the sort
+// order but we don't care because the data is going to be sorted again.
+class CollapseStream {
+ public:
+ CollapseStream(const util::stream::ChainPosition &position) :
+ current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
+ block_(position) {
+ StartBlock();
+ }
+
+ const NGram &operator*() const { return current_; }
+ const NGram *operator->() const { return &current_; }
+
+ operator bool() const { return block_; }
+
+ CollapseStream &operator++() {
+ assert(block_);
+ if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) {
+ memcpy(current_.Base(), copy_from_, current_.TotalSize());
+ UpdateCopyFrom();
+ }
+ current_.NextInMemory();
+ uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
+ if (current_.Base() == block_base + block_->ValidSize()) {
+ block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base);
+ ++block_;
+ StartBlock();
+ }
+ return *this;
+ }
+
+ private:
+ void StartBlock() {
+ for (; ; ++block_) {
+ if (!block_) return;
+ if (block_->ValidSize()) break;
+ }
+ current_.ReBase(block_->Get());
+ copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize();
+ UpdateCopyFrom();
+ }
+
+ // Find last without bos.
+ void UpdateCopyFrom() {
+ for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
+ if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break;
+ }
+ }
+
+ NGram current_;
+
+ // Goes backwards in the block
+ uint8_t *copy_from_;
+
+ util::stream::Link block_;
+};
+
+} // namespace
+
+void AdjustCounts::Run(const ChainPositions &positions) {
+ UTIL_TIMER("(%w s) Adjusted counts\n");
+
+ const std::size_t order = positions.size();
+ StatCollector stats(order, counts_, discounts_);
+ if (order == 1) {
+ // Only unigrams. Just collect stats.
+ for (NGramStream full(positions[0]); full; ++full)
+ stats.AddFull(full->Count());
+ stats.CalculateDiscounts();
+ return;
+ }
+
+ NGramStreams streams;
+ streams.Init(positions, positions.size() - 1);
+ CollapseStream full(positions[positions.size() - 1]);
+
+ // Initialization: <unk> has count 0 and so does <s>.
+ NGramStream *lower_valid = streams.begin();
+ streams[0]->Count() = 0;
+ *streams[0]->begin() = kUNK;
+ stats.Add(0, 0);
+ (++streams[0])->Count() = 0;
+ *streams[0]->begin() = kBOS;
+ // not in stats because it will get put in later.
+
+ // iterate over full (the stream of the highest order ngrams)
+ for (; full; ++full) {
+ const WordIndex *different = FindDifference(*full, **lower_valid);
+ std::size_t same = full->end() - 1 - different;
+ // Increment the adjusted count.
+ if (same) ++streams[same - 1]->Count();
+
+ // Output all the valid ones that changed.
+ for (; lower_valid >= &streams[same]; --lower_valid) {
+ stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count());
+ ++*lower_valid;
+ }
+
+ // This is here because bos is also const WordIndex *, so copy gets
+ // consistent argument types.
+ const WordIndex *full_end = full->end();
+ // Initialize and mark as valid up to bos.
+ const WordIndex *bos;
+ for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
+ ++lower_valid;
+ std::copy(bos, full_end, (*lower_valid)->begin());
+ (*lower_valid)->Count() = 1;
+ }
+ // Now bos indicates where <s> is or is the 0th word of full.
+ if (bos != full->begin()) {
+ // There is an <s> beyond the 0th word.
+ NGramStream &to = *++lower_valid;
+ std::copy(bos, full_end, to->begin());
+ to->Count() = full->Count();
+ } else {
+ stats.AddFull(full->Count());
+ }
+ assert(lower_valid >= &streams[0]);
+ }
+
+ // Output everything valid.
+ for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
+ stats.Add(s - streams.begin(), (*s)->Count());
+ ++*s;
+ }
+ // Poison everyone! Except the N-grams which were already poisoned by the input.
+ for (NGramStream *s = streams.begin(); s != streams.end(); ++s)
+ s->Poison();
+
+ stats.CalculateDiscounts();
+
+ // NOTE: See special early-return case for unigrams near the top of this function
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh
new file mode 100644
index 00000000..f38ff79d
--- /dev/null
+++ b/klm/lm/builder/adjust_counts.hh
@@ -0,0 +1,44 @@
+#ifndef LM_BUILDER_ADJUST_COUNTS__
+#define LM_BUILDER_ADJUST_COUNTS__
+
+#include "lm/builder/discount.hh"
+#include "util/exception.hh"
+
+#include <vector>
+
+#include <stdint.h>
+
+namespace lm {
+namespace builder {
+
+class ChainPositions;
+
+class BadDiscountException : public util::Exception {
+ public:
+ BadDiscountException() throw();
+ ~BadDiscountException() throw();
+};
+
+/* Compute adjusted counts.
+ * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts.
+ * Output: [1,N]-grams with adjusted counts.
+ * [1,N)-grams are in suffix order
+ * N-grams are in undefined order (they're going to be sorted anyway).
+ */
+class AdjustCounts {
+ public:
+ AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts)
+ : counts_(counts), discounts_(discounts) {}
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ std::vector<uint64_t> &counts_;
+ std::vector<Discount> &discounts_;
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_ADJUST_COUNTS__
+
diff --git a/klm/lm/builder/adjust_counts_test.cc b/klm/lm/builder/adjust_counts_test.cc
new file mode 100644
index 00000000..68b5f33e
--- /dev/null
+++ b/klm/lm/builder/adjust_counts_test.cc
@@ -0,0 +1,106 @@
+#include "lm/builder/adjust_counts.hh"
+
+#include "lm/builder/multi_stream.hh"
+#include "util/scoped.hh"
+
+#include <boost/thread/thread.hpp>
+#define BOOST_TEST_MODULE AdjustCounts
+#include <boost/test/unit_test.hpp>
+
+namespace lm { namespace builder { namespace {
+
+class KeepCopy {
+ public:
+ KeepCopy() : size_(0) {}
+
+ void Run(const util::stream::ChainPosition &position) {
+ for (util::stream::Link link(position); link; ++link) {
+ mem_.call_realloc(size_ + link->ValidSize());
+ memcpy(static_cast<uint8_t*>(mem_.get()) + size_, link->Get(), link->ValidSize());
+ size_ += link->ValidSize();
+ }
+ }
+
+ uint8_t *Get() { return static_cast<uint8_t*>(mem_.get()); }
+ std::size_t Size() const { return size_; }
+
+ private:
+ util::scoped_malloc mem_;
+ std::size_t size_;
+};
+
+struct Gram4 {
+ WordIndex ids[4];
+ uint64_t count;
+};
+
+class WriteInput {
+ public:
+ void Run(const util::stream::ChainPosition &position) {
+ NGramStream input(position);
+ Gram4 grams[] = {
+ {{0,0,0,0},10},
+ {{0,0,3,0},3},
+ // bos
+ {{1,1,1,2},5},
+ {{0,0,3,2},5},
+ };
+ for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) {
+ memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4);
+ input->Count() = grams[i].count;
+ }
+ input.Poison();
+ }
+};
+
+BOOST_AUTO_TEST_CASE(Simple) {
+ KeepCopy outputs[4];
+ std::vector<uint64_t> counts;
+ std::vector<Discount> discount;
+ {
+ util::stream::ChainConfig config;
+ config.total_memory = 100;
+ config.block_count = 1;
+ Chains chains(4);
+ for (unsigned i = 0; i < 4; ++i) {
+ config.entry_size = NGram::TotalSize(i + 1);
+ chains.push_back(config);
+ }
+
+ chains[3] >> WriteInput();
+ ChainPositions for_adjust(chains);
+ for (unsigned i = 0; i < 4; ++i) {
+ chains[i] >> boost::ref(outputs[i]);
+ }
+ chains >> util::stream::kRecycle;
+ BOOST_CHECK_THROW(AdjustCounts(counts, discount).Run(for_adjust), BadDiscountException);
+ }
+ BOOST_REQUIRE_EQUAL(4UL, counts.size());
+ BOOST_CHECK_EQUAL(4UL, counts[0]);
+ // These are no longer set because the discounts are bad.
+/* BOOST_CHECK_EQUAL(4UL, counts[1]);
+ BOOST_CHECK_EQUAL(3UL, counts[2]);
+ BOOST_CHECK_EQUAL(3UL, counts[3]);*/
+ BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size());
+ NGram uni(outputs[0].Get(), 1);
+ BOOST_CHECK_EQUAL(kUNK, *uni.begin());
+ BOOST_CHECK_EQUAL(0ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(kBOS, *uni.begin());
+ BOOST_CHECK_EQUAL(0ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(0UL, *uni.begin());
+ BOOST_CHECK_EQUAL(2ULL, uni.Count());
+ uni.NextInMemory();
+ BOOST_CHECK_EQUAL(2ULL, uni.Count());
+ BOOST_CHECK_EQUAL(2UL, *uni.begin());
+
+ BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size());
+ NGram bi(outputs[1].Get(), 2);
+ BOOST_CHECK_EQUAL(0UL, *bi.begin());
+ BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1));
+ BOOST_CHECK_EQUAL(1ULL, bi.Count());
+ bi.NextInMemory();
+}
+
+}}} // namespaces
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc
new file mode 100644
index 00000000..8c3de57d
--- /dev/null
+++ b/klm/lm/builder/corpus_count.cc
@@ -0,0 +1,223 @@
+#include "lm/builder/corpus_count.hh"
+
+#include "lm/builder/ngram.hh"
+#include "lm/lm_exception.hh"
+#include "lm/word_index.hh"
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/murmur_hash.hh"
+#include "util/probing_hash_table.hh"
+#include "util/scoped.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/timer.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/unordered_set.hpp>
+#include <boost/unordered_map.hpp>
+
+#include <functional>
+
+#include <stdint.h>
+
+namespace lm {
+namespace builder {
+namespace {
+
+class VocabHandout {
+ public:
+ explicit VocabHandout(int fd) {
+ util::scoped_fd duped(util::DupOrThrow(fd));
+ word_list_.reset(util::FDOpenOrThrow(duped));
+
+ Lookup("<unk>"); // Force 0
+ Lookup("<s>"); // Force 1
+ Lookup("</s>"); // Force 2
+ }
+
+ WordIndex Lookup(const StringPiece &word) {
+ uint64_t hashed = util::MurmurHashNative(word.data(), word.size());
+ std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size())));
+ if (ret.second) {
+ char null_delimit = 0;
+ util::WriteOrThrow(word_list_.get(), word.data(), word.size());
+ util::WriteOrThrow(word_list_.get(), &null_delimit, 1);
+ UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ }
+ return ret.first->second;
+ }
+
+ WordIndex Size() const {
+ return seen_.size();
+ }
+
+ private:
+ typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen;
+
+ Seen seen_;
+
+ util::scoped_FILE word_list_;
+};
+
+class DedupeHash : public std::unary_function<const WordIndex *, bool> {
+ public:
+ explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
+
+ std::size_t operator()(const WordIndex *start) const {
+ return util::MurmurHashNative(start, size_);
+ }
+
+ private:
+ const std::size_t size_;
+};
+
+class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> {
+ public:
+ explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {}
+
+ bool operator()(const WordIndex *first, const WordIndex *second) const {
+ return !memcmp(first, second, size_);
+ }
+
+ private:
+ const std::size_t size_;
+};
+
+struct DedupeEntry {
+ typedef WordIndex *Key;
+ Key GetKey() const { return key; }
+ Key key;
+ static DedupeEntry Construct(WordIndex *at) {
+ DedupeEntry ret;
+ ret.key = at;
+ return ret;
+ }
+};
+
+typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
+
+const float kProbingMultiplier = 1.5;
+
+class Writer {
+ public:
+ Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
+ : block_(position), gram_(block_->Get(), order),
+ dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()),
+ dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
+ buffer_(new WordIndex[order - 1]),
+ block_size_(position.GetChain().BlockSize()) {
+ dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
+ if (order == 1) {
+ // Add special words. AdjustCounts is responsible if order != 1.
+ AddUnigramWord(kUNK);
+ AddUnigramWord(kBOS);
+ }
+ }
+
+ ~Writer() {
+ block_->SetValidSize(reinterpret_cast<const uint8_t*>(gram_.begin()) - static_cast<const uint8_t*>(block_->Get()));
+ (++block_).Poison();
+ }
+
+ // Write context with a bunch of <s>
+ void StartSentence() {
+ for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) {
+ *i = kBOS;
+ }
+ }
+
+ void Append(WordIndex word) {
+ *(gram_.end() - 1) = word;
+ Dedupe::MutableIterator at;
+ bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at);
+ if (found) {
+ // Already present.
+ NGram already(at->key, gram_.Order());
+ ++(already.Count());
+ // Shift left by one.
+ memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1));
+ return;
+ }
+ // Complete the write.
+ gram_.Count() = 1;
+ // Prepare the next n-gram.
+ if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) {
+ NGram last(gram_);
+ gram_.NextInMemory();
+ std::copy(last.begin() + 1, last.end(), gram_.begin());
+ return;
+ }
+ // Block end. Need to store the context in a temporary buffer.
+ std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
+ dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ block_->SetValidSize(block_size_);
+ gram_.ReBase((++block_)->Get());
+ std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
+ }
+
+ private:
+ void AddUnigramWord(WordIndex index) {
+ *gram_.begin() = index;
+ gram_.Count() = 0;
+ gram_.NextInMemory();
+ if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) {
+ block_->SetValidSize(block_size_);
+ gram_.ReBase((++block_)->Get());
+ }
+ }
+
+ util::stream::Link block_;
+
+ NGram gram_;
+
+ // This is the memory behind the invalid value in dedupe_.
+ std::vector<WordIndex> dedupe_invalid_;
+ // Hash table combiner implementation.
+ Dedupe dedupe_;
+
+ // Small buffer to hold existing ngrams when shifting across a block boundary.
+ boost::scoped_array<WordIndex> buffer_;
+
+ const std::size_t block_size_;
+};
+
+} // namespace
+
+float CorpusCount::DedupeMultiplier(std::size_t order) {
+ return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
+}
+
+CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
+ : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
+ dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
+ dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
+ token_count_ = 0;
+ type_count_ = 0;
+}
+
+void CorpusCount::Run(const util::stream::ChainPosition &position) {
+ UTIL_TIMER("(%w s) Counted n-grams\n");
+
+ VocabHandout vocab(vocab_write_);
+ const WordIndex end_sentence = vocab.Lookup("</s>");
+ Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
+ uint64_t count = 0;
+ try {
+ while(true) {
+ StringPiece line(from_.ReadLine());
+ writer.StartSentence();
+ for (util::TokenIter<util::AnyCharacter, true> w(line, " \t"); w; ++w) {
+ WordIndex word = vocab.Lookup(*w);
+ UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future.");
+ writer.Append(word);
+ ++count;
+ }
+ writer.Append(end_sentence);
+ }
+ } catch (const util::EndOfFileException &e) {}
+ token_count_ = count;
+ type_count_ = vocab.Size();
+}
+
+} // namespace builder
+} // namespace lm
diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh
new file mode 100644
index 00000000..e255bad1
--- /dev/null
+++ b/klm/lm/builder/corpus_count.hh
@@ -0,0 +1,42 @@
+#ifndef LM_BUILDER_CORPUS_COUNT__
+#define LM_BUILDER_CORPUS_COUNT__
+
+#include "lm/word_index.hh"
+#include "util/scoped.hh"
+
+#include <cstddef>
+#include <string>
+#include <stdint.h>
+
+namespace util {
+class FilePiece;
+namespace stream {
+class ChainPosition;
+} // namespace stream
+} // namespace util
+
+namespace lm {
+namespace builder {
+
+class CorpusCount {
+ public:
+ // Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size
+ static float DedupeMultiplier(std::size_t order);
+
+ CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);
+
+ void Run(const util::stream::ChainPosition &position);
+
+ private:
+ util::FilePiece &from_;
+ int vocab_write_;
+ uint64_t &token_count_;
+ WordIndex &type_count_;
+
+ std::size_t dedupe_mem_size_;
+ util::scoped_malloc dedupe_mem_;
+};
+
+} // namespace builder
+} // namespace lm
+#endif // LM_BUILDER_CORPUS_COUNT__
diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc
new file mode 100644
index 00000000..8d53ca9d
--- /dev/null
+++ b/klm/lm/builder/corpus_count_test.cc
@@ -0,0 +1,76 @@
+#include "lm/builder/corpus_count.hh"
+
+#include "lm/builder/ngram.hh"
+#include "lm/builder/ngram_stream.hh"
+
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/tokenize_piece.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/stream.hh"
+
+#define BOOST_TEST_MODULE CorpusCountTest
+#include <boost/test/unit_test.hpp>
+
+namespace lm { namespace builder { namespace {
+
+#define Check(str, count) { \
+ BOOST_REQUIRE(stream); \
+ w = stream->begin(); \
+ for (util::TokenIter<util::AnyCharacter, true> t(str, " "); t; ++t, ++w) { \
+ BOOST_CHECK_EQUAL(*t, v[*w]); \
+ } \
+ BOOST_CHECK_EQUAL((uint64_t)count, stream->Count()); \
+ ++stream; \
+}
+
+BOOST_AUTO_TEST_CASE(Short) {
+ util::scoped_fd input_file(util::MakeTemp("corpus_count_test_temp"));
+ const char input[] = "looking on a little more loin\non a little more loin\non foo little more loin\nbar\n\n";
+ // Blocks of 10 are
+ // looking on a little more loin </s> on a little[duplicate] more[duplicate] loin[duplicate] </s>[duplicate] on[duplicate] foo
+ // little more loin </s> bar </s> </s>
+
+ util::WriteOrThrow(input_file.get(), input, sizeof(input) - 1);
+ util::FilePiece input_piece(input_file.release(), "temp file");
+
+ util::stream::ChainConfig config;
+ config.entry_size = NGram::TotalSize(3);
+ config.total_memory = config.entry_size * 20;
+ config.block_count = 2;
+
+ util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab"));
+
+ util::stream::Chain chain(config);
+ NGramStream stream;
+ uint64_t token_count;
+ WordIndex type_count;
+ CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;
+
+ const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"};
+
+ WordIndex *w;
+
+ Check("<s> <s> looking", 1);
+ Check("<s> looking on", 1);
+ Check("looking on a", 1);
+ Check("on a little", 2);
+ Check("a little more", 2);
+ Check("little more loin", 2);
+ Check("more loin </s>", 2);
+ Check("<s> <s> on", 2);
+ Check("<s> on a", 1);
+ Check("<s> on foo", 1);
+ Check("on foo little", 1);
+ Check("foo little more", 1);
+ Check("little more loin", 1);
+ Check("more loin </s>", 1);
+ Check("<s> <s> bar", 1);
+ Check("<s> bar </s>", 1);
+ Check("<s> <s> </s>", 1);
+ BOOST_CHECK(!stream);
+ BOOST_CHECK_EQUAL(sizeof(v) / sizeof(const char*), type_count);
+}
+
+}}} // namespaces
diff --git a/klm/lm/builder/discount.hh b/klm/lm/builder/discount.hh
new file mode 100644
index 00000000..754fb20d
--- /dev/null
+++ b/klm/lm/builder/discount.hh
@@ -0,0 +1,26 @@
+#ifndef BUILDER_DISCOUNT__
+#define BUILDER_DISCOUNT__
+
+#include <algorithm>
+
+#include <inttypes.h>
+
+namespace lm {
+namespace builder {
+
+struct Discount {
+ float amount[4];
+
+ float Get(uint64_t count) const {
+ return amount[std::min<uint64_t>(count, 3)];
+ }
+
+ float Apply(uint64_t count) const {
+ return static_cast<float>(count) - Get(count);
+ }
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // BUILDER_DISCOUNT__
diff --git a/klm/lm/builder/header_info.hh b/klm/lm/builder/header_info.hh
new file mode 100644
index 00000000..ccca1456
--- /dev/null
+++ b/klm/lm/builder/header_info.hh
@@ -0,0 +1,20 @@
+#ifndef LM_BUILDER_HEADER_INFO__
+#define LM_BUILDER_HEADER_INFO__
+
+#include <string>
+#include <stdint.h>
+
+// Some configuration info that is used to add
+// comments to the beginning of an ARPA file
+struct HeaderInfo {
+ const std::string input_file;
+ const uint64_t token_count;
+
+ HeaderInfo(const std::string& input_file_in, uint64_t token_count_in)
+ : input_file(input_file_in), token_count(token_count_in) {}
+
+ // TODO: Add smoothing type
+ // TODO: More info if multiple models were interpolated
+};
+
+#endif
diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc
new file mode 100644
index 00000000..58b42a20
--- /dev/null
+++ b/klm/lm/builder/initial_probabilities.cc
@@ -0,0 +1,136 @@
+#include "lm/builder/initial_probabilities.hh"
+
+#include "lm/builder/discount.hh"
+#include "lm/builder/ngram_stream.hh"
+#include "lm/builder/sort.hh"
+#include "util/file.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/io.hh"
+#include "util/stream/stream.hh"
+
+#include <vector>
+
+namespace lm { namespace builder {
+
+namespace {
+struct BufferEntry {
+ // Gamma from page 20 of Chen and Goodman.
+ float gamma;
+ // \sum_w a(c w) for all w.
+ float denominator;
+};
+
+// Extract an array of gamma from an array of BufferEntry.
+class OnlyGamma {
+ public:
+ void Run(const util::stream::ChainPosition &position) {
+ for (util::stream::Link block_it(position); block_it; ++block_it) {
+ float *out = static_cast<float*>(block_it->Get());
+ const float *in = out;
+ const float *end = static_cast<const float*>(block_it->ValidEnd());
+ for (out += 1, in += 2; in < end; out += 1, in += 2) {
+ *out = *in;
+ }
+ block_it->SetValidSize(block_it->ValidSize() / 2);
+ }
+ }
+};
+
+class AddRight {
+ public:
+ AddRight(const Discount &discount, const util::stream::ChainPosition &input)
+ : discount_(discount), input_(input) {}
+
+ void Run(const util::stream::ChainPosition &output) {
+ NGramStream in(input_);
+ util::stream::Stream out(output);
+
+ std::vector<WordIndex> previous(in->Order() - 1);
+ const std::size_t size = sizeof(WordIndex) * previous.size();
+ for(; in; ++out) {
+ memcpy(&previous[0], in->begin(), size);
+ uint64_t denominator = 0;
+ uint64_t counts[4];
+ memset(counts, 0, sizeof(counts));
+ do {
+ denominator += in->Count();
+ ++counts[std::min(in->Count(), static_cast<uint64_t>(3))];
+ } while (++in && !memcmp(&previous[0], in->begin(), size));
+ BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get());
+ entry.denominator = static_cast<float>(denominator);
+ entry.gamma = 0.0;
+ for (unsigned i = 1; i <= 3; ++i) {
+ entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]);
+ }
+ entry.gamma /= entry.denominator;
+ }
+ out.Poison();
+ }
+
+ private:
+ const Discount &discount_;
+ const util::stream::ChainPosition input_;
+};
+
+class MergeRight {
+ public:
+ MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount)
+ : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {}
+
+ // calculate the initial probability of each n-gram (before order-interpolation)
+ // Run() gets invoked once for each order
+ void Run(const util::stream::ChainPosition &primary) {
+ util::stream::Stream summed(from_adder_);
+
+ NGramStream grams(primary);
+
+ // Without interpolation, the interpolation weight goes to <unk>.
+ if (grams->Order() == 1 && !interpolate_unigrams_) {
+ BufferEntry sums(*static_cast<const BufferEntry*>(summed.Get()));
+ assert(*grams->begin() == kUNK);
+ grams->Value().uninterp.prob = sums.gamma;
+ grams->Value().uninterp.gamma = 0.0;
+ while (++grams) {
+ grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator;
+ grams->Value().uninterp.gamma = 0.0;
+ }
+ ++summed;
+ return;
+ }
+
+ std::vector<WordIndex> previous(grams->Order() - 1);
+ const std::size_t size = sizeof(WordIndex) * previous.size();
+ for (; grams; ++summed) {
+ memcpy(&previous[0], grams->begin(), size);
+ const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get());
+ do {
+ Payload &pay = grams->Value();
+ pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator;
+ pay.uninterp.gamma = sums.gamma;
+ } while (++grams && !memcmp(&previous[0], grams->begin(), size));
+ }
+ }
+
+ private:
+ bool interpolate_unigrams_;
+ util::stream::ChainPosition from_adder_;
+ Discount discount_;
+};
+
+} // namespace
+
+void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) {
+ util::stream::ChainConfig gamma_config = config.adder_out;
+ gamma_config.entry_size = sizeof(BufferEntry);
+ for (size_t i = 0; i < primary.size(); ++i) {
+ util::stream::ChainPosition second(second_in[i].Add());
+ second_in[i] >> util::stream::kRecycle;
+ gamma_out.push_back(gamma_config);
+ gamma_out[i] >> AddRight(discounts[i], second);
+ primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]);
+ // Don't bother with the OnlyGamma thread for something to discard.
+ if (i) gamma_out[i] >> OnlyGamma();
+ }
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/initial_probabilities.hh b/klm/lm/builder/initial_probabilities.hh
new file mode 100644
index 00000000..626388eb
--- /dev/null
+++ b/klm/lm/builder/initial_probabilities.hh
@@ -0,0 +1,34 @@
+#ifndef LM_BUILDER_INITIAL_PROBABILITIES__
+#define LM_BUILDER_INITIAL_PROBABILITIES__
+
+#include "lm/builder/discount.hh"
+#include "util/stream/config.hh"
+
+#include <vector>
+
+namespace lm {
+namespace builder {
+class Chains;
+
+struct InitialProbabilitiesConfig {
+ // These should be small buffers to keep the adder from getting too far ahead
+ util::stream::ChainConfig adder_in;
+ util::stream::ChainConfig adder_out;
+ // SRILM doesn't normally interpolate unigrams.
+ bool interpolate_unigrams;
+};
+
+/* Compute initial (uninterpolated) probabilities
+ * primary: the normal chain of n-grams. Incoming is context sorted adjusted
+ * counts. Outgoing has uninterpolated probabilities for use by Interpolate.
+ * second_in: a second copy of the primary input. Discard the output.
+ * gamma_out: Computed gamma values are output on these chains in suffix order.
+ * The values are bare floats and should be buffered for interpolation to
+ * use.
+ */
+void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out);
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_INITIAL_PROBABILITIES__
diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc
new file mode 100644
index 00000000..50026806
--- /dev/null
+++ b/klm/lm/builder/interpolate.cc
@@ -0,0 +1,65 @@
+#include "lm/builder/interpolate.hh"
+
+#include "lm/builder/joint_order.hh"
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/sort.hh"
+#include "lm/lm_exception.hh"
+
+#include <assert.h>
+
+namespace lm { namespace builder {
+namespace {
+
+class Callback {
+ public:
+ Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) {
+ probs_[0] = uniform_prob;
+ for (std::size_t i = 0; i < backoffs.size(); ++i) {
+ backoffs_.push_back(backoffs[i]);
+ }
+ }
+
+ ~Callback() {
+ for (std::size_t i = 0; i < backoffs_.size(); ++i) {
+ if (backoffs_[i]) {
+ std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl;
+ abort();
+ }
+ }
+ }
+
+ void Enter(unsigned order_minus_1, NGram &gram) {
+ Payload &pay = gram.Value();
+ pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
+ probs_[order_minus_1 + 1] = pay.complete.prob;
+ pay.complete.prob = log10(pay.complete.prob);
+ // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling.
+ if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
+ pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
+ ++backoffs_[order_minus_1];
+ } else {
+ // Not a context.
+ pay.complete.backoff = 0.0;
+ }
+ }
+
+ void Exit(unsigned, const NGram &) const {}
+
+ private:
+ FixedArray<util::stream::Stream> backoffs_;
+
+ std::vector<float> probs_;
+};
+} // namespace
+
+Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)
+ : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {}
+
+// perform order-wise interpolation
+void Interpolate::Run(const ChainPositions &positions) {
+ assert(positions.size() == backoffs_.size() + 1);
+ Callback callback(uniform_prob_, backoffs_);
+ JointOrder<Callback, SuffixOrder>(positions, callback);
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/interpolate.hh b/klm/lm/builder/interpolate.hh
new file mode 100644
index 00000000..9268d404
--- /dev/null
+++ b/klm/lm/builder/interpolate.hh
@@ -0,0 +1,27 @@
+#ifndef LM_BUILDER_INTERPOLATE__
+#define LM_BUILDER_INTERPOLATE__
+
+#include <stdint.h>
+
+#include "lm/builder/multi_stream.hh"
+
+namespace lm { namespace builder {
+
+/* Interpolate step.
+ * Input: suffix sorted n-grams with (p_uninterpolated, gamma) from
+ * InitialProbabilities.
+ * Output: suffix sorted n-grams with complete probability
+ */
+class Interpolate {
+ public:
+ explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs);
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ float uniform_prob_;
+ ChainPositions backoffs_;
+};
+
+}} // namespaces
+#endif // LM_BUILDER_INTERPOLATE__
diff --git a/klm/lm/builder/joint_order.hh b/klm/lm/builder/joint_order.hh
new file mode 100644
index 00000000..b5620144
--- /dev/null
+++ b/klm/lm/builder/joint_order.hh
@@ -0,0 +1,43 @@
+#ifndef LM_BUILDER_JOINT_ORDER__
+#define LM_BUILDER_JOINT_ORDER__
+
+#include "lm/builder/multi_stream.hh"
+#include "lm/lm_exception.hh"
+
+#include <string.h>
+
+namespace lm { namespace builder {
+
+template <class Callback, class Compare> void JointOrder(const ChainPositions &positions, Callback &callback) {
+ // Allow matching to reference streams[-1].
+ NGramStreams streams_with_dummy;
+ streams_with_dummy.InitWithDummy(positions);
+ NGramStream *streams = streams_with_dummy.begin() + 1;
+
+ unsigned int order;
+ for (order = 0; order < positions.size() && streams[order]; ++order) {}
+ assert(order); // should always have <unk>.
+ unsigned int current = 0;
+ while (true) {
+ // Does the context match the lower one?
+ if (!memcmp(streams[static_cast<int>(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset, sizeof(WordIndex) * current)) {
+ callback.Enter(current, *streams[current]);
+ // Transition to looking for extensions.
+ if (++current < order) continue;
+ }
+ // No extension left.
+ while(true) {
+ assert(current > 0);
+ --current;
+ callback.Exit(current, *streams[current]);
+ if (++streams[current]) break;
+ UTIL_THROW_IF(order != current + 1, FormatLoadException, "Detected n-gram without matching suffix");
+ order = current;
+ if (!order) return;
+ }
+ }
+}
+
+}} // namespaces
+
+#endif // LM_BUILDER_JOINT_ORDER__
diff --git a/klm/lm/builder/main.cc b/klm/lm/builder/main.cc
new file mode 100644
index 00000000..90b9dca2
--- /dev/null
+++ b/klm/lm/builder/main.cc
@@ -0,0 +1,94 @@
+#include "lm/builder/pipeline.hh"
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/usage.hh"
+
+#include <iostream>
+
+#include <boost/program_options.hpp>
+
+namespace {
+class SizeNotify {
+ public:
+ SizeNotify(std::size_t &out) : behind_(out) {}
+
+ void operator()(const std::string &from) {
+ behind_ = util::ParseSize(from);
+ }
+
+ private:
+ std::size_t &behind_;
+};
+
+boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value) {
+ return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value);
+}
+
+} // namespace
+
+int main(int argc, char *argv[]) {
+ try {
+ namespace po = boost::program_options;
+ po::options_description options("Language model building options");
+ lm::builder::PipelineConfig pipeline;
+
+ options.add_options()
+ ("order,o", po::value<std::size_t>(&pipeline.order)->required(), "Order of the model")
+ ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
+ ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
+ ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
+ ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step")
+ ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
+ ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
+ ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
+ ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
+ ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.");
+ if (argc == 1) {
+ std::cerr <<
+ "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"
+ "Please cite:\n"
+ "@inproceedings{kenlm,\n"
+ "author = {Kenneth Heafield},\n"
+ "title = {{KenLM}: Faster and Smaller Language Model Queries},\n"
+ "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n"
+ "month = {July}, year={2011},\n"
+ "address = {Edinburgh, UK},\n"
+ "publisher = {Association for Computational Linguistics},\n"
+ "}\n\n"
+ "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n"
+ "the model (-o) is the only mandatory option. As this is an on-disk program,\n"
+ "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n"
+ "Memory sizes are specified like GNU sort: a number followed by a unit character.\n"
+ "Valid units are \% for percentage of memory (supported platforms only) and (in\n"
+ "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n";
+ std::cerr << options << std::endl;
+ return 1;
+ }
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, options), vm);
+ po::notify(vm);
+
+ util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
+
+ lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;
+ // TODO: evaluate options for these.
+ initial.adder_in.total_memory = 32768;
+ initial.adder_in.block_count = 2;
+ initial.adder_out.total_memory = 32768;
+ initial.adder_out.block_count = 2;
+ pipeline.read_backoffs = initial.adder_out;
+
+ // Read from stdin
+ try {
+ lm::builder::Pipeline(pipeline, 0, 1);
+ } catch (const util::MallocException &e) {
+ std::cerr << e.what() << std::endl;
+ std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl;
+ return 1;
+ }
+ util::PrintUsage(std::cerr);
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
+ }
+}
diff --git a/klm/lm/builder/multi_stream.hh b/klm/lm/builder/multi_stream.hh
new file mode 100644
index 00000000..707a98c7
--- /dev/null
+++ b/klm/lm/builder/multi_stream.hh
@@ -0,0 +1,180 @@
+#ifndef LM_BUILDER_MULTI_STREAM__
+#define LM_BUILDER_MULTI_STREAM__
+
+#include "lm/builder/ngram_stream.hh"
+#include "util/scoped.hh"
+#include "util/stream/chain.hh"
+
+#include <cstddef>
+#include <new>
+
+#include <assert.h>
+#include <stdlib.h>
+
+namespace lm { namespace builder {
+
+template <class T> class FixedArray {
+ public:
+ explicit FixedArray(std::size_t count) {
+ Init(count);
+ }
+
+ FixedArray() : newed_end_(NULL) {}
+
+ void Init(std::size_t count) {
+ assert(!block_.get());
+ block_.reset(malloc(sizeof(T) * count));
+ if (!block_.get()) throw std::bad_alloc();
+ newed_end_ = begin();
+ }
+
+ FixedArray(const FixedArray &from) {
+ std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get());
+ Init(size);
+ for (std::size_t i = 0; i < size; ++i) {
+ new(end()) T(from[i]);
+ Constructed();
+ }
+ }
+
+ ~FixedArray() { clear(); }
+
+ T *begin() { return static_cast<T*>(block_.get()); }
+ const T *begin() const { return static_cast<const T*>(block_.get()); }
+ // Always call Constructed after successful completion of new.
+ T *end() { return newed_end_; }
+ const T *end() const { return newed_end_; }
+
+ T &back() { return *(end() - 1); }
+ const T &back() const { return *(end() - 1); }
+
+ std::size_t size() const { return end() - begin(); }
+ bool empty() const { return begin() == end(); }
+
+ T &operator[](std::size_t i) { return begin()[i]; }
+ const T &operator[](std::size_t i) const { return begin()[i]; }
+
+ template <class C> void push_back(const C &c) {
+ new (end()) T(c);
+ Constructed();
+ }
+
+ void clear() {
+ for (T *i = begin(); i != end(); ++i)
+ i->~T();
+ newed_end_ = begin();
+ }
+
+ protected:
+ void Constructed() {
+ ++newed_end_;
+ }
+
+ private:
+ util::scoped_malloc block_;
+
+ T *newed_end_;
+};
+
+class Chains;
+
+class ChainPositions : public FixedArray<util::stream::ChainPosition> {
+ public:
+ ChainPositions() {}
+
+ void Init(Chains &chains);
+
+ explicit ChainPositions(Chains &chains) {
+ Init(chains);
+ }
+};
+
+class Chains : public FixedArray<util::stream::Chain> {
+ private:
+ template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun {
+ typedef Chains type;
+ };
+
+ public:
+ explicit Chains(std::size_t limit) : FixedArray<util::stream::Chain>(limit) {}
+
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) {
+ threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
+ return *this;
+ }
+
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) {
+ threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
+ return *this;
+ }
+
+ Chains &operator>>(const util::stream::Recycler &recycler) {
+ for (util::stream::Chain *i = begin(); i != end(); ++i)
+ *i >> recycler;
+ return *this;
+ }
+
+ void Wait(bool release_memory = true) {
+ threads_.clear();
+ for (util::stream::Chain *i = begin(); i != end(); ++i) {
+ i->Wait(release_memory);
+ }
+ }
+
+ private:
+ boost::ptr_vector<util::stream::Thread> threads_;
+
+ Chains(const Chains &);
+ void operator=(const Chains &);
+};
+
+inline void ChainPositions::Init(Chains &chains) {
+ FixedArray<util::stream::ChainPosition>::Init(chains.size());
+ for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) {
+ new (end()) util::stream::ChainPosition(i->Add()); Constructed();
+ }
+}
+
+inline Chains &operator>>(Chains &chains, ChainPositions &positions) {
+ positions.Init(chains);
+ return chains;
+}
+
+class NGramStreams : public FixedArray<NGramStream> {
+ public:
+ NGramStreams() {}
+
+ // This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning).
+ void InitWithDummy(const ChainPositions &positions) {
+ FixedArray<NGramStream>::Init(positions.size() + 1);
+ new (end()) NGramStream(); Constructed();
+ for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) {
+ push_back(*i);
+ }
+ }
+
+ // Limit restricts to positions[0,limit)
+ void Init(const ChainPositions &positions, std::size_t limit) {
+ FixedArray<NGramStream>::Init(limit);
+ for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) {
+ push_back(*i);
+ }
+ }
+ void Init(const ChainPositions &positions) {
+ Init(positions, positions.size());
+ }
+
+ NGramStreams(const ChainPositions &positions) {
+ Init(positions);
+ }
+};
+
+inline Chains &operator>>(Chains &chains, NGramStreams &streams) {
+ ChainPositions positions;
+ chains >> positions;
+ streams.Init(positions);
+ return chains;
+}
+
+}} // namespaces
+#endif // LM_BUILDER_MULTI_STREAM__
diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh
new file mode 100644
index 00000000..2984ed0b
--- /dev/null
+++ b/klm/lm/builder/ngram.hh
@@ -0,0 +1,84 @@
+#ifndef LM_BUILDER_NGRAM__
+#define LM_BUILDER_NGRAM__
+
+#include "lm/weights.hh"
+#include "lm/word_index.hh"
+
+#include <cstddef>
+
+#include <assert.h>
+#include <stdint.h>
+#include <string.h>
+
+namespace lm {
+namespace builder {
+
+struct Uninterpolated {
+ float prob; // Uninterpolated probability.
+ float gamma; // Interpolation weight for lower order.
+};
+
+union Payload {
+ uint64_t count;
+ Uninterpolated uninterp;
+ ProbBackoff complete;
+};
+
+class NGram {
+ public:
+ NGram(void *begin, std::size_t order)
+ : begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {}
+
+ const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); }
+ uint8_t *Base() { return reinterpret_cast<uint8_t*>(begin_); }
+
+ void ReBase(void *to) {
+ std::size_t difference = end_ - begin_;
+ begin_ = reinterpret_cast<WordIndex*>(to);
+ end_ = begin_ + difference;
+ }
+
+ // Would do operator++ but that can get confusing for a stream.
+ void NextInMemory() {
+ ReBase(&Value() + 1);
+ }
+
+ // Lower-case in deference to STL.
+ const WordIndex *begin() const { return begin_; }
+ WordIndex *begin() { return begin_; }
+ const WordIndex *end() const { return end_; }
+ WordIndex *end() { return end_; }
+
+ const Payload &Value() const { return *reinterpret_cast<const Payload *>(end_); }
+ Payload &Value() { return *reinterpret_cast<Payload *>(end_); }
+
+ uint64_t &Count() { return Value().count; }
+ const uint64_t Count() const { return Value().count; }
+
+ std::size_t Order() const { return end_ - begin_; }
+
+ static std::size_t TotalSize(std::size_t order) {
+ return order * sizeof(WordIndex) + sizeof(Payload);
+ }
+ std::size_t TotalSize() const {
+ // Compiler should optimize this.
+ return TotalSize(Order());
+ }
+ static std::size_t OrderFromSize(std::size_t size) {
+ std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex);
+ assert(size == TotalSize(ret));
+ return ret;
+ }
+
+ private:
+ WordIndex *begin_, *end_;
+};
+
+const WordIndex kUNK = 0;
+const WordIndex kBOS = 1;
+const WordIndex kEOS = 2;
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_NGRAM__
diff --git a/klm/lm/builder/ngram_stream.hh b/klm/lm/builder/ngram_stream.hh
new file mode 100644
index 00000000..3c994664
--- /dev/null
+++ b/klm/lm/builder/ngram_stream.hh
@@ -0,0 +1,55 @@
+#ifndef LM_BUILDER_NGRAM_STREAM__
+#define LM_BUILDER_NGRAM_STREAM__
+
+#include "lm/builder/ngram.hh"
+#include "util/stream/chain.hh"
+#include "util/stream/stream.hh"
+
+#include <cstddef>
+
+namespace lm { namespace builder {
+
+class NGramStream {
+ public:
+ NGramStream() : gram_(NULL, 0) {}
+
+ NGramStream(const util::stream::ChainPosition &position) : gram_(NULL, 0) {
+ Init(position);
+ }
+
+ void Init(const util::stream::ChainPosition &position) {
+ stream_.Init(position);
+ gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize()));
+ }
+
+ NGram &operator*() { return gram_; }
+ const NGram &operator*() const { return gram_; }
+
+ NGram *operator->() { return &gram_; }
+ const NGram *operator->() const { return &gram_; }
+
+ void *Get() { return stream_.Get(); }
+ const void *Get() const { return stream_.Get(); }
+
+ operator bool() const { return stream_; }
+ bool operator!() const { return !stream_; }
+ void Poison() { stream_.Poison(); }
+
+ NGramStream &operator++() {
+ ++stream_;
+ gram_.ReBase(stream_.Get());
+ return *this;
+ }
+
+ private:
+ NGram gram_;
+ util::stream::Stream stream_;
+};
+
+inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) {
+ str.Init(chain.Add());
+ return chain;
+}
+
+}} // namespaces
+#endif // LM_BUILDER_NGRAM_STREAM__
diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc
new file mode 100644
index 00000000..14a1f721
--- /dev/null
+++ b/klm/lm/builder/pipeline.cc
@@ -0,0 +1,320 @@
+#include "lm/builder/pipeline.hh"
+
+#include "lm/builder/adjust_counts.hh"
+#include "lm/builder/corpus_count.hh"
+#include "lm/builder/initial_probabilities.hh"
+#include "lm/builder/interpolate.hh"
+#include "lm/builder/print.hh"
+#include "lm/builder/sort.hh"
+
+#include "lm/sizes.hh"
+
+#include "util/exception.hh"
+#include "util/file.hh"
+#include "util/stream/io.hh"
+
+#include <algorithm>
+#include <iostream>
+#include <vector>
+
+namespace lm { namespace builder {
+
+namespace {
+void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts) {
+ std::cerr << "Statistics:\n";
+ for (size_t i = 0; i < counts.size(); ++i) {
+ std::cerr << (i + 1) << ' ' << counts[i];
+ for (size_t d = 1; d <= 3; ++d)
+ std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d];
+ std::cerr << '\n';
+ }
+}
+
+class Master {
+ public:
+ explicit Master(const PipelineConfig &config)
+ : config_(config), chains_(config.order), files_(config.order) {
+ config_.minimum_block = std::max(NGram::TotalSize(config_.order), config_.minimum_block);
+ }
+
+ const PipelineConfig &Config() const { return config_; }
+
+ Chains &MutableChains() { return chains_; }
+
+ template <class T> Master &operator>>(const T &worker) {
+ chains_ >> worker;
+ return *this;
+ }
+
+ // This takes the (partially) sorted ngrams and sets up for adjusted counts.
+ void InitForAdjust(util::stream::Sort<SuffixOrder, AddCombiner> &ngrams, WordIndex types) {
+ const std::size_t each_order_min = config_.minimum_block * config_.block_count;
+ // We know how many unigrams there are. Don't allocate more than needed to them.
+ const std::size_t min_chains = (config_.order - 1) * each_order_min +
+ std::min(types * NGram::TotalSize(1), each_order_min);
+ // Do merge sort with calculated laziness.
+ const std::size_t merge_using = ngrams.Merge(std::min(config_.TotalMemory() - min_chains, ngrams.DefaultLazy()));
+
+ std::vector<uint64_t> count_bounds(1, types);
+ CreateChains(config_.TotalMemory() - merge_using, count_bounds);
+ ngrams.Output(chains_.back(), merge_using);
+
+ // Setup unigram file.
+ files_.push_back(util::MakeTemp(config_.TempPrefix()));
+ }
+
+ // For initial probabilities, but this is generic.
+ void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, Chains &second, util::stream::ChainConfig second_config) {
+ // Do merge first before allocating chain memory.
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts[i - 1].Merge(0);
+ }
+ // There's no lazy merge, so just divide memory amongst the chains.
+ CreateChains(config_.TotalMemory(), counts);
+ chains_.back().ActivateProgress();
+ chains_[0] >> files_[0].Source();
+ second_config.entry_size = NGram::TotalSize(1);
+ second.push_back(second_config);
+ second.back() >> files_[0].Source();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ util::scoped_fd fd(sorts[i - 1].StealCompleted());
+ chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get()));
+ chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true);
+ second_config.entry_size = NGram::TotalSize(i + 1);
+ second.push_back(second_config);
+ second.back() >> util::stream::PRead(fd.release(), true);
+ }
+ }
+
+ // There is no sort after this, so go for broke on lazy merging.
+ template <class Compare> void MaximumLazyInput(const std::vector<uint64_t> &counts, Sorts<Compare> &sorts) {
+ // Determine the minimum we can use for all the chains.
+ std::size_t min_chains = 0;
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast<uint64_t>(config_.minimum_block));
+ }
+ std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains);
+ std::vector<std::size_t> laziness;
+ // Prioritize longer n-grams.
+ for (util::stream::Sort<SuffixOrder> *i = sorts.end() - 1; i >= sorts.begin(); --i) {
+ laziness.push_back(i->Merge(for_merge));
+ assert(for_merge >= laziness.back());
+ for_merge -= laziness.back();
+ }
+ std::reverse(laziness.begin(), laziness.end());
+
+ CreateChains(for_merge + min_chains, counts);
+ chains_.back().ActivateProgress();
+ chains_[0] >> files_[0].Source();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts[i - 1].Output(chains_[i], laziness[i - 1]);
+ }
+ }
+
+ void BufferFinal(const std::vector<uint64_t> &counts) {
+ chains_[0] >> files_[0].Sink();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ files_.push_back(util::MakeTemp(config_.TempPrefix()));
+ chains_[i] >> files_[i].Sink();
+ }
+ chains_.Wait(true);
+ // Use less memory. Because we can.
+ CreateChains(std::min(config_.sort.buffer_size * config_.order, config_.TotalMemory()), counts);
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ chains_[i] >> files_[i].Source();
+ }
+ }
+
+ template <class Compare> void SetupSorts(Sorts<Compare> &sorts) {
+ sorts.Init(config_.order - 1);
+ // Unigrams don't get sorted because their order is always the same.
+ chains_[0] >> files_[0].Sink();
+ for (std::size_t i = 1; i < config_.order; ++i) {
+ sorts.push_back(chains_[i], config_.sort, Compare(i + 1));
+ }
+ chains_.Wait(true);
+ }
+
+ private:
+ // Create chains, allocating memory to them. Totally heuristic. Count
+ // bounds are upper bounds on the counts or not present.
+ void CreateChains(std::size_t remaining_mem, const std::vector<uint64_t> &count_bounds) {
+ std::vector<std::size_t> assignments;
+ assignments.reserve(config_.order);
+ // Start by assigning maximum memory usage (to be refined later).
+ for (std::size_t i = 0; i < count_bounds.size(); ++i) {
+ assignments.push_back(static_cast<std::size_t>(std::min(
+ static_cast<uint64_t>(remaining_mem),
+ count_bounds[i] * static_cast<uint64_t>(NGram::TotalSize(i + 1)))));
+ }
+ assignments.resize(config_.order, remaining_mem);
+
+ // Now we know how much memory everybody wants. How much will they get?
+ // Proportional to this.
+ std::vector<float> portions;
+ // Indices of orders that have yet to be assigned.
+ std::vector<std::size_t> unassigned;
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ portions.push_back(static_cast<float>((i+1) * NGram::TotalSize(i+1)));
+ unassigned.push_back(i);
+ }
+ /*If somebody doesn't eat their full dinner, give it to the rest of the
+ * family. Then somebody else might not eat their full dinner etc. Ends
+ * when everybody unassigned is hungry.
+ */
+ float sum;
+ bool found_more;
+ std::vector<std::size_t> block_count(config_.order);
+ do {
+ sum = 0.0;
+ for (std::size_t i = 0; i < unassigned.size(); ++i) {
+ sum += portions[unassigned[i]];
+ }
+ found_more = false;
+ // If the proportional assignment is more than needed, give it just what it needs.
+ for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end();) {
+ if (assignments[*i] <= remaining_mem * (portions[*i] / sum)) {
+ remaining_mem -= assignments[*i];
+ block_count[*i] = 1;
+ i = unassigned.erase(i);
+ found_more = true;
+ } else {
+ ++i;
+ }
+ }
+ } while (found_more);
+ for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end(); ++i) {
+ assignments[*i] = remaining_mem * (portions[*i] / sum);
+ block_count[*i] = config_.block_count;
+ }
+ chains_.clear();
+ std::cerr << "Chain sizes:";
+ for (std::size_t i = 0; i < config_.order; ++i) {
+ std::cerr << ' ' << (i+1) << ":" << assignments[i];
+ chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i]));
+ }
+ std::cerr << std::endl;
+ }
+
+ PipelineConfig config_;
+
+ Chains chains_;
+ // Often only unigrams, but sometimes all orders.
+ FixedArray<util::stream::FileBuffer> files_;
+};
+
+void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) {
+ const PipelineConfig &config = master.Config();
+ std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl;
+
+ UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory());
+ std::size_t memory_for_chain =
+ // This much memory to work with after vocab hash table.
+ static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) /
+ // Solve for block size including the dedupe multiplier for one block.
+ (static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *
+ // Chain likes memory expressed in terms of total memory.
+ static_cast<float>(config.block_count);
+ util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain));
+
+ WordIndex type_count;
+ util::FilePiece text(text_file, NULL, &std::cerr);
+ text_file_name = text.FileName();
+ CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ chain >> boost::ref(counter);
+
+ util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
+ chain.Wait(true);
+ std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl;
+ master.InitForAdjust(sorter, type_count);
+}
+
+void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+ const PipelineConfig &config = master.Config();
+ Chains second(config.order);
+
+ {
+ Sorts<ContextOrder> sorts;
+ master.SetupSorts(sorts);
+ PrintStatistics(counts, discounts);
+ lm::ngram::ShowSizes(counts);
+ std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl;
+ master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in);
+ }
+
+ Chains gamma_chains(config.order);
+ InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains);
+ // Don't care about gamma for 0.
+ gamma_chains[0] >> util::stream::kRecycle;
+ gammas.Init(config.order - 1);
+ for (std::size_t i = 1; i < config.order; ++i) {
+ gammas.push_back(util::MakeTemp(config.TempPrefix()));
+ gamma_chains[i] >> gammas[i - 1].Sink();
+ }
+ // Has to be done here due to gamma_chains scope.
+ master.SetupSorts(primary);
+}
+
+void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+ std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl;
+ const PipelineConfig &config = master.Config();
+ master.MaximumLazyInput(counts, primary);
+
+ Chains gamma_chains(config.order - 1);
+ util::stream::ChainConfig read_backoffs(config.read_backoffs);
+ read_backoffs.entry_size = sizeof(float);
+ for (std::size_t i = 0; i < config.order - 1; ++i) {
+ gamma_chains.push_back(read_backoffs);
+ gamma_chains.back() >> gammas[i].Source();
+ }
+ master >> Interpolate(counts[0], ChainPositions(gamma_chains));
+ gamma_chains >> util::stream::kRecycle;
+ master.BufferFinal(counts);
+}
+
+} // namespace
+
+void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
+ // Some fail-fast sanity checks.
+ if (config.sort.buffer_size * 4 > config.TotalMemory()) {
+ config.sort.buffer_size = config.TotalMemory() / 4;
+ std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl;
+ }
+ if (config.minimum_block < NGram::TotalSize(config.order)) {
+ config.minimum_block = NGram::TotalSize(config.order);
+ std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl;
+ }
+ UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << ".");
+ UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception,
+ "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size.");
+
+ UTIL_TIMER("(%w s) Total wall time elapsed\n");
+ Master master(config);
+
+ util::scoped_fd vocab_file(config.vocab_file.empty() ?
+ util::MakeTemp(config.TempPrefix()) :
+ util::CreateOrThrow(config.vocab_file.c_str()));
+ uint64_t token_count;
+ std::string text_file_name;
+ CountText(text_file, vocab_file.get(), master, token_count, text_file_name);
+
+ std::vector<uint64_t> counts;
+ std::vector<Discount> discounts;
+ master >> AdjustCounts(counts, discounts);
+
+ {
+ FixedArray<util::stream::FileBuffer> gammas;
+ Sorts<SuffixOrder> primary;
+ InitialProbabilities(counts, discounts, master, primary, gammas);
+ InterpolateProbabilities(counts, master, primary, gammas);
+ }
+
+ std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
+ VocabReconstitute vocab(vocab_file.get());
+ UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?");
+ HeaderInfo header_info(text_file_name, token_count);
+ master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle;
+ master.MutableChains().Wait(true);
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh
new file mode 100644
index 00000000..f1d6c5f6
--- /dev/null
+++ b/klm/lm/builder/pipeline.hh
@@ -0,0 +1,40 @@
+#ifndef LM_BUILDER_PIPELINE__
+#define LM_BUILDER_PIPELINE__
+
+#include "lm/builder/initial_probabilities.hh"
+#include "lm/builder/header_info.hh"
+#include "util/stream/config.hh"
+#include "util/file_piece.hh"
+
+#include <string>
+#include <cstddef>
+
+namespace lm { namespace builder {
+
+struct PipelineConfig {
+ std::size_t order;
+ std::string vocab_file;
+ util::stream::SortConfig sort;
+ InitialProbabilitiesConfig initial_probs;
+ util::stream::ChainConfig read_backoffs;
+ bool verbose_header;
+
+ // Amount of memory to assume that the vocabulary hash table will use. This
+ // is subtracted from total memory for CorpusCount.
+ std::size_t assume_vocab_hash_size;
+
+ // Minimum block size to tolerate.
+ std::size_t minimum_block;
+
+ // Number of blocks to use. This will be overridden to 1 if everything fits.
+ std::size_t block_count;
+
+ const std::string &TempPrefix() const { return sort.temp_prefix; }
+ std::size_t TotalMemory() const { return sort.total_memory; }
+};
+
+// Takes ownership of text_file.
+void Pipeline(PipelineConfig config, int text_file, int out_arpa);
+
+}} // namespaces
+#endif // LM_BUILDER_PIPELINE__
diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc
new file mode 100644
index 00000000..b0323221
--- /dev/null
+++ b/klm/lm/builder/print.cc
@@ -0,0 +1,135 @@
+#include "lm/builder/print.hh"
+
+#include "util/double-conversion/double-conversion.h"
+#include "util/double-conversion/utils.h"
+#include "util/file.hh"
+#include "util/mmap.hh"
+#include "util/scoped.hh"
+#include "util/stream/timer.hh"
+
+#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE
+#include <boost/lexical_cast.hpp>
+
+#include <sstream>
+
+#include <string.h>
+
+namespace lm { namespace builder {
+
+VocabReconstitute::VocabReconstitute(int fd) {
+ uint64_t size = util::SizeOrThrow(fd);
+ util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_);
+ const char *const start = static_cast<const char*>(memory_.get());
+ const char *i;
+ for (i = start; i != start + size; i += strlen(i) + 1) {
+ map_.push_back(i);
+ }
+ // Last one for LookupPiece.
+ map_.push_back(i);
+}
+
+namespace {
+class OutputManager {
+ public:
+ static const std::size_t kOutBuf = 1048576;
+
+ // Does not take ownership of out.
+ explicit OutputManager(int out)
+ : buf_(util::MallocOrThrow(kOutBuf)),
+ builder_(static_cast<char*>(buf_.get()), kOutBuf),
+ // Mostly the default but with inf instead. And no flags.
+ convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0),
+ fd_(out) {}
+
+ ~OutputManager() {
+ Flush();
+ }
+
+ OutputManager &operator<<(float value) {
+ // Odd, but this is the largest number found in the comments.
+ EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
+ convert_.ToShortestSingle(value, &builder_);
+ return *this;
+ }
+
+ OutputManager &operator<<(StringPiece str) {
+ if (str.size() > kOutBuf) {
+ Flush();
+ util::WriteOrThrow(fd_, str.data(), str.size());
+ } else {
+ EnsureRemaining(str.size());
+ builder_.AddSubstring(str.data(), str.size());
+ }
+ return *this;
+ }
+
+ // Inefficient!
+ OutputManager &operator<<(unsigned val) {
+ return *this << boost::lexical_cast<std::string>(val);
+ }
+
+ OutputManager &operator<<(char c) {
+ EnsureRemaining(1);
+ builder_.AddCharacter(c);
+ return *this;
+ }
+
+ void Flush() {
+ util::WriteOrThrow(fd_, buf_.get(), builder_.position());
+ builder_.Reset();
+ }
+
+ private:
+ void EnsureRemaining(std::size_t amount) {
+ if (static_cast<std::size_t>(builder_.size() - builder_.position()) < amount) {
+ Flush();
+ }
+ }
+
+ util::scoped_malloc buf_;
+ double_conversion::StringBuilder builder_;
+ double_conversion::DoubleToStringConverter convert_;
+ int fd_;
+};
+} // namespace
+
+PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd)
+ : vocab_(vocab), out_fd_(out_fd) {
+ std::stringstream stream;
+
+ if (header_info) {
+ stream << "# Input file: " << header_info->input_file << '\n';
+ stream << "# Token count: " << header_info->token_count << '\n';
+ stream << "# Smoothing: Modified Kneser-Ney" << '\n';
+ }
+ stream << "\\data\\\n";
+ for (size_t i = 0; i < counts.size(); ++i) {
+ stream << "ngram " << (i+1) << '=' << counts[i] << '\n';
+ }
+ stream << '\n';
+ std::string as_string(stream.str());
+ util::WriteOrThrow(out_fd, as_string.data(), as_string.size());
+}
+
+void PrintARPA::Run(const ChainPositions &positions) {
+ UTIL_TIMER("(%w s) Wrote ARPA file\n");
+ OutputManager out(out_fd_);
+ for (unsigned order = 1; order <= positions.size(); ++order) {
+ out << "\\" << order << "-grams:" << '\n';
+ for (NGramStream stream(positions[order - 1]); stream; ++stream) {
+ // Correcting for numerical precision issues. Take that IRST.
+ out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin());
+ for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
+ out << ' ' << vocab_.Lookup(*i);
+ }
+ float backoff = stream->Value().complete.backoff;
+ if (backoff != 0.0)
+ out << '\t' << backoff;
+ out << '\n';
+ }
+ out << '\n';
+ }
+ out << "\\end\\\n";
+}
+
+}} // namespaces
diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh
new file mode 100644
index 00000000..aa932e75
--- /dev/null
+++ b/klm/lm/builder/print.hh
@@ -0,0 +1,102 @@
+#ifndef LM_BUILDER_PRINT__
+#define LM_BUILDER_PRINT__
+
+#include "lm/builder/ngram.hh"
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/header_info.hh"
+#include "util/file.hh"
+#include "util/mmap.hh"
+#include "util/string_piece.hh"
+
+#include <ostream>
+
+#include <assert.h>
+
+// Warning: print routines read all unigrams before all bigrams before all
+// trigrams etc. So if other parts of the chain move jointly, you'll have to
+// buffer.
+
+namespace lm { namespace builder {
+
+class VocabReconstitute {
+ public:
+ // fd must be alive for life of this object; does not take ownership.
+ explicit VocabReconstitute(int fd);
+
+ const char *Lookup(WordIndex index) const {
+ assert(index < map_.size() - 1);
+ return map_[index];
+ }
+
+ StringPiece LookupPiece(WordIndex index) const {
+ return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]);
+ }
+
+ std::size_t Size() const {
+ // There's an extra entry to support StringPiece lengths.
+ return map_.size() - 1;
+ }
+
+ private:
+ util::scoped_memory memory_;
+ std::vector<const char*> map_;
+};
+
+// Not defined, only specialized.
+template <class T> void PrintPayload(std::ostream &to, const Payload &payload);
+template <> inline void PrintPayload<uint64_t>(std::ostream &to, const Payload &payload) {
+ to << payload.count;
+}
+template <> inline void PrintPayload<Uninterpolated>(std::ostream &to, const Payload &payload) {
+ to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma);
+}
+template <> inline void PrintPayload<ProbBackoff>(std::ostream &to, const Payload &payload) {
+ to << payload.complete.prob << ' ' << payload.complete.backoff;
+}
+
+// template parameter is the type stored.
+template <class V> class Print {
+ public:
+ explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {}
+
+ void Run(const ChainPositions &chains) {
+ NGramStreams streams(chains);
+ for (NGramStream *s = streams.begin(); s != streams.end(); ++s) {
+ DumpStream(*s);
+ }
+ }
+
+ void Run(const util::stream::ChainPosition &position) {
+ NGramStream stream(position);
+ DumpStream(stream);
+ }
+
+ private:
+ void DumpStream(NGramStream &stream) {
+ for (; stream; ++stream) {
+ PrintPayload<V>(to_, stream->Value());
+ for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) {
+ to_ << ' ' << vocab_.Lookup(*w) << '=' << *w;
+ }
+ to_ << '\n';
+ }
+ }
+
+ const VocabReconstitute &vocab_;
+ std::ostream &to_;
+};
+
+class PrintARPA {
+ public:
+ // header_info may be NULL to disable the header
+ explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd);
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ const VocabReconstitute &vocab_;
+ int out_fd_;
+};
+
+}} // namespaces
+#endif // LM_BUILDER_PRINT__
diff --git a/klm/lm/builder/sort.hh b/klm/lm/builder/sort.hh
new file mode 100644
index 00000000..9989389b
--- /dev/null
+++ b/klm/lm/builder/sort.hh
@@ -0,0 +1,103 @@
+#ifndef LM_BUILDER_SORT__
+#define LM_BUILDER_SORT__
+
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/ngram.hh"
+#include "lm/word_index.hh"
+#include "util/stream/sort.hh"
+
+#include "util/stream/timer.hh"
+
+#include <functional>
+#include <string>
+
+namespace lm {
+namespace builder {
+
+template <class Child> class Comparator : public std::binary_function<const void *, const void *, bool> {
+ public:
+ explicit Comparator(std::size_t order) : order_(order) {}
+
+ inline bool operator()(const void *lhs, const void *rhs) const {
+ return static_cast<const Child*>(this)->Compare(static_cast<const WordIndex*>(lhs), static_cast<const WordIndex*>(rhs));
+ }
+
+ std::size_t Order() const { return order_; }
+
+ protected:
+ std::size_t order_;
+};
+
+class SuffixOrder : public Comparator<SuffixOrder> {
+ public:
+ explicit SuffixOrder(std::size_t order) : Comparator<SuffixOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (std::size_t i = order_ - 1; i != 0; --i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return lhs[0] < rhs[0];
+ }
+
+ static const unsigned kMatchOffset = 1;
+};
+
+class ContextOrder : public Comparator<ContextOrder> {
+ public:
+ explicit ContextOrder(std::size_t order) : Comparator<ContextOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (int i = order_ - 2; i >= 0; --i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return lhs[order_ - 1] < rhs[order_ - 1];
+ }
+};
+
+class PrefixOrder : public Comparator<PrefixOrder> {
+ public:
+ explicit PrefixOrder(std::size_t order) : Comparator<PrefixOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (std::size_t i = 0; i < order_; ++i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return false;
+ }
+
+ static const unsigned kMatchOffset = 0;
+};
+
+// Sum counts for the same n-gram.
+struct AddCombiner {
+ bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const {
+ NGram first(first_void, compare.Order());
+ // There isn't a const version of NGram.
+ NGram second(const_cast<void*>(second_void), compare.Order());
+ if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false;
+ first.Count() += second.Count();
+ return true;
+ }
+};
+
+// The combiner is only used on a single chain, so I didn't bother to allow
+// that template.
+template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > {
+ private:
+ typedef util::stream::Sort<Compare> S;
+ typedef FixedArray<S> P;
+
+ public:
+ void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) {
+ new (P::end()) S(chain, config, compare);
+ P::Constructed();
+ }
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_SORT__
diff --git a/klm/lm/filter/arpa_io.cc b/klm/lm/filter/arpa_io.cc
new file mode 100644
index 00000000..caf8df95
--- /dev/null
+++ b/klm/lm/filter/arpa_io.cc
@@ -0,0 +1,122 @@
+#include "lm/filter/arpa_io.hh"
+#include "util/file_piece.hh"
+
+#include <iostream>
+#include <ostream>
+#include <string>
+#include <vector>
+
+#include <ctype.h>
+#include <errno.h>
+#include <string.h>
+
+namespace lm {
+
+ARPAInputException::ARPAInputException(const StringPiece &message) throw() : what_("Error: ") {
+ what_.append(message.data(), message.size());
+}
+
+ARPAInputException::ARPAInputException(const StringPiece &message, const StringPiece &line) throw() {
+ what_ = "Error: ";
+ what_.append(message.data(), message.size());
+ what_ += " in line '";
+ what_.append(line.data(), line.size());
+ what_ += "'.";
+}
+
+ARPAOutputException::ARPAOutputException(const char *message, const std::string &file_name) throw()
+ : what_(std::string(message) + " file " + file_name), file_name_(file_name) {
+ if (errno) {
+ char buf[1024];
+ buf[0] = 0;
+#if (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600) && ! _GNU_SOURCE
+ const char *add = buf;
+ if (!strerror_r(errno, buf, 1024)) {
+#else
+ const char *add = strerror_r(errno, buf, 1024);
+ if (add) {
+#endif
+ what_ += " :";
+ what_ += add;
+ }
+ }
+}
+
+// Seeking is the responsibility of the caller.
+void WriteCounts(std::ostream &out, const std::vector<size_t> &number) {
+ out << "\n\\data\\\n";
+ for (unsigned int i = 0; i < number.size(); ++i) {
+ out << "ngram " << i+1 << "=" << number[i] << '\n';
+ }
+ out << '\n';
+}
+
+size_t SizeNeededForCounts(const std::vector<size_t> &number) {
+ std::ostringstream buf;
+ WriteCounts(buf, number);
+ return buf.tellp();
+}
+
+bool IsEntirelyWhiteSpace(const StringPiece &line) {
+ for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) {
+ if (!isspace(line.data()[i])) return false;
+ }
+ return true;
+}
+
+ARPAOutput::ARPAOutput(const char *name, size_t buffer_size) : file_name_(name), buffer_(new char[buffer_size]) {
+ try {
+ file_.exceptions(std::ostream::eofbit | std::ostream::failbit | std::ostream::badbit);
+ if (!file_.rdbuf()->pubsetbuf(buffer_.get(), buffer_size)) {
+ std::cerr << "Warning: could not enlarge buffer for " << name << std::endl;
+ buffer_.reset();
+ }
+ file_.open(name, std::ios::out | std::ios::binary);
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Opening", file_name_);
+ }
+}
+
+void ARPAOutput::ReserveForCounts(std::streampos reserve) {
+ try {
+ for (std::streampos i = 0; i < reserve; i += std::streampos(1)) {
+ file_ << '\n';
+ }
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing blanks to reserve space for counts to ", file_name_);
+ }
+}
+
+void ARPAOutput::BeginLength(unsigned int length) {
+ fast_counter_ = 0;
+ try {
+ file_ << '\\' << length << "-grams:" << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing n-gram header to ", file_name_);
+ }
+}
+
+void ARPAOutput::EndLength(unsigned int length) {
+ try {
+ file_ << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing blank at end of count list to ", file_name_);
+ }
+ if (length > counts_.size()) {
+ counts_.resize(length);
+ }
+ counts_[length - 1] = fast_counter_;
+}
+
+void ARPAOutput::Finish() {
+ try {
+ file_ << "\\end\\\n";
+ file_.seekp(0);
+ WriteCounts(file_, counts_);
+ file_ << std::flush;
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Finishing including writing counts at beginning to ", file_name_);
+ }
+}
+
+} // namespace lm
diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh
new file mode 100644
index 00000000..90f48447
--- /dev/null
+++ b/klm/lm/filter/arpa_io.hh
@@ -0,0 +1,122 @@
+#ifndef LM_FILTER_ARPA_IO__
+#define LM_FILTER_ARPA_IO__
+/* Input and output for ARPA format language model files.
+ */
+#include "lm/read_arpa.hh"
+#include "util/exception.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/noncopyable.hpp>
+#include <boost/scoped_array.hpp>
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include <err.h>
+#include <string.h>
+
+namespace util { class FilePiece; }
+
+namespace lm {
+
+class ARPAInputException : public util::Exception {
+ public:
+ explicit ARPAInputException(const StringPiece &message) throw();
+ explicit ARPAInputException(const StringPiece &message, const StringPiece &line) throw();
+ virtual ~ARPAInputException() throw() {}
+
+ const char *what() const throw() { return what_.c_str(); }
+
+ private:
+ std::string what_;
+};
+
+class ARPAOutputException : public std::exception {
+ public:
+ ARPAOutputException(const char *prefix, const std::string &file_name) throw();
+ virtual ~ARPAOutputException() throw() {}
+
+ const char *what() const throw() { return what_.c_str(); }
+
+ const std::string &File() const throw() { return file_name_; }
+
+ private:
+ std::string what_;
+ const std::string file_name_;
+};
+
+// Handling for the counts of n-grams at the beginning of ARPA files.
+size_t SizeNeededForCounts(const std::vector<size_t> &number);
+
+/* Writes an ARPA file. This has to be seekable so the counts can be written
+ * at the end. Hence, I just have it own a std::fstream instead of accepting
+ * a separately held std::ostream.
+ */
+class ARPAOutput : boost::noncopyable {
+ public:
+ explicit ARPAOutput(const char *name, size_t buffer_size = 65536);
+
+ void ReserveForCounts(std::streampos reserve);
+
+ void BeginLength(unsigned int length);
+
+ void AddNGram(const StringPiece &line) {
+ try {
+ file_ << line << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing an n-gram", file_name_);
+ }
+ ++fast_counter_;
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ void EndLength(unsigned int length);
+
+ void Finish();
+
+ private:
+ const std::string file_name_;
+ boost::scoped_array<char> buffer_;
+ std::fstream file_;
+ size_t fast_counter_;
+ std::vector<size_t> counts_;
+};
+
+
+template <class Output> void ReadNGrams(util::FilePiece &in, unsigned int length, size_t number, Output &out) {
+ ReadNGramHeader(in, length);
+ out.BeginLength(length);
+ for (size_t i = 0; i < number; ++i) {
+ StringPiece line = in.ReadLine();
+ util::TokenIter<util::SingleCharacter> tabber(line, '\t');
+ if (!tabber) throw ARPAInputException("blank line", line);
+ if (!++tabber) throw ARPAInputException("no tab", line);
+
+ out.AddNGram(*tabber, line);
+ }
+ out.EndLength(length);
+}
+
+template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) {
+ std::vector<size_t> number;
+ ReadARPACounts(in_lm, number);
+ out.ReserveForCounts(SizeNeededForCounts(number));
+ for (unsigned int i = 0; i < number.size(); ++i) {
+ ReadNGrams(in_lm, i + 1, number[i], out);
+ }
+ ReadEnd(in_lm);
+ out.Finish();
+}
+
+} // namespace lm
+
+#endif // LM_FILTER_ARPA_IO__
diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh
new file mode 100644
index 00000000..97c0fa25
--- /dev/null
+++ b/klm/lm/filter/count_io.hh
@@ -0,0 +1,91 @@
+#ifndef LM_FILTER_COUNT_IO__
+#define LM_FILTER_COUNT_IO__
+
+#include <fstream>
+#include <iostream>
+#include <string>
+
+#include <err.h>
+
+#include "util/file_piece.hh"
+
+namespace lm {
+
+class CountOutput : boost::noncopyable {
+ public:
+ explicit CountOutput(const char *name) : file_(name, std::ios::out) {}
+
+ void AddNGram(const StringPiece &line) {
+ if (!(file_ << line << '\n')) {
+ err(3, "Writing counts file failed");
+ }
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ private:
+ std::fstream file_;
+};
+
+class CountBatch {
+ public:
+ explicit CountBatch(std::streamsize initial_read)
+ : initial_read_(initial_read) {
+ buffer_.reserve(initial_read);
+ }
+
+ void Read(std::istream &in) {
+ buffer_.resize(initial_read_);
+ in.read(&*buffer_.begin(), initial_read_);
+ buffer_.resize(in.gcount());
+ char got;
+ while (in.get(got) && got != '\n')
+ buffer_.push_back(got);
+ }
+
+ template <class Output> void Send(Output &out) {
+ for (util::TokenIter<util::SingleCharacter> line(StringPiece(&*buffer_.begin(), buffer_.size()), '\n'); line; ++line) {
+ util::TokenIter<util::SingleCharacter> tabber(*line, '\t');
+ if (!tabber) {
+ std::cerr << "Warning: empty n-gram count line being removed\n";
+ continue;
+ }
+ util::TokenIter<util::SingleCharacter, true> words(*tabber, ' ');
+ if (!words) {
+ std::cerr << "Line has a tab but no words.\n";
+ continue;
+ }
+ out.AddNGram(words, util::TokenIter<util::SingleCharacter, true>::end(), *line);
+ }
+ }
+
+ private:
+ std::streamsize initial_read_;
+
+ // This could have been a std::string but that's less happy with raw writes.
+ std::vector<char> buffer_;
+};
+
+template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) {
+ try {
+ while (true) {
+ StringPiece line = in_file.ReadLine();
+ util::TokenIter<util::SingleCharacter> tabber(line, '\t');
+ if (!tabber) {
+ std::cerr << "Warning: empty n-gram count line being removed\n";
+ continue;
+ }
+ out.AddNGram(*tabber, line);
+ }
+ } catch (const util::EndOfFileException &e) {}
+}
+
+} // namespace lm
+
+#endif // LM_FILTER_COUNT_IO__
diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh
new file mode 100644
index 00000000..7f945b0d
--- /dev/null
+++ b/klm/lm/filter/format.hh
@@ -0,0 +1,250 @@
+#ifndef LM_FILTER_FORMAT_H__
+#define LM_FITLER_FORMAT_H__
+
+#include "lm/filter/arpa_io.hh"
+#include "lm/filter/count_io.hh"
+
+#include <boost/lexical_cast.hpp>
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <iosfwd>
+
+namespace lm {
+
+template <class Single> class MultipleOutput {
+ private:
+ typedef boost::ptr_vector<Single> Singles;
+ typedef typename Singles::iterator SinglesIterator;
+
+ public:
+ MultipleOutput(const char *prefix, size_t number) {
+ files_.reserve(number);
+ std::string tmp;
+ for (unsigned int i = 0; i < number; ++i) {
+ tmp = prefix;
+ tmp += boost::lexical_cast<std::string>(i);
+ files_.push_back(new Single(tmp.c_str()));
+ }
+ }
+
+ void AddNGram(const StringPiece &line) {
+ for (SinglesIterator i = files_.begin(); i != files_.end(); ++i)
+ i->AddNGram(line);
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ for (SinglesIterator i = files_.begin(); i != files_.end(); ++i)
+ i->AddNGram(begin, end, line);
+ }
+
+ void SingleAddNGram(size_t offset, const StringPiece &line) {
+ files_[offset].AddNGram(line);
+ }
+
+ template <class Iterator> void SingleAddNGram(size_t offset, const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ files_[offset].AddNGram(begin, end, line);
+ }
+
+ protected:
+ Singles files_;
+};
+
+class MultipleARPAOutput : public MultipleOutput<ARPAOutput> {
+ public:
+ MultipleARPAOutput(const char *prefix, size_t number) : MultipleOutput<ARPAOutput>(prefix, number) {}
+
+ void ReserveForCounts(std::streampos reserve) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->ReserveForCounts(reserve);
+ }
+
+ void BeginLength(unsigned int length) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->BeginLength(length);
+ }
+
+ void EndLength(unsigned int length) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->EndLength(length);
+ }
+
+ void Finish() {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->Finish();
+ }
+};
+
+template <class Filter, class Output> class DispatchInput {
+ public:
+ DispatchInput(Filter &filter, Output &output) : filter_(filter), output_(output) {}
+
+/* template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ filter_.AddNGram(begin, end, line, output_);
+ }*/
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ filter_.AddNGram(ngram, line, output_);
+ }
+
+ protected:
+ Filter &filter_;
+ Output &output_;
+};
+
+template <class Filter, class Output> class DispatchARPAInput : public DispatchInput<Filter, Output> {
+ private:
+ typedef DispatchInput<Filter, Output> B;
+
+ public:
+ DispatchARPAInput(Filter &filter, Output &output) : B(filter, output) {}
+
+ void ReserveForCounts(std::streampos reserve) { B::output_.ReserveForCounts(reserve); }
+ void BeginLength(unsigned int length) { B::output_.BeginLength(length); }
+
+ void EndLength(unsigned int length) {
+ B::filter_.Flush();
+ B::output_.EndLength(length);
+ }
+ void Finish() { B::output_.Finish(); }
+};
+
+struct ARPAFormat {
+ typedef ARPAOutput Output;
+ typedef MultipleARPAOutput Multiple;
+ static void Copy(util::FilePiece &in, Output &out) {
+ ReadARPA(in, out);
+ }
+ template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) {
+ DispatchARPAInput<Filter, Out> dispatcher(filter, output);
+ ReadARPA(in, dispatcher);
+ }
+};
+
+struct CountFormat {
+ typedef CountOutput Output;
+ typedef MultipleOutput<Output> Multiple;
+ static void Copy(util::FilePiece &in, Output &out) {
+ ReadCount(in, out);
+ }
+ template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) {
+ DispatchInput<Filter, Out> dispatcher(filter, output);
+ ReadCount(in, dispatcher);
+ }
+};
+
+/* For multithreading, the buffer classes hold batches of filter inputs and
+ * outputs in memory. The strings get reused a lot, so keep them around
+ * instead of clearing each time.
+ */
+class InputBuffer {
+ public:
+ InputBuffer() : actual_(0) {}
+
+ void Reserve(size_t size) { lines_.reserve(size); }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ if (lines_.size() == actual_) lines_.resize(lines_.size() + 1);
+ // TODO avoid this copy.
+ std::string &copied = lines_[actual_].line;
+ copied.assign(line.data(), line.size());
+ lines_[actual_].ngram.set(copied.data() + (ngram.data() - line.data()), ngram.size());
+ ++actual_;
+ }
+
+ template <class Filter, class Output> void CallFilter(Filter &filter, Output &output) const {
+ for (std::vector<Line>::const_iterator i = lines_.begin(); i != lines_.begin() + actual_; ++i) {
+ filter.AddNGram(i->ngram, i->line, output);
+ }
+ }
+
+ void Clear() { actual_ = 0; }
+ bool Empty() { return actual_ == 0; }
+ size_t Size() { return actual_; }
+
+ private:
+ struct Line {
+ std::string line;
+ StringPiece ngram;
+ };
+
+ size_t actual_;
+
+ std::vector<Line> lines_;
+};
+
+class BinaryOutputBuffer {
+ public:
+ BinaryOutputBuffer() {}
+
+ void Reserve(size_t size) {
+ lines_.reserve(size);
+ }
+
+ void AddNGram(const StringPiece &line) {
+ lines_.push_back(line);
+ }
+
+ template <class Output> void Flush(Output &output) {
+ for (std::vector<StringPiece>::const_iterator i = lines_.begin(); i != lines_.end(); ++i) {
+ output.AddNGram(*i);
+ }
+ lines_.clear();
+ }
+
+ private:
+ std::vector<StringPiece> lines_;
+};
+
+class MultipleOutputBuffer {
+ public:
+ MultipleOutputBuffer() : last_(NULL) {}
+
+ void Reserve(size_t size) {
+ annotated_.reserve(size);
+ }
+
+ void AddNGram(const StringPiece &line) {
+ annotated_.resize(annotated_.size() + 1);
+ annotated_.back().line = line;
+ }
+
+ void SingleAddNGram(size_t offset, const StringPiece &line) {
+ if ((line.data() == last_.data()) && (line.length() == last_.length())) {
+ annotated_.back().systems.push_back(offset);
+ } else {
+ annotated_.resize(annotated_.size() + 1);
+ annotated_.back().systems.push_back(offset);
+ annotated_.back().line = line;
+ last_ = line;
+ }
+ }
+
+ template <class Output> void Flush(Output &output) {
+ for (std::vector<Annotated>::const_iterator i = annotated_.begin(); i != annotated_.end(); ++i) {
+ if (i->systems.empty()) {
+ output.AddNGram(i->line);
+ } else {
+ for (std::vector<size_t>::const_iterator j = i->systems.begin(); j != i->systems.end(); ++j) {
+ output.SingleAddNGram(*j, i->line);
+ }
+ }
+ }
+ annotated_.clear();
+ }
+
+ private:
+ struct Annotated {
+ // If this is empty, send to all systems.
+ // A filter should never send to all systems and send to a single one.
+ std::vector<size_t> systems;
+ StringPiece line;
+ };
+
+ StringPiece last_;
+
+ std::vector<Annotated> annotated_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_FORMAT_H__
diff --git a/klm/lm/filter/main.cc b/klm/lm/filter/main.cc
new file mode 100644
index 00000000..c42243e2
--- /dev/null
+++ b/klm/lm/filter/main.cc
@@ -0,0 +1,249 @@
+#include "lm/filter/arpa_io.hh"
+#include "lm/filter/format.hh"
+#include "lm/filter/phrase.hh"
+#ifndef NTHREAD
+#include "lm/filter/thread.hh"
+#endif
+#include "lm/filter/vocab.hh"
+#include "lm/filter/wrapper.hh"
+#include "util/file_piece.hh"
+
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include <memory>
+
+namespace lm {
+namespace {
+
+void DisplayHelp(const char *name) {
+ std::cerr
+ << "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n"
+ "copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n"
+ " parser.\n"
+ "single mode treats the entire input as a single sentence.\n"
+ "multiple mode filters to multiple sentences in parallel. Each sentence is on\n"
+ " a separate line. A separate file is created for each file by appending the\n"
+ " 0-indexed line number to the output file name.\n"
+ "union mode produces one filtered model that is the union of models created by\n"
+ " multiple mode.\n\n"
+ "context means only the context (all but last word) has to pass the filter, but\n"
+ " the entire n-gram is output.\n\n"
+ "phrase means that the vocabulary is actually tab-delimited phrases and that the\n"
+ " phrases can generate the n-gram when assembled in arbitrary order and\n"
+ " clipped. Currently works with multiple or union mode.\n\n"
+ "The file format is set by [raw|arpa] with default arpa:\n"
+ "raw means space-separated tokens, optionally followed by a tab and arbitrary\n"
+ " text. This is useful for ngram count files.\n"
+ "arpa means the ARPA file format for n-gram language models.\n\n"
+#ifndef NTHREAD
+ "threads:m sets m threads (default: conccurrency detected by boost)\n"
+ "batch_size:m sets the batch size for threading. Expect memory usage from this\n"
+ " of 2*threads*batch_size n-grams.\n\n"
+#else
+ "This binary was compiled with -DNTHREAD, disabling threading. If you wanted\n"
+ " threading, compile without this flag against Boost >=1.42.0.\n\n"
+#endif
+ "There are two inputs: vocabulary and model. Either may be given as a file\n"
+ " while the other is on stdin. Specify the type given as a file using\n"
+ " vocab: or model: before the file name. \n\n"
+ "For ARPA format, the output must be seekable. For raw format, it can be a\n"
+ " stream i.e. /dev/stdout\n";
+}
+
+typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION} FilterMode;
+typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format;
+
+struct Config {
+ Config() :
+#ifndef NTHREAD
+ batch_size(25000),
+ threads(boost::thread::hardware_concurrency()),
+#endif
+ phrase(false),
+ context(false),
+ format(FORMAT_ARPA)
+ {
+#ifndef NTHREAD
+ if (!threads) threads = 1;
+#endif
+ }
+
+#ifndef NTHREAD
+ size_t batch_size;
+ size_t threads;
+#endif
+ bool phrase;
+ bool context;
+ FilterMode mode;
+ Format format;
+};
+
+template <class Format, class Filter, class OutputBuffer, class Output> void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) {
+#ifndef NTHREAD
+ if (config.threads == 1) {
+#endif
+ Format::RunFilter(in_lm, filter, output);
+#ifndef NTHREAD
+ } else {
+ typedef Controller<Filter, OutputBuffer, Output> Threaded;
+ Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output);
+ Format::RunFilter(in_lm, threading, output);
+ }
+#endif
+}
+
+template <class Format, class Filter, class OutputBuffer, class Output> void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) {
+ if (config.context) {
+ ContextFilter<Filter> context_filter(filter);
+ RunThreadedFilter<Format, ContextFilter<Filter>, OutputBuffer, Output>(config, in_lm, context_filter, output);
+ } else {
+ RunThreadedFilter<Format, Filter, OutputBuffer, Output>(config, in_lm, filter, output);
+ }
+}
+
+template <class Format, class Binary> void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) {
+ typedef BinaryFilter<Binary> Filter;
+ RunContextFilter<Format, Filter, BinaryOutputBuffer, typename Format::Output>(config, in_lm, Filter(binary), out);
+}
+
+template <class Format> void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) {
+ if (config.mode == MODE_MULTIPLE) {
+ if (config.phrase) {
+ typedef phrase::Multiple Filter;
+ phrase::Substrings substrings;
+ typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings));
+ RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(substrings), out);
+ } else {
+ typedef vocab::Multiple Filter;
+ boost::unordered_map<std::string, std::vector<unsigned int> > words;
+ typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words));
+ RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(words), out);
+ }
+ return;
+ }
+
+ typename Format::Output out(out_name);
+
+ if (config.mode == MODE_COPY) {
+ Format::Copy(in_lm, out);
+ return;
+ }
+
+ if (config.mode == MODE_SINGLE) {
+ vocab::Single::Words words;
+ vocab::ReadSingle(in_vocab, words);
+ DispatchBinaryFilter<Format, vocab::Single>(config, in_lm, vocab::Single(words), out);
+ return;
+ }
+
+ if (config.mode == MODE_UNION) {
+ if (config.phrase) {
+ phrase::Substrings substrings;
+ phrase::ReadMultiple(in_vocab, substrings);
+ DispatchBinaryFilter<Format, phrase::Union>(config, in_lm, phrase::Union(substrings), out);
+ } else {
+ vocab::Union::Words words;
+ vocab::ReadMultiple(in_vocab, words);
+ DispatchBinaryFilter<Format, vocab::Union>(config, in_lm, vocab::Union(words), out);
+ }
+ return;
+ }
+}
+
+} // namespace
+} // namespace lm
+
+int main(int argc, char *argv[]) {
+ if (argc < 4) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+
+ // I used to have boost::program_options, but some users didn't want to compile boost.
+ lm::Config config;
+ boost::optional<lm::FilterMode> mode;
+ for (int i = 1; i < argc - 2; ++i) {
+ const char *str = argv[i];
+ if (!std::strcmp(str, "copy")) {
+ mode = lm::MODE_COPY;
+ } else if (!std::strcmp(str, "single")) {
+ mode = lm::MODE_SINGLE;
+ } else if (!std::strcmp(str, "multiple")) {
+ mode = lm::MODE_MULTIPLE;
+ } else if (!std::strcmp(str, "union")) {
+ mode = lm::MODE_UNION;
+ } else if (!std::strcmp(str, "phrase")) {
+ config.phrase = true;
+ } else if (!std::strcmp(str, "context")) {
+ config.context = true;
+ } else if (!std::strcmp(str, "arpa")) {
+ config.format = lm::FORMAT_ARPA;
+ } else if (!std::strcmp(str, "raw")) {
+ config.format = lm::FORMAT_COUNT;
+#ifndef NTHREAD
+ } else if (!std::strncmp(str, "threads:", 8)) {
+ config.threads = boost::lexical_cast<size_t>(str + 8);
+ if (!config.threads) {
+ std::cerr << "Specify at least one thread." << std::endl;
+ return 1;
+ }
+ } else if (!std::strncmp(str, "batch_size:", 11)) {
+ config.batch_size = boost::lexical_cast<size_t>(str + 11);
+ if (config.batch_size < 5000) {
+ std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
+ if (!config.batch_size) return 1;
+ }
+#endif
+ } else {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+ }
+
+ if (!mode) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+ config.mode = *mode;
+
+ if (config.phrase && config.mode != lm::MODE_UNION && mode != lm::MODE_MULTIPLE) {
+ std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
+ return 1;
+ }
+
+ bool cmd_is_model = true;
+ const char *cmd_input = argv[argc - 2];
+ if (!strncmp(cmd_input, "vocab:", 6)) {
+ cmd_is_model = false;
+ cmd_input += 6;
+ } else if (!strncmp(cmd_input, "model:", 6)) {
+ cmd_input += 6;
+ } else if (strchr(cmd_input, ':')) {
+ errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input);
+ } else {
+ std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
+ }
+ std::ifstream cmd_file;
+ std::istream *vocab;
+ if (cmd_is_model) {
+ vocab = &std::cin;
+ } else {
+ cmd_file.open(cmd_input, std::ios::in);
+ if (!cmd_file) {
+ err(2, "Could not open input file %s", cmd_input);
+ }
+ vocab = &cmd_file;
+ }
+
+ util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
+
+ if (config.format == lm::FORMAT_ARPA) {
+ lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
+ } else if (config.format == lm::FORMAT_COUNT) {
+ lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ }
+ return 0;
+}
diff --git a/klm/lm/filter/phrase.cc b/klm/lm/filter/phrase.cc
new file mode 100644
index 00000000..1bef2a3f
--- /dev/null
+++ b/klm/lm/filter/phrase.cc
@@ -0,0 +1,281 @@
+#include "lm/filter/phrase.hh"
+
+#include "lm/filter/format.hh"
+
+#include <algorithm>
+#include <functional>
+#include <iostream>
+#include <queue>
+#include <string>
+#include <vector>
+
+#include <ctype.h>
+
+namespace lm {
+namespace phrase {
+
+unsigned int ReadMultiple(std::istream &in, Substrings &out) {
+ bool sentence_content = false;
+ unsigned int sentence_id = 0;
+ std::vector<Hash> phrase;
+ std::string word;
+ while (in) {
+ char c;
+ // Gather a word.
+ while (!isspace(c = in.get()) && in) word += c;
+ // Treat EOF like a newline.
+ if (!in) c = '\n';
+ // Add the word to the phrase.
+ if (!word.empty()) {
+ phrase.push_back(util::MurmurHashNative(word.data(), word.size()));
+ word.clear();
+ }
+ if (c == ' ') continue;
+ // It's more than just a space. Close out the phrase.
+ if (!phrase.empty()) {
+ sentence_content = true;
+ out.AddPhrase(sentence_id, phrase.begin(), phrase.end());
+ phrase.clear();
+ }
+ if (c == '\t' || c == '\v') continue;
+ // It's more than a space or tab: a newline.
+ if (sentence_content) {
+ ++sentence_id;
+ sentence_content = false;
+ }
+ }
+ if (!in.eof()) in.exceptions(std::istream::failbit | std::istream::badbit);
+ return sentence_id + sentence_content;
+}
+
+namespace detail { const StringPiece kEndSentence("</s>"); }
+
+namespace {
+
+typedef unsigned int Sentence;
+typedef std::vector<Sentence> Sentences;
+
+class Vertex;
+
+class Arc {
+ public:
+ Arc() {}
+
+ // For arcs from one vertex to another.
+ void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) {
+ Set(to, intersect);
+ from_ = &from;
+ }
+
+ /* For arcs from before the n-gram begins to somewhere in the n-gram (right
+ * aligned). These have no from_ vertex; it implictly matches every
+ * sentence. This also handles when the n-gram is a substring of a phrase.
+ */
+ void SetRight(Vertex &to, const Sentences &complete) {
+ Set(to, complete);
+ from_ = NULL;
+ }
+
+ Sentence Current() const {
+ return *current_;
+ }
+
+ bool Empty() const {
+ return current_ == last_;
+ }
+
+ /* When this function returns:
+ * If Empty() then there's nothing left from this intersection.
+ *
+ * If Current() == to then to is part of the intersection.
+ *
+ * Otherwise, Current() > to. In this case, to is not part of the
+ * intersection and neither is anything < Current(). To determine if
+ * any value >= Current() is in the intersection, call LowerBound again
+ * with the value.
+ */
+ void LowerBound(const Sentence to);
+
+ private:
+ void Set(Vertex &to, const Sentences &sentences);
+
+ const Sentence *current_;
+ const Sentence *last_;
+ Vertex *from_;
+};
+
+struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> {
+ bool operator()(const Arc *first, const Arc *second) const {
+ return first->Current() > second->Current();
+ }
+};
+
+class Vertex {
+ public:
+ Vertex() : current_(0) {}
+
+ Sentence Current() const {
+ return current_;
+ }
+
+ bool Empty() const {
+ return incoming_.empty();
+ }
+
+ void LowerBound(const Sentence to);
+
+ private:
+ friend class Arc;
+
+ void AddIncoming(Arc *arc) {
+ if (!arc->Empty()) incoming_.push(arc);
+ }
+
+ unsigned int current_;
+ std::priority_queue<Arc*, std::vector<Arc*>, ArcGreater> incoming_;
+};
+
+void Arc::LowerBound(const Sentence to) {
+ current_ = std::lower_bound(current_, last_, to);
+ // If *current_ > to, don't advance from_. The intervening values of
+ // from_ may be useful for another one of its outgoing arcs.
+ if (!from_ || Empty() || (Current() > to)) return;
+ assert(Current() == to);
+ from_->LowerBound(to);
+ if (from_->Empty()) {
+ current_ = last_;
+ return;
+ }
+ assert(from_->Current() >= to);
+ if (from_->Current() > to) {
+ current_ = std::lower_bound(current_ + 1, last_, from_->Current());
+ }
+}
+
+void Arc::Set(Vertex &to, const Sentences &sentences) {
+ current_ = &*sentences.begin();
+ last_ = &*sentences.end();
+ to.AddIncoming(this);
+}
+
+void Vertex::LowerBound(const Sentence to) {
+ if (Empty()) return;
+ // Union lower bound.
+ while (true) {
+ Arc *top = incoming_.top();
+ if (top->Current() > to) {
+ current_ = top->Current();
+ return;
+ }
+ // If top->Current() == to, we still need to verify that's an actual
+ // element and not just a bound.
+ incoming_.pop();
+ top->LowerBound(to);
+ if (!top->Empty()) {
+ incoming_.push(top);
+ if (top->Current() == to) {
+ current_ = to;
+ return;
+ }
+ } else if (Empty()) {
+ return;
+ }
+ }
+}
+
+void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Vertex *const vertices, Arc *free_arc) {
+ assert(!hashes.empty());
+
+ const Hash *const first_word = &*hashes.begin();
+ const Hash *const last_word = &*hashes.end() - 1;
+
+ Hash hash = 0;
+ const Sentences *found;
+ // Phrases starting at or before the first word in the n-gram.
+ {
+ Vertex *vertex = vertices;
+ for (const Hash *word = first_word; ; ++word, ++vertex) {
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word);
+ // Now hash is [hashes.begin(), word].
+ if (word == last_word) {
+ if (phrase.FindSubstring(hash, found))
+ (free_arc++)->SetRight(*vertex, *found);
+ break;
+ }
+ if (!phrase.FindRight(hash, found)) break;
+ (free_arc++)->SetRight(*vertex, *found);
+ }
+ }
+
+ // Phrases starting at the second or later word in the n-gram.
+ Vertex *vertex_from = vertices;
+ for (const Hash *word_from = first_word + 1; word_from != &*hashes.end(); ++word_from, ++vertex_from) {
+ hash = 0;
+ Vertex *vertex_to = vertex_from + 1;
+ for (const Hash *word_to = word_from; ; ++word_to, ++vertex_to) {
+ // Notice that word_to and vertex_to have the same index.
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word_to);
+ // Now hash covers [word_from, word_to].
+ if (word_to == last_word) {
+ if (phrase.FindLeft(hash, found))
+ (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
+ break;
+ }
+ if (!phrase.FindPhrase(hash, found)) break;
+ (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
+ }
+ }
+}
+
+} // namespace
+
+namespace detail {
+
+} // namespace detail
+
+bool Union::Evaluate() {
+ assert(!hashes_.empty());
+ // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
+ Vertex vertices[hashes_.size()];
+ // One for every substring.
+ Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
+ BuildGraph(substrings_, hashes_, vertices, arcs);
+ Vertex &last_vertex = vertices[hashes_.size() - 1];
+
+ unsigned int lower = 0;
+ while (true) {
+ last_vertex.LowerBound(lower);
+ if (last_vertex.Empty()) return false;
+ if (last_vertex.Current() == lower) return true;
+ lower = last_vertex.Current();
+ }
+}
+
+template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) {
+ assert(!hashes_.empty());
+ // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
+ Vertex vertices[hashes_.size()];
+ // One for every substring.
+ Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
+ BuildGraph(substrings_, hashes_, vertices, arcs);
+ Vertex &last_vertex = vertices[hashes_.size() - 1];
+
+ unsigned int lower = 0;
+ while (true) {
+ last_vertex.LowerBound(lower);
+ if (last_vertex.Empty()) return;
+ if (last_vertex.Current() == lower) {
+ output.SingleAddNGram(lower, line);
+ ++lower;
+ } else {
+ lower = last_vertex.Current();
+ }
+ }
+}
+
+template void Multiple::Evaluate<CountFormat::Multiple>(const StringPiece &line, CountFormat::Multiple &output);
+template void Multiple::Evaluate<ARPAFormat::Multiple>(const StringPiece &line, ARPAFormat::Multiple &output);
+template void Multiple::Evaluate<MultipleOutputBuffer>(const StringPiece &line, MultipleOutputBuffer &output);
+
+} // namespace phrase
+} // namespace lm
diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh
new file mode 100644
index 00000000..07479dea
--- /dev/null
+++ b/klm/lm/filter/phrase.hh
@@ -0,0 +1,153 @@
+#ifndef LM_FILTER_PHRASE_H__
+#define LM_FILTER_PHRASE_H__
+
+#include "util/murmur_hash.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/unordered_map.hpp>
+
+#include <iosfwd>
+#include <vector>
+
+#define LM_FILTER_PHRASE_METHOD(caps, lower) \
+bool Find##caps(Hash key, const std::vector<unsigned int> *&out) const {\
+ Table::const_iterator i(table_.find(key));\
+ if (i==table_.end()) return false; \
+ out = &i->second.lower; \
+ return true; \
+}
+
+namespace lm {
+namespace phrase {
+
+typedef uint64_t Hash;
+
+class Substrings {
+ private:
+ /* This is the value in a hash table where the key is a string. It indicates
+ * four sets of sentences:
+ * substring is sentences with a phrase containing the key as a substring.
+ * left is sentencess with a phrase that begins with the key (left aligned).
+ * right is sentences with a phrase that ends with the key (right aligned).
+ * phrase is sentences where the key is a phrase.
+ * Each set is encoded as a vector of sentence ids in increasing order.
+ */
+ struct SentenceRelation {
+ std::vector<unsigned int> substring, left, right, phrase;
+ };
+ /* Most of the CPU is hash table lookups, so let's not complicate it with
+ * vector equality comparisons. If a collision happens, the SentenceRelation
+ * structure will contain the union of sentence ids over the colliding strings.
+ * In that case, the filter will be slightly more permissive.
+ * The key here is the same as boost's hash of std::vector<std::string>.
+ */
+ typedef boost::unordered_map<Hash, SentenceRelation> Table;
+
+ public:
+ Substrings() {}
+
+ /* If the string isn't a substring of any phrase, return NULL. Otherwise,
+ * return a pointer to std::vector<unsigned int> listing sentences with
+ * matching phrases. This set may be empty for Left, Right, or Phrase.
+ * Example: const std::vector<unsigned int> *FindSubstring(Hash key)
+ */
+ LM_FILTER_PHRASE_METHOD(Substring, substring)
+ LM_FILTER_PHRASE_METHOD(Left, left)
+ LM_FILTER_PHRASE_METHOD(Right, right)
+ LM_FILTER_PHRASE_METHOD(Phrase, phrase)
+
+ // sentence_id must be non-decreasing. Iterators are over words in the phrase.
+ template <class Iterator> void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) {
+ // Iterate over all substrings.
+ for (Iterator start = begin; start != end; ++start) {
+ Hash hash = 0;
+ SentenceRelation *relation;
+ for (Iterator finish = start; finish != end; ++finish) {
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish);
+ // Now hash is of [start, finish].
+ relation = &table_[hash];
+ AppendSentence(relation->substring, sentence_id);
+ if (start == begin) AppendSentence(relation->left, sentence_id);
+ }
+ AppendSentence(relation->right, sentence_id);
+ if (start == begin) AppendSentence(relation->phrase, sentence_id);
+ }
+ }
+
+ private:
+ void AppendSentence(std::vector<unsigned int> &vec, unsigned int sentence_id) {
+ if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id);
+ }
+
+ Table table_;
+};
+
+// Read a file with one sentence per line containing tab-delimited phrases of
+// space-separated words.
+unsigned int ReadMultiple(std::istream &in, Substrings &out);
+
+namespace detail {
+extern const StringPiece kEndSentence;
+
+template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::vector<Hash> &hashes) {
+ hashes.clear();
+ if (i == end) return;
+ // TODO: check strict phrase boundaries after <s> and before </s>. For now, just skip tags.
+ if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) {
+ ++i;
+ }
+ for (; i != end && (*i != kEndSentence); ++i) {
+ hashes.push_back(util::MurmurHashNative(i->data(), i->size()));
+ }
+}
+
+} // namespace detail
+
+class Union {
+ public:
+ explicit Union(const Substrings &substrings) : substrings_(substrings) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ detail::MakeHashes(begin, end, hashes_);
+ return hashes_.empty() || Evaluate();
+ }
+
+ private:
+ bool Evaluate();
+
+ std::vector<Hash> hashes_;
+
+ const Substrings &substrings_;
+};
+
+class Multiple {
+ public:
+ explicit Multiple(const Substrings &substrings) : substrings_(substrings) {}
+
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ detail::MakeHashes(begin, end, hashes_);
+ if (hashes_.empty()) {
+ output.AddNGram(line);
+ return;
+ }
+ Evaluate(line, output);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ template <class Output> void Evaluate(const StringPiece &line, Output &output);
+
+ std::vector<Hash> hashes_;
+
+ const Substrings &substrings_;
+};
+
+} // namespace phrase
+} // namespace lm
+#endif // LM_FILTER_PHRASE_H__
diff --git a/klm/lm/filter/thread.hh b/klm/lm/filter/thread.hh
new file mode 100644
index 00000000..e785b263
--- /dev/null
+++ b/klm/lm/filter/thread.hh
@@ -0,0 +1,167 @@
+#ifndef LM_FILTER_THREAD_H__
+#define LM_FILTER_THREAD_H__
+
+#include "util/thread_pool.hh"
+
+#include <boost/utility/in_place_factory.hpp>
+
+#include <deque>
+#include <stack>
+
+namespace lm {
+
+template <class OutputBuffer> class ThreadBatch {
+ public:
+ ThreadBatch() {}
+
+ void Reserve(size_t size) {
+ input_.Reserve(size);
+ output_.Reserve(size);
+ }
+
+ // File reading thread.
+ InputBuffer &Fill(uint64_t sequence) {
+ sequence_ = sequence;
+ // Why wait until now to clear instead of after output? free in the same
+ // thread as allocated.
+ input_.Clear();
+ return input_;
+ }
+
+ // Filter worker thread.
+ template <class Filter> void CallFilter(Filter &filter) {
+ input_.CallFilter(filter, output_);
+ }
+
+ uint64_t Sequence() const { return sequence_; }
+
+ // File writing thread.
+ template <class RealOutput> void Flush(RealOutput &output) {
+ output_.Flush(output);
+ }
+
+ private:
+ InputBuffer input_;
+ OutputBuffer output_;
+
+ uint64_t sequence_;
+};
+
+template <class Batch, class Filter> class FilterWorker {
+ public:
+ typedef Batch *Request;
+
+ FilterWorker(const Filter &filter, util::PCQueue<Request> &done) : filter_(filter), done_(done) {}
+
+ void operator()(Request request) {
+ request->CallFilter(filter_);
+ done_.Produce(request);
+ }
+
+ private:
+ Filter filter_;
+
+ util::PCQueue<Request> &done_;
+};
+
+// There should only be one OutputWorker.
+template <class Batch, class Output> class OutputWorker {
+ public:
+ typedef Batch *Request;
+
+ OutputWorker(Output &output, util::PCQueue<Request> &done) : output_(output), done_(done), base_sequence_(0) {}
+
+ void operator()(Request request) {
+ assert(request->Sequence() >= base_sequence_);
+ // Assemble the output in order.
+ uint64_t pos = request->Sequence() - base_sequence_;
+ if (pos >= ordering_.size()) {
+ ordering_.resize(pos + 1, NULL);
+ }
+ ordering_[pos] = request;
+ while (!ordering_.empty() && ordering_.front()) {
+ ordering_.front()->Flush(output_);
+ done_.Produce(ordering_.front());
+ ordering_.pop_front();
+ ++base_sequence_;
+ }
+ }
+
+ private:
+ Output &output_;
+
+ util::PCQueue<Request> &done_;
+
+ std::deque<Request> ordering_;
+
+ uint64_t base_sequence_;
+};
+
+template <class Filter, class OutputBuffer, class RealOutput> class Controller : boost::noncopyable {
+ private:
+ typedef ThreadBatch<OutputBuffer> Batch;
+
+ public:
+ Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output)
+ : batch_size_(batch_size), queue_size_(queue),
+ batches_(queue),
+ to_read_(queue),
+ output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL),
+ filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL),
+ sequence_(0) {
+ for (size_t i = 0; i < queue; ++i) {
+ batches_[i].Reserve(batch_size);
+ local_read_.push(&batches_[i]);
+ }
+ NewInput();
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) {
+ input_->AddNGram(ngram, line, output);
+ if (input_->Size() == batch_size_) {
+ FlushInput();
+ NewInput();
+ }
+ }
+
+ void Flush() {
+ FlushInput();
+ while (local_read_.size() < queue_size_) {
+ MoveRead();
+ }
+ NewInput();
+ }
+
+ private:
+ void FlushInput() {
+ if (input_->Empty()) return;
+ filter_.Produce(local_read_.top());
+ local_read_.pop();
+ if (local_read_.empty()) MoveRead();
+ }
+
+ void NewInput() {
+ input_ = &local_read_.top()->Fill(sequence_++);
+ }
+
+ void MoveRead() {
+ local_read_.push(to_read_.Consume());
+ }
+
+ const size_t batch_size_;
+ const size_t queue_size_;
+
+ std::vector<Batch> batches_;
+
+ util::PCQueue<Batch*> to_read_;
+ std::stack<Batch*> local_read_;
+ util::ThreadPool<OutputWorker<Batch, RealOutput> > output_;
+ util::ThreadPool<FilterWorker<Batch, Filter> > filter_;
+
+ uint64_t sequence_;
+ InputBuffer *input_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_THREAD_H__
diff --git a/klm/lm/filter/vocab.cc b/klm/lm/filter/vocab.cc
new file mode 100644
index 00000000..7ee4e84b
--- /dev/null
+++ b/klm/lm/filter/vocab.cc
@@ -0,0 +1,54 @@
+#include "lm/filter/vocab.hh"
+
+#include <istream>
+#include <iostream>
+
+#include <ctype.h>
+#include <err.h>
+
+namespace lm {
+namespace vocab {
+
+void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out) {
+ in.exceptions(std::istream::badbit);
+ std::string word;
+ while (in >> word) {
+ out.insert(word);
+ }
+}
+
+namespace {
+bool IsLineEnd(std::istream &in) {
+ int got;
+ do {
+ got = in.get();
+ if (!in) return true;
+ if (got == '\n') return true;
+ } while (isspace(got));
+ in.unget();
+ return false;
+}
+}// namespace
+
+// Read space separated words in enter separated lines. These lines can be
+// very long, so don't read an entire line at a time.
+unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out) {
+ in.exceptions(std::istream::badbit);
+ unsigned int sentence = 0;
+ bool used_id = false;
+ std::string word;
+ while (in >> word) {
+ used_id = true;
+ std::vector<unsigned int> &posting = out[word];
+ if (posting.empty() || (posting.back() != sentence))
+ posting.push_back(sentence);
+ if (IsLineEnd(in)) {
+ ++sentence;
+ used_id = false;
+ }
+ }
+ return sentence + used_id;
+}
+
+} // namespace vocab
+} // namespace lm
diff --git a/klm/lm/filter/vocab.hh b/klm/lm/filter/vocab.hh
new file mode 100644
index 00000000..e2b6adff
--- /dev/null
+++ b/klm/lm/filter/vocab.hh
@@ -0,0 +1,132 @@
+#ifndef LM_FILTER_VOCAB_H__
+#define LM_FILTER_VOCAB_H__
+
+// Vocabulary-based filters for language models.
+
+#include "util/multi_intersection.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/noncopyable.hpp>
+#include <boost/range/iterator_range.hpp>
+#include <boost/unordered/unordered_map.hpp>
+#include <boost/unordered/unordered_set.hpp>
+
+#include <string>
+#include <vector>
+
+namespace lm {
+namespace vocab {
+
+void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out);
+
+// Read one sentence vocabulary per line. Return the number of sentences.
+unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out);
+
+/* Is this a special tag like <s> or <UNK>? This actually includes anything
+ * surrounded with < and >, which most tokenizers separate for real words, so
+ * this should not catch real words as it looks at a single token.
+ */
+inline bool IsTag(const StringPiece &value) {
+ // The parser should never give an empty string.
+ assert(!value.empty());
+ return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>');
+}
+
+class Single {
+ public:
+ typedef boost::unordered_set<std::string> Words;
+
+ explicit Single(const Words &vocab) : vocab_(vocab) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ for (Iterator i = begin; i != end; ++i) {
+ if (IsTag(*i)) continue;
+ if (FindStringPiece(vocab_, *i) == vocab_.end()) return false;
+ }
+ return true;
+ }
+
+ private:
+ const Words &vocab_;
+};
+
+class Union {
+ public:
+ typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
+
+ explicit Union(const Words &vocabs) : vocabs_(vocabs) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ sets_.clear();
+
+ for (Iterator i(begin); i != end; ++i) {
+ if (IsTag(*i)) continue;
+ Words::const_iterator found(FindStringPiece(vocabs_, *i));
+ if (vocabs_.end() == found) return false;
+ sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
+ }
+ return (sets_.empty() || util::FirstIntersection(sets_));
+ }
+
+ private:
+ const Words &vocabs_;
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets_;
+};
+
+class Multiple {
+ public:
+ typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
+
+ Multiple(const Words &vocabs) : vocabs_(vocabs) {}
+
+ private:
+ // Callback from AllIntersection that does AddNGram.
+ template <class Output> class Callback {
+ public:
+ Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {}
+
+ void operator()(unsigned int index) {
+ out_.SingleAddNGram(index, line_);
+ }
+
+ private:
+ Output &out_;
+ const StringPiece &line_;
+ };
+
+ public:
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ sets_.clear();
+ for (Iterator i(begin); i != end; ++i) {
+ if (IsTag(*i)) continue;
+ Words::const_iterator found(FindStringPiece(vocabs_, *i));
+ if (vocabs_.end() == found) return;
+ sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
+ }
+ if (sets_.empty()) {
+ output.AddNGram(line);
+ return;
+ }
+
+ Callback<Output> cb(output, line);
+ util::AllIntersection(sets_, cb);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ const Words &vocabs_;
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets_;
+};
+
+} // namespace vocab
+} // namespace lm
+
+#endif // LM_FILTER_VOCAB_H__
diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh
new file mode 100644
index 00000000..90b07a08
--- /dev/null
+++ b/klm/lm/filter/wrapper.hh
@@ -0,0 +1,58 @@
+#ifndef LM_FILTER_WRAPPER_H__
+#define LM_FILTER_WRAPPER_H__
+
+#include "util/string_piece.hh"
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+namespace lm {
+
+// Provide a single-output filter with the same interface as a
+// multiple-output filter so clients code against one interface.
+template <class Binary> class BinaryFilter {
+ public:
+ // Binary modes are just references (and a set) and it makes the API cleaner to copy them.
+ explicit BinaryFilter(Binary binary) : binary_(binary) {}
+
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ if (binary_.PassNGram(begin, end))
+ output.AddNGram(line);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ Binary binary_;
+};
+
+// Wrap another filter to pay attention only to context words
+template <class FilterT> class ContextFilter {
+ public:
+ typedef FilterT Filter;
+
+ explicit ContextFilter(Filter &backend) : backend_(backend) {}
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ pieces_.clear();
+ // TODO: this copy could be avoided by a lookahead iterator.
+ std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_));
+ backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ std::vector<StringPiece> pieces_;
+
+ Filter backend_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_WRAPPER_H__
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 32084b5b..eb159094 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -1,6 +1,7 @@
#include "lm/model.hh"
#include <stdlib.h>
+#include <string.h>
#define BOOST_TEST_MODULE ModelTest
#include <boost/test/unit_test.hpp>
@@ -22,17 +23,20 @@ std::ostream &operator<<(std::ostream &o, const State &state) {
namespace {
+// Stupid bjam reverses the command line arguments randomly.
const char *TestLocation() {
- if (boost::unit_test::framework::master_test_suite().argc < 2) {
+ if (boost::unit_test::framework::master_test_suite().argc < 3) {
return "test.arpa";
}
- return boost::unit_test::framework::master_test_suite().argv[1];
+ char **argv = boost::unit_test::framework::master_test_suite().argv;
+ return argv[strstr(argv[1], "nounk") ? 2 : 1];
}
const char *TestNoUnkLocation() {
if (boost::unit_test::framework::master_test_suite().argc < 3) {
return "test_nounk.arpa";
}
- return boost::unit_test::framework::master_test_suite().argv[2];
+ char **argv = boost::unit_test::framework::master_test_suite().argv;
+ return argv[strstr(argv[1], "nounk") ? 1 : 2];
}
template <class Model> State GetState(const Model &model, const char *word, const State &in) {
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index b709fef9..9ea08798 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -1,6 +1,7 @@
#include "lm/read_arpa.hh"
#include "lm/blank.hh"
+#include "util/file.hh"
#include <cmath>
#include <cstdlib>
@@ -45,8 +46,14 @@ uint64_t ReadCount(const std::string &from) {
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
number.clear();
- StringPiece line;
- while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
+ StringPiece line = in.ReadLine();
+ // In general, ARPA files can have arbitrary text before "\data\"
+ // But in KenLM, we require such lines to start with "#", so that
+ // we can do stricter error checking
+ while (IsEntirelyWhiteSpace(line) || line.starts_with("#")) {
+ line = in.ReadLine();
+ }
+
if (line != "\\data\\") {
if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
diff --git a/klm/lm/sizes.cc b/klm/lm/sizes.cc
new file mode 100644
index 00000000..55ad586c
--- /dev/null
+++ b/klm/lm/sizes.cc
@@ -0,0 +1,63 @@
+#include "lm/sizes.hh"
+#include "lm/model.hh"
+#include "util/file_piece.hh"
+
+#include <vector>
+#include <iomanip>
+
+namespace lm {
+namespace ngram {
+
+void ShowSizes(const std::vector<uint64_t> &counts, const lm::ngram::Config &config) {
+ uint64_t sizes[6];
+ sizes[0] = ProbingModel::Size(counts, config);
+ sizes[1] = RestProbingModel::Size(counts, config);
+ sizes[2] = TrieModel::Size(counts, config);
+ sizes[3] = QuantTrieModel::Size(counts, config);
+ sizes[4] = ArrayTrieModel::Size(counts, config);
+ sizes[5] = QuantArrayTrieModel::Size(counts, config);
+ uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
+ uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
+ uint64_t divide;
+ char prefix;
+ if (min_length < (1 << 10) * 10) {
+ prefix = ' ';
+ divide = 1;
+ } else if (min_length < (1 << 20) * 10) {
+ prefix = 'k';
+ divide = 1 << 10;
+ } else if (min_length < (1ULL << 30) * 10) {
+ prefix = 'M';
+ divide = 1 << 20;
+ } else {
+ prefix = 'G';
+ divide = 1 << 30;
+ }
+ long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide))));
+ std::cerr << "Memory estimate for binary LM:\ntype ";
+
+ // right align bytes.
+ for (long int i = 0; i < length - 2; ++i) std::cerr << ' ';
+
+ std::cerr << prefix << "B\n"
+ "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"
+ "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n"
+ "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n"
+ "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
+ "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
+ "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";
+}
+
+void ShowSizes(const std::vector<uint64_t> &counts) {
+ lm::ngram::Config config;
+ ShowSizes(counts, config);
+}
+
+void ShowSizes(const char *file, const lm::ngram::Config &config) {
+ std::vector<uint64_t> counts;
+ util::FilePiece f(file);
+ lm::ReadARPACounts(f, counts);
+ ShowSizes(counts, config);
+}
+
+}} //namespaces
diff --git a/klm/lm/sizes.hh b/klm/lm/sizes.hh
new file mode 100644
index 00000000..85abade7
--- /dev/null
+++ b/klm/lm/sizes.hh
@@ -0,0 +1,17 @@
+#ifndef LM_SIZES__
+#define LM_SIZES__
+
+#include <vector>
+
+#include <stdint.h>
+
+namespace lm { namespace ngram {
+
+struct Config;
+
+void ShowSizes(const std::vector<uint64_t> &counts, const lm::ngram::Config &config);
+void ShowSizes(const std::vector<uint64_t> &counts);
+void ShowSizes(const char *file, const lm::ngram::Config &config);
+
+}} // namespaces
+#endif // LM_SIZES__
diff --git a/klm/lm/state.hh b/klm/lm/state.hh
index 551510a8..d8e6c132 100644
--- a/klm/lm/state.hh
+++ b/klm/lm/state.hh
@@ -56,14 +56,14 @@ inline uint64_t hash_value(const State &state, uint64_t seed = 0) {
struct Left {
bool operator==(const Left &other) const {
return
- (length == other.length) &&
- pointers[length - 1] == other.pointers[length - 1] &&
- full == other.full;
+ length == other.length &&
+ (!length || (pointers[length - 1] == other.pointers[length - 1] && full == other.full));
}
int Compare(const Left &other) const {
if (length < other.length) return -1;
if (length > other.length) return 1;
+ if (length == 0) return 0; // Must be full.
if (pointers[length - 1] > other.pointers[length - 1]) return 1;
if (pointers[length - 1] < other.pointers[length - 1]) return -1;
return (int)full - (int)other.full;
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index 8663e94e..dc542bb3 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -65,13 +65,13 @@ class PartialViewProxy {
typedef util::ProxyIterator<PartialViewProxy> PartialIter;
-FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) {
- util::scoped_fd file(maker.Make());
+FILE *DiskFlush(const void *mem_begin, const void *mem_end, const std::string &temp_prefix) {
+ util::scoped_fd file(util::MakeTemp(temp_prefix));
util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin);
return util::FDOpenOrThrow(file);
}
-FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) {
+FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_prefix, std::size_t entry_size, unsigned char order) {
const size_t context_size = sizeof(WordIndex) * (order - 1);
// Sort just the contexts using the same memory.
PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
@@ -84,7 +84,7 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make
#endif
(context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
- util::scoped_FILE out(maker.MakeFile());
+ util::scoped_FILE out(util::FMakeTemp(temp_prefix));
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
if (context_begin == context_end) return out.release();
@@ -114,12 +114,12 @@ struct FirstCombine {
}
};
-template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) {
+template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const std::string &temp_prefix, std::size_t weights_size, unsigned char order, const Combine &combine) {
std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
RecordReader first, second;
first.Init(first_file, entry_size);
second.Init(second_file, entry_size);
- util::scoped_FILE out_file(maker.MakeFile());
+ util::scoped_FILE out_file(util::FMakeTemp(temp_prefix));
EntryCompare less(order);
while (first && second) {
if (less(first.Data(), second.Data())) {
@@ -177,9 +177,8 @@ void RecordReader::Rewind() {
}
SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
- util::TempMaker maker(file_prefix);
PositiveProbWarn warn(config.positive_log_probability);
- unigram_.reset(maker.Make());
+ unigram_.reset(util::MakeTemp(file_prefix));
{
// In case <unk> appears.
size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff);
@@ -202,7 +201,7 @@ SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<u
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
for (unsigned char order = 2; order <= counts.size(); ++order) {
- ConvertToSorted(f, vocab, counts, maker, order, warn, mem.get(), buffer);
+ ConvertToSorted(f, vocab, counts, file_prefix, order, warn, mem.get(), buffer);
}
ReadEnd(f);
}
@@ -227,7 +226,7 @@ class Closer {
};
} // namespace
-void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
+void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
ReadNGramHeader(f, order);
const size_t count = counts[order - 1];
// Size of weights. Does it include backoff?
@@ -261,8 +260,8 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
std::sort
#endif
(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order)));
- files.push_back(DiskFlush(begin, out_end, maker));
- contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order));
+ files.push_back(DiskFlush(begin, out_end, file_prefix));
+ contexts.push_back(WriteContextFile(begin, out_end, file_prefix, entry_size, order));
done += (out_end - begin) / entry_size;
}
@@ -270,10 +269,10 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
// All individual files created. Merge them.
while (files.size() > 1) {
- files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine()));
+ files.push_back(MergeSortedFiles(files[0], files[1], file_prefix, weights_size, order, ThrowCombine()));
files_closer.PopFront();
files_closer.PopFront();
- contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine()));
+ contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], file_prefix, 0, order - 1, FirstCombine()));
contexts_closer.PopFront();
contexts_closer.PopFront();
}
diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh
index 2197b80c..1afd9562 100644
--- a/klm/lm/trie_sort.hh
+++ b/klm/lm/trie_sort.hh
@@ -18,7 +18,6 @@
namespace util {
class FilePiece;
-class TempMaker;
} // namespace util
namespace lm {
@@ -101,7 +100,7 @@ class SortedFiles {
}
private:
- void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size);
+ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size);
util::scoped_fd unigram_;