diff options
author | Kenneth Heafield <github@kheafield.com> | 2013-04-24 10:12:41 +0100 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2013-04-24 10:12:41 +0100 |
commit | db960a8bba81df3217660ec5a96d73e0d6baa01b (patch) | |
tree | 7d84cff7fc47fda4ce28ca5164ab74ebf7f6ece8 /klm/lm | |
parent | bf10ad9d1d3a17ae82804f947616db89f41d4f28 (diff) |
KenLM 0831569c3137536165b107c6841603c725dfa2b1
Diffstat (limited to 'klm/lm')
-rw-r--r-- | klm/lm/builder/corpus_count.cc | 82 | ||||
-rw-r--r-- | klm/lm/builder/corpus_count.hh | 5 | ||||
-rw-r--r-- | klm/lm/builder/corpus_count_test.cc | 2 | ||||
-rw-r--r-- | klm/lm/builder/lmplz_main.cc | 17 | ||||
-rw-r--r-- | klm/lm/builder/pipeline.cc | 7 | ||||
-rw-r--r-- | klm/lm/builder/pipeline.hh | 9 | ||||
-rw-r--r-- | klm/lm/builder/print.cc | 74 | ||||
-rw-r--r-- | klm/lm/builder/print.hh | 3 | ||||
-rw-r--r-- | klm/lm/filter/filter_main.cc | 4 | ||||
-rw-r--r-- | klm/lm/kenlm_max_order_main.cc | 6 | ||||
-rw-r--r-- | klm/lm/query_main.cc | 1 |
11 files changed, 97 insertions, 113 deletions
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc index abea4ed0..aea93ad1 100644 --- a/klm/lm/builder/corpus_count.cc +++ b/klm/lm/builder/corpus_count.cc @@ -3,6 +3,7 @@ #include "lm/builder/ngram.hh" #include "lm/lm_exception.hh" #include "lm/word_index.hh" +#include "util/fake_ofstream.hh" #include "util/file.hh" #include "util/file_piece.hh" #include "util/murmur_hash.hh" @@ -23,39 +24,71 @@ namespace lm { namespace builder { namespace { +#pragma pack(push) +#pragma pack(4) +struct VocabEntry { + typedef uint64_t Key; + + uint64_t GetKey() const { return key; } + void SetKey(uint64_t to) { key = to; } + + uint64_t key; + lm::WordIndex value; +}; +#pragma pack(pop) + +const float kProbingMultiplier = 1.5; + class VocabHandout { public: - explicit VocabHandout(int fd) { - util::scoped_fd duped(util::DupOrThrow(fd)); - word_list_.reset(util::FDOpenOrThrow(duped)); - + static std::size_t MemUsage(WordIndex initial_guess) { + if (initial_guess < 2) initial_guess = 2; + return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier)); + } + + explicit VocabHandout(int fd, WordIndex initial_guess) : + table_backing_(util::CallocOrThrow(MemUsage(initial_guess))), + table_(table_backing_.get(), MemUsage(initial_guess)), + double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)), + word_list_(fd) { Lookup("<unk>"); // Force 0 Lookup("<s>"); // Force 1 Lookup("</s>"); // Force 2 } WordIndex Lookup(const StringPiece &word) { - uint64_t hashed = util::MurmurHashNative(word.data(), word.size()); - std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size()))); - if (ret.second) { - char null_delimit = 0; - util::WriteOrThrow(word_list_.get(), word.data(), word.size()); - util::WriteOrThrow(word_list_.get(), &null_delimit, 1); - UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); + VocabEntry entry; + entry.key = util::MurmurHashNative(word.data(), word.size()); + entry.value = table_.SizeNoSerialization(); + + Table::MutableIterator it; + if (table_.FindOrInsert(entry, it)) + return it->value; + word_list_ << word << '\0'; + UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); + if (Size() >= double_cutoff_) { + table_backing_.call_realloc(table_.DoubleTo()); + table_.Double(table_backing_.get()); + double_cutoff_ *= 2; } - return ret.first->second; + return entry.value; } WordIndex Size() const { - return seen_.size(); + return table_.SizeNoSerialization(); } private: - typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen; + // TODO: factor out a resizable probing hash table. + // TODO: use mremap on linux to get all zeros on resizes. + util::scoped_malloc table_backing_; - Seen seen_; + typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table; + Table table_; - util::scoped_FILE word_list_; + std::size_t double_cutoff_; + + util::FakeOFStream word_list_; }; class DedupeHash : public std::unary_function<const WordIndex *, bool> { @@ -85,6 +118,7 @@ class DedupeEquals : public std::binary_function<const WordIndex *, const WordIn struct DedupeEntry { typedef WordIndex *Key; Key GetKey() const { return key; } + void SetKey(WordIndex *to) { key = to; } Key key; static DedupeEntry Construct(WordIndex *at) { DedupeEntry ret; @@ -95,8 +129,6 @@ struct DedupeEntry { typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe; -const float kProbingMultiplier = 1.5; - class Writer { public: Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size) @@ -105,7 +137,7 @@ class Writer { dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)), buffer_(new WordIndex[order - 1]), block_size_(position.GetChain().BlockSize()) { - dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); + dedupe_.Clear(); assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size); if (order == 1) { // Add special words. AdjustCounts is responsible if order != 1. @@ -149,7 +181,7 @@ class Writer { } // Block end. Need to store the context in a temporary buffer. std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); - dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); + dedupe_.Clear(); block_->SetValidSize(block_size_); gram_.ReBase((++block_)->Get()); std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin()); @@ -187,18 +219,22 @@ float CorpusCount::DedupeMultiplier(std::size_t order) { return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order)); } +std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { + return VocabHandout::MemUsage(vocab_estimate); +} + CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { - token_count_ = 0; - type_count_ = 0; } void CorpusCount::Run(const util::stream::ChainPosition &position) { UTIL_TIMER("(%w s) Counted n-grams\n"); - VocabHandout vocab(vocab_write_); + VocabHandout vocab(vocab_write_, type_count_); + token_count_ = 0; + type_count_ = 0; const WordIndex end_sentence = vocab.Lookup("</s>"); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh index e255bad1..aa0ed8ed 100644 --- a/klm/lm/builder/corpus_count.hh +++ b/klm/lm/builder/corpus_count.hh @@ -23,6 +23,11 @@ class CorpusCount { // Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size static float DedupeMultiplier(std::size_t order); + // How much memory vocabulary will use based on estimated size of the vocab. + static std::size_t VocabUsage(std::size_t vocab_estimate); + + // token_count: out. + // type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value. CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block); void Run(const util::stream::ChainPosition &position); diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc index 8d53ca9d..6d325ef5 100644 --- a/klm/lm/builder/corpus_count_test.cc +++ b/klm/lm/builder/corpus_count_test.cc @@ -44,7 +44,7 @@ BOOST_AUTO_TEST_CASE(Short) { util::stream::Chain chain(config); NGramStream stream; uint64_t token_count; - WordIndex type_count; + WordIndex type_count = 10; CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize()); chain >> boost::ref(counter) >> stream >> util::stream::kRecycle; diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc index 90b9dca2..1e086dcc 100644 --- a/klm/lm/builder/lmplz_main.cc +++ b/klm/lm/builder/lmplz_main.cc @@ -6,6 +6,7 @@ #include <iostream> #include <boost/program_options.hpp> +#include <boost/version.hpp> namespace { class SizeNotify { @@ -33,13 +34,17 @@ int main(int argc, char *argv[]) { lm::builder::PipelineConfig pipeline; options.add_options() - ("order,o", po::value<std::size_t>(&pipeline.order)->required(), "Order of the model") + ("order,o", po::value<std::size_t>(&pipeline.order) +#if BOOST_VERSION >= 104200 + ->required() +#endif + , "Order of the model") ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") - ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step") ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") + ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table") ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)") ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); @@ -68,6 +73,14 @@ int main(int argc, char *argv[]) { po::store(po::parse_command_line(argc, argv, options), vm); po::notify(vm); + // required() appeared in Boost 1.42.0. +#if BOOST_VERSION < 104200 + if (!vm.count("order")) { + std::cerr << "the option '--order' is required but missing" << std::endl; + return 1; + } +#endif + util::NormalizeTempPrefix(pipeline.sort.temp_prefix); lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc index 14a1f721..b89ea6ba 100644 --- a/klm/lm/builder/pipeline.cc +++ b/klm/lm/builder/pipeline.cc @@ -207,17 +207,18 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m const PipelineConfig &config = master.Config(); std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl; - UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory()); + const std::size_t vocab_usage = CorpusCount::VocabUsage(config.vocab_estimate); + UTIL_THROW_IF(config.TotalMemory() < vocab_usage, util::Exception, "Vocab hash size estimate " << vocab_usage << " exceeds total memory " << config.TotalMemory()); std::size_t memory_for_chain = // This much memory to work with after vocab hash table. - static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) / + static_cast<float>(config.TotalMemory() - vocab_usage) / // Solve for block size including the dedupe multiplier for one block. (static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) * // Chain likes memory expressed in terms of total memory. static_cast<float>(config.block_count); util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain)); - WordIndex type_count; + WordIndex type_count = config.vocab_estimate; util::FilePiece text(text_file, NULL, &std::cerr); text_file_name = text.FileName(); CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh index f1d6c5f6..845e5481 100644 --- a/klm/lm/builder/pipeline.hh +++ b/klm/lm/builder/pipeline.hh @@ -3,6 +3,7 @@ #include "lm/builder/initial_probabilities.hh" #include "lm/builder/header_info.hh" +#include "lm/word_index.hh" #include "util/stream/config.hh" #include "util/file_piece.hh" @@ -19,9 +20,9 @@ struct PipelineConfig { util::stream::ChainConfig read_backoffs; bool verbose_header; - // Amount of memory to assume that the vocabulary hash table will use. This - // is subtracted from total memory for CorpusCount. - std::size_t assume_vocab_hash_size; + // Estimated vocabulary size. Used for sizing CorpusCount memory and + // initial probing hash table sizing, also in CorpusCount. + lm::WordIndex vocab_estimate; // Minimum block size to tolerate. std::size_t minimum_block; @@ -33,7 +34,7 @@ struct PipelineConfig { std::size_t TotalMemory() const { return sort.total_memory; } }; -// Takes ownership of text_file. +// Takes ownership of text_file and out_arpa. void Pipeline(PipelineConfig config, int text_file, int out_arpa); }} // namespaces diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc index b0323221..84bd81ca 100644 --- a/klm/lm/builder/print.cc +++ b/klm/lm/builder/print.cc @@ -1,15 +1,11 @@ #include "lm/builder/print.hh" -#include "util/double-conversion/double-conversion.h" -#include "util/double-conversion/utils.h" +#include "util/fake_ofstream.hh" #include "util/file.hh" #include "util/mmap.hh" #include "util/scoped.hh" #include "util/stream/timer.hh" -#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE -#include <boost/lexical_cast.hpp> - #include <sstream> #include <string.h> @@ -28,71 +24,6 @@ VocabReconstitute::VocabReconstitute(int fd) { map_.push_back(i); } -namespace { -class OutputManager { - public: - static const std::size_t kOutBuf = 1048576; - - // Does not take ownership of out. - explicit OutputManager(int out) - : buf_(util::MallocOrThrow(kOutBuf)), - builder_(static_cast<char*>(buf_.get()), kOutBuf), - // Mostly the default but with inf instead. And no flags. - convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0), - fd_(out) {} - - ~OutputManager() { - Flush(); - } - - OutputManager &operator<<(float value) { - // Odd, but this is the largest number found in the comments. - EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8); - convert_.ToShortestSingle(value, &builder_); - return *this; - } - - OutputManager &operator<<(StringPiece str) { - if (str.size() > kOutBuf) { - Flush(); - util::WriteOrThrow(fd_, str.data(), str.size()); - } else { - EnsureRemaining(str.size()); - builder_.AddSubstring(str.data(), str.size()); - } - return *this; - } - - // Inefficient! - OutputManager &operator<<(unsigned val) { - return *this << boost::lexical_cast<std::string>(val); - } - - OutputManager &operator<<(char c) { - EnsureRemaining(1); - builder_.AddCharacter(c); - return *this; - } - - void Flush() { - util::WriteOrThrow(fd_, buf_.get(), builder_.position()); - builder_.Reset(); - } - - private: - void EnsureRemaining(std::size_t amount) { - if (static_cast<std::size_t>(builder_.size() - builder_.position()) < amount) { - Flush(); - } - } - - util::scoped_malloc buf_; - double_conversion::StringBuilder builder_; - double_conversion::DoubleToStringConverter convert_; - int fd_; -}; -} // namespace - PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd) : vocab_(vocab), out_fd_(out_fd) { std::stringstream stream; @@ -112,8 +43,9 @@ PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> } void PrintARPA::Run(const ChainPositions &positions) { + util::scoped_fd closer(out_fd_); UTIL_TIMER("(%w s) Wrote ARPA file\n"); - OutputManager out(out_fd_); + util::FakeOFStream out(out_fd_); for (unsigned order = 1; order <= positions.size(); ++order) { out << "\\" << order << "-grams:" << '\n'; for (NGramStream stream(positions[order - 1]); stream; ++stream) { diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh index aa932e75..adbbb94a 100644 --- a/klm/lm/builder/print.hh +++ b/klm/lm/builder/print.hh @@ -88,7 +88,8 @@ template <class V> class Print { class PrintARPA { public: - // header_info may be NULL to disable the header + // header_info may be NULL to disable the header. + // Takes ownership of out_fd upon Run(). explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd); void Run(const ChainPositions &positions); diff --git a/klm/lm/filter/filter_main.cc b/klm/lm/filter/filter_main.cc index 1a4ba84f..1736bc40 100644 --- a/klm/lm/filter/filter_main.cc +++ b/klm/lm/filter/filter_main.cc @@ -25,8 +25,8 @@ void DisplayHelp(const char *name) { " parser.\n" "single mode treats the entire input as a single sentence.\n" "multiple mode filters to multiple sentences in parallel. Each sentence is on\n" - " a separate line. A separate file is created for each file by appending the\n" - " 0-indexed line number to the output file name.\n" + " a separate line. A separate file is created for each sentence by appending\n" + " the 0-indexed line number to the output file name.\n" "union mode produces one filtered model that is the union of models created by\n" " multiple mode.\n\n" "context means only the context (all but last word) has to pass the filter, but\n" diff --git a/klm/lm/kenlm_max_order_main.cc b/klm/lm/kenlm_max_order_main.cc deleted file mode 100644 index 94221201..00000000 --- a/klm/lm/kenlm_max_order_main.cc +++ /dev/null @@ -1,6 +0,0 @@ -#include "lm/max_order.hh" -#include <iostream> - -int main(int argc, char *argv[]) { - std::cerr << "KenLM was compiled with a maximum supported n-gram order set to " << KENLM_MAX_ORDER << "." << std::endl; -} diff --git a/klm/lm/query_main.cc b/klm/lm/query_main.cc index 49757d9a..27d3a1a5 100644 --- a/klm/lm/query_main.cc +++ b/klm/lm/query_main.cc @@ -2,6 +2,7 @@ int main(int argc, char *argv[]) { if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) { + std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl; std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl; std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl; return 1; |