diff options
| author | Paul Baltescu <pauldb89@gmail.com> | 2013-04-24 17:18:10 +0100 | 
|---|---|---|
| committer | Paul Baltescu <pauldb89@gmail.com> | 2013-04-24 17:18:10 +0100 | 
| commit | e8b412577b9d3fe2090b9f48443f919cd268c809 (patch) | |
| tree | b46a7b51d365519dfb5170d71bac33be6d3e29b9 /klm/lm/builder | |
| parent | d189426a7ea56b71eb6e25ed02a7b0993cfb56a8 (diff) | |
| parent | 5aee54869aa19cfe9be965e67a472e94449d16da (diff) | |
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'klm/lm/builder')
| -rw-r--r-- | klm/lm/builder/corpus_count.cc | 82 | ||||
| -rw-r--r-- | klm/lm/builder/corpus_count.hh | 5 | ||||
| -rw-r--r-- | klm/lm/builder/corpus_count_test.cc | 2 | ||||
| -rw-r--r-- | klm/lm/builder/lmplz_main.cc | 17 | ||||
| -rw-r--r-- | klm/lm/builder/pipeline.cc | 7 | ||||
| -rw-r--r-- | klm/lm/builder/pipeline.hh | 9 | ||||
| -rw-r--r-- | klm/lm/builder/print.cc | 74 | ||||
| -rw-r--r-- | klm/lm/builder/print.hh | 3 | 
8 files changed, 94 insertions, 105 deletions
| diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc index abea4ed0..aea93ad1 100644 --- a/klm/lm/builder/corpus_count.cc +++ b/klm/lm/builder/corpus_count.cc @@ -3,6 +3,7 @@  #include "lm/builder/ngram.hh"  #include "lm/lm_exception.hh"  #include "lm/word_index.hh" +#include "util/fake_ofstream.hh"  #include "util/file.hh"  #include "util/file_piece.hh"  #include "util/murmur_hash.hh" @@ -23,39 +24,71 @@ namespace lm {  namespace builder {  namespace { +#pragma pack(push) +#pragma pack(4) +struct VocabEntry { +  typedef uint64_t Key; + +  uint64_t GetKey() const { return key; } +  void SetKey(uint64_t to) { key = to; } + +  uint64_t key; +  lm::WordIndex value; +}; +#pragma pack(pop) + +const float kProbingMultiplier = 1.5; +  class VocabHandout {    public: -    explicit VocabHandout(int fd) { -      util::scoped_fd duped(util::DupOrThrow(fd)); -      word_list_.reset(util::FDOpenOrThrow(duped)); -       +    static std::size_t MemUsage(WordIndex initial_guess) { +      if (initial_guess < 2) initial_guess = 2; +      return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier)); +    } + +    explicit VocabHandout(int fd, WordIndex initial_guess) : +        table_backing_(util::CallocOrThrow(MemUsage(initial_guess))), +        table_(table_backing_.get(), MemUsage(initial_guess)), +        double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)), +        word_list_(fd) {        Lookup("<unk>"); // Force 0        Lookup("<s>"); // Force 1        Lookup("</s>"); // Force 2      }      WordIndex Lookup(const StringPiece &word) { -      uint64_t hashed = util::MurmurHashNative(word.data(), word.size()); -      std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size()))); -      if (ret.second) { -        char null_delimit = 0; -        util::WriteOrThrow(word_list_.get(), word.data(), word.size()); -        util::WriteOrThrow(word_list_.get(), &null_delimit, 1); -        UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words.  Change WordIndex to uint64_t in lm/word_index.hh."); +      VocabEntry entry; +      entry.key = util::MurmurHashNative(word.data(), word.size()); +      entry.value = table_.SizeNoSerialization(); + +      Table::MutableIterator it; +      if (table_.FindOrInsert(entry, it)) +        return it->value; +      word_list_ << word << '\0'; +      UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words.  Change WordIndex to uint64_t in lm/word_index.hh."); +      if (Size() >= double_cutoff_) { +        table_backing_.call_realloc(table_.DoubleTo()); +        table_.Double(table_backing_.get()); +        double_cutoff_ *= 2;        } -      return ret.first->second; +      return entry.value;      }      WordIndex Size() const { -      return seen_.size(); +      return table_.SizeNoSerialization();      }    private: -    typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen; +    // TODO: factor out a resizable probing hash table. +    // TODO: use mremap on linux to get all zeros on resizes. +    util::scoped_malloc table_backing_; -    Seen seen_; +    typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table; +    Table table_; -    util::scoped_FILE word_list_; +    std::size_t double_cutoff_; +     +    util::FakeOFStream word_list_;  };  class DedupeHash : public std::unary_function<const WordIndex *, bool> { @@ -85,6 +118,7 @@ class DedupeEquals : public std::binary_function<const WordIndex *, const WordIn  struct DedupeEntry {    typedef WordIndex *Key;    Key GetKey() const { return key; } +  void SetKey(WordIndex *to) { key = to; }    Key key;    static DedupeEntry Construct(WordIndex *at) {      DedupeEntry ret; @@ -95,8 +129,6 @@ struct DedupeEntry {  typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe; -const float kProbingMultiplier = 1.5; -  class Writer {    public:      Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)  @@ -105,7 +137,7 @@ class Writer {          dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),          buffer_(new WordIndex[order - 1]),          block_size_(position.GetChain().BlockSize()) { -      dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); +      dedupe_.Clear();        assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);        if (order == 1) {          // Add special words.  AdjustCounts is responsible if order != 1.     @@ -149,7 +181,7 @@ class Writer {        }        // Block end.  Need to store the context in a temporary buffer.          std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); -      dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); +      dedupe_.Clear();        block_->SetValidSize(block_size_);        gram_.ReBase((++block_)->Get());        std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin()); @@ -187,18 +219,22 @@ float CorpusCount::DedupeMultiplier(std::size_t order) {    return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));  } +std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { +  return VocabHandout::MemUsage(vocab_estimate); +} +  CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)     : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),      dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),      dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { -  token_count_ = 0; -  type_count_ = 0;  }  void CorpusCount::Run(const util::stream::ChainPosition &position) {    UTIL_TIMER("(%w s) Counted n-grams\n"); -  VocabHandout vocab(vocab_write_); +  VocabHandout vocab(vocab_write_, type_count_); +  token_count_ = 0; +  type_count_ = 0;    const WordIndex end_sentence = vocab.Lookup("</s>");    Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);    uint64_t count = 0; diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh index e255bad1..aa0ed8ed 100644 --- a/klm/lm/builder/corpus_count.hh +++ b/klm/lm/builder/corpus_count.hh @@ -23,6 +23,11 @@ class CorpusCount {      // Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size      static float DedupeMultiplier(std::size_t order); +    // How much memory vocabulary will use based on estimated size of the vocab. +    static std::size_t VocabUsage(std::size_t vocab_estimate); + +    // token_count: out. +    // type_count aka vocabulary size.  Initialize to an estimate.  It is set to the exact value.      CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);      void Run(const util::stream::ChainPosition &position); diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc index 8d53ca9d..6d325ef5 100644 --- a/klm/lm/builder/corpus_count_test.cc +++ b/klm/lm/builder/corpus_count_test.cc @@ -44,7 +44,7 @@ BOOST_AUTO_TEST_CASE(Short) {    util::stream::Chain chain(config);    NGramStream stream;    uint64_t token_count; -  WordIndex type_count; +  WordIndex type_count = 10;    CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize());    chain >> boost::ref(counter) >> stream >> util::stream::kRecycle; diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc index 90b9dca2..1e086dcc 100644 --- a/klm/lm/builder/lmplz_main.cc +++ b/klm/lm/builder/lmplz_main.cc @@ -6,6 +6,7 @@  #include <iostream>  #include <boost/program_options.hpp> +#include <boost/version.hpp>  namespace {  class SizeNotify { @@ -33,13 +34,17 @@ int main(int argc, char *argv[]) {      lm::builder::PipelineConfig pipeline;      options.add_options() -      ("order,o", po::value<std::size_t>(&pipeline.order)->required(), "Order of the model") +      ("order,o", po::value<std::size_t>(&pipeline.order) +#if BOOST_VERSION >= 104200 +         ->required() +#endif +         , "Order of the model")        ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")        ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")        ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") -      ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step")        ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")        ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") +      ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")        ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")        ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")        ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); @@ -68,6 +73,14 @@ int main(int argc, char *argv[]) {      po::store(po::parse_command_line(argc, argv, options), vm);      po::notify(vm); +    // required() appeared in Boost 1.42.0. +#if BOOST_VERSION < 104200 +    if (!vm.count("order")) { +      std::cerr << "the option '--order' is required but missing" << std::endl; +      return 1; +    } +#endif +      util::NormalizeTempPrefix(pipeline.sort.temp_prefix);      lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc index 14a1f721..b89ea6ba 100644 --- a/klm/lm/builder/pipeline.cc +++ b/klm/lm/builder/pipeline.cc @@ -207,17 +207,18 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m    const PipelineConfig &config = master.Config();    std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl; -  UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory()); +  const std::size_t vocab_usage = CorpusCount::VocabUsage(config.vocab_estimate); +  UTIL_THROW_IF(config.TotalMemory() < vocab_usage, util::Exception, "Vocab hash size estimate " << vocab_usage << " exceeds total memory " << config.TotalMemory());    std::size_t memory_for_chain =       // This much memory to work with after vocab hash table. -    static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) / +    static_cast<float>(config.TotalMemory() - vocab_usage) /      // Solve for block size including the dedupe multiplier for one block.      (static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *      // Chain likes memory expressed in terms of total memory.      static_cast<float>(config.block_count);    util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain)); -  WordIndex type_count; +  WordIndex type_count = config.vocab_estimate;    util::FilePiece text(text_file, NULL, &std::cerr);    text_file_name = text.FileName();    CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh index f1d6c5f6..845e5481 100644 --- a/klm/lm/builder/pipeline.hh +++ b/klm/lm/builder/pipeline.hh @@ -3,6 +3,7 @@  #include "lm/builder/initial_probabilities.hh"  #include "lm/builder/header_info.hh" +#include "lm/word_index.hh"  #include "util/stream/config.hh"  #include "util/file_piece.hh" @@ -19,9 +20,9 @@ struct PipelineConfig {    util::stream::ChainConfig read_backoffs;    bool verbose_header; -  // Amount of memory to assume that the vocabulary hash table will use.  This -  // is subtracted from total memory for CorpusCount. -  std::size_t assume_vocab_hash_size; +  // Estimated vocabulary size.  Used for sizing CorpusCount memory and +  // initial probing hash table sizing, also in CorpusCount. +  lm::WordIndex vocab_estimate;    // Minimum block size to tolerate.    std::size_t minimum_block; @@ -33,7 +34,7 @@ struct PipelineConfig {    std::size_t TotalMemory() const { return sort.total_memory; }  }; -// Takes ownership of text_file. +// Takes ownership of text_file and out_arpa.  void Pipeline(PipelineConfig config, int text_file, int out_arpa);  }} // namespaces diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc index b0323221..84bd81ca 100644 --- a/klm/lm/builder/print.cc +++ b/klm/lm/builder/print.cc @@ -1,15 +1,11 @@  #include "lm/builder/print.hh" -#include "util/double-conversion/double-conversion.h" -#include "util/double-conversion/utils.h" +#include "util/fake_ofstream.hh"  #include "util/file.hh"  #include "util/mmap.hh"  #include "util/scoped.hh"  #include "util/stream/timer.hh" -#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE -#include <boost/lexical_cast.hpp> -  #include <sstream>  #include <string.h> @@ -28,71 +24,6 @@ VocabReconstitute::VocabReconstitute(int fd) {    map_.push_back(i);  } -namespace { -class OutputManager { -  public: -    static const std::size_t kOutBuf = 1048576; - -    // Does not take ownership of out. -    explicit OutputManager(int out) -      : buf_(util::MallocOrThrow(kOutBuf)), -        builder_(static_cast<char*>(buf_.get()), kOutBuf), -        // Mostly the default but with inf instead.  And no flags. -        convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0), -        fd_(out) {} - -    ~OutputManager() { -      Flush(); -    } - -    OutputManager &operator<<(float value) { -      // Odd, but this is the largest number found in the comments. -      EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8); -      convert_.ToShortestSingle(value, &builder_); -      return *this; -    } - -    OutputManager &operator<<(StringPiece str) { -      if (str.size() > kOutBuf) { -        Flush(); -        util::WriteOrThrow(fd_, str.data(), str.size()); -      } else { -        EnsureRemaining(str.size()); -        builder_.AddSubstring(str.data(), str.size()); -      } -      return *this; -    } - -    // Inefficient! -    OutputManager &operator<<(unsigned val) { -      return *this << boost::lexical_cast<std::string>(val); -    } - -    OutputManager &operator<<(char c) { -      EnsureRemaining(1); -      builder_.AddCharacter(c); -      return *this; -    } - -    void Flush() { -      util::WriteOrThrow(fd_, buf_.get(), builder_.position()); -      builder_.Reset(); -    } - -  private: -    void EnsureRemaining(std::size_t amount) { -      if (static_cast<std::size_t>(builder_.size() - builder_.position()) < amount) { -        Flush(); -      } -    } - -    util::scoped_malloc buf_; -    double_conversion::StringBuilder builder_; -    double_conversion::DoubleToStringConverter convert_; -    int fd_; -}; -} // namespace -  PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd)     : vocab_(vocab), out_fd_(out_fd) {    std::stringstream stream; @@ -112,8 +43,9 @@ PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t>  }  void PrintARPA::Run(const ChainPositions &positions) { +  util::scoped_fd closer(out_fd_);    UTIL_TIMER("(%w s) Wrote ARPA file\n"); -  OutputManager out(out_fd_); +  util::FakeOFStream out(out_fd_);    for (unsigned order = 1; order <= positions.size(); ++order) {      out << "\\" << order << "-grams:" << '\n';      for (NGramStream stream(positions[order - 1]); stream; ++stream) { diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh index aa932e75..adbbb94a 100644 --- a/klm/lm/builder/print.hh +++ b/klm/lm/builder/print.hh @@ -88,7 +88,8 @@ template <class V> class Print {  class PrintARPA {    public: -    // header_info may be NULL to disable the header +    // header_info may be NULL to disable the header. +    // Takes ownership of out_fd upon Run().      explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd);      void Run(const ChainPositions &positions); | 
