summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2013-04-24 10:12:41 +0100
committerKenneth Heafield <github@kheafield.com>2013-04-24 10:12:41 +0100
commitdb960a8bba81df3217660ec5a96d73e0d6baa01b (patch)
tree7d84cff7fc47fda4ce28ca5164ab74ebf7f6ece8 /klm/lm
parentbf10ad9d1d3a17ae82804f947616db89f41d4f28 (diff)
KenLM 0831569c3137536165b107c6841603c725dfa2b1
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/builder/corpus_count.cc82
-rw-r--r--klm/lm/builder/corpus_count.hh5
-rw-r--r--klm/lm/builder/corpus_count_test.cc2
-rw-r--r--klm/lm/builder/lmplz_main.cc17
-rw-r--r--klm/lm/builder/pipeline.cc7
-rw-r--r--klm/lm/builder/pipeline.hh9
-rw-r--r--klm/lm/builder/print.cc74
-rw-r--r--klm/lm/builder/print.hh3
-rw-r--r--klm/lm/filter/filter_main.cc4
-rw-r--r--klm/lm/kenlm_max_order_main.cc6
-rw-r--r--klm/lm/query_main.cc1
11 files changed, 97 insertions, 113 deletions
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc
index abea4ed0..aea93ad1 100644
--- a/klm/lm/builder/corpus_count.cc
+++ b/klm/lm/builder/corpus_count.cc
@@ -3,6 +3,7 @@
#include "lm/builder/ngram.hh"
#include "lm/lm_exception.hh"
#include "lm/word_index.hh"
+#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include "util/murmur_hash.hh"
@@ -23,39 +24,71 @@ namespace lm {
namespace builder {
namespace {
+#pragma pack(push)
+#pragma pack(4)
+struct VocabEntry {
+ typedef uint64_t Key;
+
+ uint64_t GetKey() const { return key; }
+ void SetKey(uint64_t to) { key = to; }
+
+ uint64_t key;
+ lm::WordIndex value;
+};
+#pragma pack(pop)
+
+const float kProbingMultiplier = 1.5;
+
class VocabHandout {
public:
- explicit VocabHandout(int fd) {
- util::scoped_fd duped(util::DupOrThrow(fd));
- word_list_.reset(util::FDOpenOrThrow(duped));
-
+ static std::size_t MemUsage(WordIndex initial_guess) {
+ if (initial_guess < 2) initial_guess = 2;
+ return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier));
+ }
+
+ explicit VocabHandout(int fd, WordIndex initial_guess) :
+ table_backing_(util::CallocOrThrow(MemUsage(initial_guess))),
+ table_(table_backing_.get(), MemUsage(initial_guess)),
+ double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)),
+ word_list_(fd) {
Lookup("<unk>"); // Force 0
Lookup("<s>"); // Force 1
Lookup("</s>"); // Force 2
}
WordIndex Lookup(const StringPiece &word) {
- uint64_t hashed = util::MurmurHashNative(word.data(), word.size());
- std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size())));
- if (ret.second) {
- char null_delimit = 0;
- util::WriteOrThrow(word_list_.get(), word.data(), word.size());
- util::WriteOrThrow(word_list_.get(), &null_delimit, 1);
- UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ VocabEntry entry;
+ entry.key = util::MurmurHashNative(word.data(), word.size());
+ entry.value = table_.SizeNoSerialization();
+
+ Table::MutableIterator it;
+ if (table_.FindOrInsert(entry, it))
+ return it->value;
+ word_list_ << word << '\0';
+ UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ if (Size() >= double_cutoff_) {
+ table_backing_.call_realloc(table_.DoubleTo());
+ table_.Double(table_backing_.get());
+ double_cutoff_ *= 2;
}
- return ret.first->second;
+ return entry.value;
}
WordIndex Size() const {
- return seen_.size();
+ return table_.SizeNoSerialization();
}
private:
- typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen;
+ // TODO: factor out a resizable probing hash table.
+ // TODO: use mremap on linux to get all zeros on resizes.
+ util::scoped_malloc table_backing_;
- Seen seen_;
+ typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table;
+ Table table_;
- util::scoped_FILE word_list_;
+ std::size_t double_cutoff_;
+
+ util::FakeOFStream word_list_;
};
class DedupeHash : public std::unary_function<const WordIndex *, bool> {
@@ -85,6 +118,7 @@ class DedupeEquals : public std::binary_function<const WordIndex *, const WordIn
struct DedupeEntry {
typedef WordIndex *Key;
Key GetKey() const { return key; }
+ void SetKey(WordIndex *to) { key = to; }
Key key;
static DedupeEntry Construct(WordIndex *at) {
DedupeEntry ret;
@@ -95,8 +129,6 @@ struct DedupeEntry {
typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
-const float kProbingMultiplier = 1.5;
-
class Writer {
public:
Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
@@ -105,7 +137,7 @@ class Writer {
dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
buffer_(new WordIndex[order - 1]),
block_size_(position.GetChain().BlockSize()) {
- dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ dedupe_.Clear();
assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
if (order == 1) {
// Add special words. AdjustCounts is responsible if order != 1.
@@ -149,7 +181,7 @@ class Writer {
}
// Block end. Need to store the context in a temporary buffer.
std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
- dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ dedupe_.Clear();
block_->SetValidSize(block_size_);
gram_.ReBase((++block_)->Get());
std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
@@ -187,18 +219,22 @@ float CorpusCount::DedupeMultiplier(std::size_t order) {
return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
}
+std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
+ return VocabHandout::MemUsage(vocab_estimate);
+}
+
CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
- token_count_ = 0;
- type_count_ = 0;
}
void CorpusCount::Run(const util::stream::ChainPosition &position) {
UTIL_TIMER("(%w s) Counted n-grams\n");
- VocabHandout vocab(vocab_write_);
+ VocabHandout vocab(vocab_write_, type_count_);
+ token_count_ = 0;
+ type_count_ = 0;
const WordIndex end_sentence = vocab.Lookup("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;
diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh
index e255bad1..aa0ed8ed 100644
--- a/klm/lm/builder/corpus_count.hh
+++ b/klm/lm/builder/corpus_count.hh
@@ -23,6 +23,11 @@ class CorpusCount {
// Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size
static float DedupeMultiplier(std::size_t order);
+ // How much memory vocabulary will use based on estimated size of the vocab.
+ static std::size_t VocabUsage(std::size_t vocab_estimate);
+
+ // token_count: out.
+ // type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value.
CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);
void Run(const util::stream::ChainPosition &position);
diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc
index 8d53ca9d..6d325ef5 100644
--- a/klm/lm/builder/corpus_count_test.cc
+++ b/klm/lm/builder/corpus_count_test.cc
@@ -44,7 +44,7 @@ BOOST_AUTO_TEST_CASE(Short) {
util::stream::Chain chain(config);
NGramStream stream;
uint64_t token_count;
- WordIndex type_count;
+ WordIndex type_count = 10;
CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize());
chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;
diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc
index 90b9dca2..1e086dcc 100644
--- a/klm/lm/builder/lmplz_main.cc
+++ b/klm/lm/builder/lmplz_main.cc
@@ -6,6 +6,7 @@
#include <iostream>
#include <boost/program_options.hpp>
+#include <boost/version.hpp>
namespace {
class SizeNotify {
@@ -33,13 +34,17 @@ int main(int argc, char *argv[]) {
lm::builder::PipelineConfig pipeline;
options.add_options()
- ("order,o", po::value<std::size_t>(&pipeline.order)->required(), "Order of the model")
+ ("order,o", po::value<std::size_t>(&pipeline.order)
+#if BOOST_VERSION >= 104200
+ ->required()
+#endif
+ , "Order of the model")
("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
- ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step")
("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
+ ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.");
@@ -68,6 +73,14 @@ int main(int argc, char *argv[]) {
po::store(po::parse_command_line(argc, argv, options), vm);
po::notify(vm);
+ // required() appeared in Boost 1.42.0.
+#if BOOST_VERSION < 104200
+ if (!vm.count("order")) {
+ std::cerr << "the option '--order' is required but missing" << std::endl;
+ return 1;
+ }
+#endif
+
util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;
diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc
index 14a1f721..b89ea6ba 100644
--- a/klm/lm/builder/pipeline.cc
+++ b/klm/lm/builder/pipeline.cc
@@ -207,17 +207,18 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
const PipelineConfig &config = master.Config();
std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl;
- UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory());
+ const std::size_t vocab_usage = CorpusCount::VocabUsage(config.vocab_estimate);
+ UTIL_THROW_IF(config.TotalMemory() < vocab_usage, util::Exception, "Vocab hash size estimate " << vocab_usage << " exceeds total memory " << config.TotalMemory());
std::size_t memory_for_chain =
// This much memory to work with after vocab hash table.
- static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) /
+ static_cast<float>(config.TotalMemory() - vocab_usage) /
// Solve for block size including the dedupe multiplier for one block.
(static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *
// Chain likes memory expressed in terms of total memory.
static_cast<float>(config.block_count);
util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain));
- WordIndex type_count;
+ WordIndex type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());
diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh
index f1d6c5f6..845e5481 100644
--- a/klm/lm/builder/pipeline.hh
+++ b/klm/lm/builder/pipeline.hh
@@ -3,6 +3,7 @@
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/header_info.hh"
+#include "lm/word_index.hh"
#include "util/stream/config.hh"
#include "util/file_piece.hh"
@@ -19,9 +20,9 @@ struct PipelineConfig {
util::stream::ChainConfig read_backoffs;
bool verbose_header;
- // Amount of memory to assume that the vocabulary hash table will use. This
- // is subtracted from total memory for CorpusCount.
- std::size_t assume_vocab_hash_size;
+ // Estimated vocabulary size. Used for sizing CorpusCount memory and
+ // initial probing hash table sizing, also in CorpusCount.
+ lm::WordIndex vocab_estimate;
// Minimum block size to tolerate.
std::size_t minimum_block;
@@ -33,7 +34,7 @@ struct PipelineConfig {
std::size_t TotalMemory() const { return sort.total_memory; }
};
-// Takes ownership of text_file.
+// Takes ownership of text_file and out_arpa.
void Pipeline(PipelineConfig config, int text_file, int out_arpa);
}} // namespaces
diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc
index b0323221..84bd81ca 100644
--- a/klm/lm/builder/print.cc
+++ b/klm/lm/builder/print.cc
@@ -1,15 +1,11 @@
#include "lm/builder/print.hh"
-#include "util/double-conversion/double-conversion.h"
-#include "util/double-conversion/utils.h"
+#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/mmap.hh"
#include "util/scoped.hh"
#include "util/stream/timer.hh"
-#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE
-#include <boost/lexical_cast.hpp>
-
#include <sstream>
#include <string.h>
@@ -28,71 +24,6 @@ VocabReconstitute::VocabReconstitute(int fd) {
map_.push_back(i);
}
-namespace {
-class OutputManager {
- public:
- static const std::size_t kOutBuf = 1048576;
-
- // Does not take ownership of out.
- explicit OutputManager(int out)
- : buf_(util::MallocOrThrow(kOutBuf)),
- builder_(static_cast<char*>(buf_.get()), kOutBuf),
- // Mostly the default but with inf instead. And no flags.
- convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0),
- fd_(out) {}
-
- ~OutputManager() {
- Flush();
- }
-
- OutputManager &operator<<(float value) {
- // Odd, but this is the largest number found in the comments.
- EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
- convert_.ToShortestSingle(value, &builder_);
- return *this;
- }
-
- OutputManager &operator<<(StringPiece str) {
- if (str.size() > kOutBuf) {
- Flush();
- util::WriteOrThrow(fd_, str.data(), str.size());
- } else {
- EnsureRemaining(str.size());
- builder_.AddSubstring(str.data(), str.size());
- }
- return *this;
- }
-
- // Inefficient!
- OutputManager &operator<<(unsigned val) {
- return *this << boost::lexical_cast<std::string>(val);
- }
-
- OutputManager &operator<<(char c) {
- EnsureRemaining(1);
- builder_.AddCharacter(c);
- return *this;
- }
-
- void Flush() {
- util::WriteOrThrow(fd_, buf_.get(), builder_.position());
- builder_.Reset();
- }
-
- private:
- void EnsureRemaining(std::size_t amount) {
- if (static_cast<std::size_t>(builder_.size() - builder_.position()) < amount) {
- Flush();
- }
- }
-
- util::scoped_malloc buf_;
- double_conversion::StringBuilder builder_;
- double_conversion::DoubleToStringConverter convert_;
- int fd_;
-};
-} // namespace
-
PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd)
: vocab_(vocab), out_fd_(out_fd) {
std::stringstream stream;
@@ -112,8 +43,9 @@ PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t>
}
void PrintARPA::Run(const ChainPositions &positions) {
+ util::scoped_fd closer(out_fd_);
UTIL_TIMER("(%w s) Wrote ARPA file\n");
- OutputManager out(out_fd_);
+ util::FakeOFStream out(out_fd_);
for (unsigned order = 1; order <= positions.size(); ++order) {
out << "\\" << order << "-grams:" << '\n';
for (NGramStream stream(positions[order - 1]); stream; ++stream) {
diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh
index aa932e75..adbbb94a 100644
--- a/klm/lm/builder/print.hh
+++ b/klm/lm/builder/print.hh
@@ -88,7 +88,8 @@ template <class V> class Print {
class PrintARPA {
public:
- // header_info may be NULL to disable the header
+ // header_info may be NULL to disable the header.
+ // Takes ownership of out_fd upon Run().
explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd);
void Run(const ChainPositions &positions);
diff --git a/klm/lm/filter/filter_main.cc b/klm/lm/filter/filter_main.cc
index 1a4ba84f..1736bc40 100644
--- a/klm/lm/filter/filter_main.cc
+++ b/klm/lm/filter/filter_main.cc
@@ -25,8 +25,8 @@ void DisplayHelp(const char *name) {
" parser.\n"
"single mode treats the entire input as a single sentence.\n"
"multiple mode filters to multiple sentences in parallel. Each sentence is on\n"
- " a separate line. A separate file is created for each file by appending the\n"
- " 0-indexed line number to the output file name.\n"
+ " a separate line. A separate file is created for each sentence by appending\n"
+ " the 0-indexed line number to the output file name.\n"
"union mode produces one filtered model that is the union of models created by\n"
" multiple mode.\n\n"
"context means only the context (all but last word) has to pass the filter, but\n"
diff --git a/klm/lm/kenlm_max_order_main.cc b/klm/lm/kenlm_max_order_main.cc
deleted file mode 100644
index 94221201..00000000
--- a/klm/lm/kenlm_max_order_main.cc
+++ /dev/null
@@ -1,6 +0,0 @@
-#include "lm/max_order.hh"
-#include <iostream>
-
-int main(int argc, char *argv[]) {
- std::cerr << "KenLM was compiled with a maximum supported n-gram order set to " << KENLM_MAX_ORDER << "." << std::endl;
-}
diff --git a/klm/lm/query_main.cc b/klm/lm/query_main.cc
index 49757d9a..27d3a1a5 100644
--- a/klm/lm/query_main.cc
+++ b/klm/lm/query_main.cc
@@ -2,6 +2,7 @@
int main(int argc, char *argv[]) {
if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) {
+ std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl;
std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl;
std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl;
return 1;