diff options
Diffstat (limited to 'klm/lm/builder')
25 files changed, 2252 insertions, 0 deletions
| diff --git a/klm/lm/builder/Makefile.am b/klm/lm/builder/Makefile.am new file mode 100644 index 00000000..b5c147fd --- /dev/null +++ b/klm/lm/builder/Makefile.am @@ -0,0 +1,28 @@ +bin_PROGRAMS = builder + +builder_SOURCES = \ +  main.cc \ +  adjust_counts.cc \ +  adjust_counts.hh \ +  corpus_count.cc \ +  corpus_count.hh \ +  discount.hh \ +  header_info.hh \ +  initial_probabilities.cc \ +  initial_probabilities.hh \ +  interpolate.cc \ +  interpolate.hh \ +  joint_order.hh \ +  multi_stream.hh \ +  ngram.hh \ +  ngram_stream.hh \ +  pipeline.cc \ +  pipeline.hh \ +  print.cc \ +  print.hh \ +  sort.hh + +builder_LDADD = ../libklm.a ../../util/double-conversion/libklm_util_double.a ../../util/stream/libklm_util_stream.a ../../util/libklm_util.a $(BOOST_THREAD_LIBS) + +AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm + diff --git a/klm/lm/builder/README.md b/klm/lm/builder/README.md new file mode 100644 index 00000000..be0d35e2 --- /dev/null +++ b/klm/lm/builder/README.md @@ -0,0 +1,47 @@ +Dependencies +============ + +Boost >= 1.42.0 is required.   + +For Ubuntu, +```bash +sudo apt-get install libboost1.48-all-dev +``` + +Alternatively, you can download, compile, and install it yourself: + +```bash +wget http://sourceforge.net/projects/boost/files/boost/1.52.0/boost_1_52_0.tar.gz/download -O boost_1_52_0.tar.gz +tar -xvzf boost_1_52_0.tar.gz +cd boost_1_52_0 +./bootstrap.sh +./b2 +sudo ./b2 install +``` + +Local install options (in a user-space prefix directory) are also possible. See http://www.boost.org/doc/libs/1_52_0/doc/html/bbv2/installation.html. + + +Building +======== + +```bash +bjam +``` +Your distribution might package bjam and boost-build separately from Boost.  Both are required.    + +Usage +===== + +Run +```bash +$ bin/lmplz +``` +to see command line arguments + +Running +======= + +```bash +bin/lmplz -o 5 <text >text.arpa +``` diff --git a/klm/lm/builder/TODO b/klm/lm/builder/TODO new file mode 100644 index 00000000..cb5aef3a --- /dev/null +++ b/klm/lm/builder/TODO @@ -0,0 +1,5 @@ +More tests! +Sharding. +Some way to manage all the crazy config options. +Option to build the binary file directly.   +Interpolation of different orders.   diff --git a/klm/lm/builder/adjust_counts.cc b/klm/lm/builder/adjust_counts.cc new file mode 100644 index 00000000..a6f48011 --- /dev/null +++ b/klm/lm/builder/adjust_counts.cc @@ -0,0 +1,216 @@ +#include "lm/builder/adjust_counts.hh" +#include "lm/builder/multi_stream.hh" +#include "util/stream/timer.hh" + +#include <algorithm> + +namespace lm { namespace builder { + +BadDiscountException::BadDiscountException() throw() {} +BadDiscountException::~BadDiscountException() throw() {} + +namespace { +// Return last word in full that is different.   +const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) { +  const WordIndex *cur_word = full.end() - 1; +  const WordIndex *pre_word = lower_last.end() - 1; +  // Find last difference.   +  for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {} +  return cur_word; +} + +class StatCollector { +  public: +    StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts)  +      : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) { +      memset(&orders_[0], 0, sizeof(OrderStat) * order); +    } + +    ~StatCollector() {} + +    void CalculateDiscounts() { +      counts_.resize(orders_.size()); +      discounts_.resize(orders_.size()); +      for (std::size_t i = 0; i < orders_.size(); ++i) { +        const OrderStat &s = orders_[i]; +        counts_[i] = s.count; + +        for (unsigned j = 1; j < 4; ++j) { +          // TODO: Specialize error message for j == 3, meaning 3+ +          UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for " +              << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any " +              << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?"); +        } + +        // See equation (26) in Chen and Goodman. +        discounts_[i].amount[0] = 0.0; +        float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]); +        for (unsigned j = 1; j < 4; ++j) { +          discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]); +          UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]); +        } +      } +    } + +    void Add(std::size_t order_minus_1, uint64_t count) { +      OrderStat &stat = orders_[order_minus_1]; +      ++stat.count; +      if (count < 5) ++stat.n[count]; +    } + +    void AddFull(uint64_t count) { +      ++full_.count; +      if (count < 5) ++full_.n[count]; +    } + +  private: +    struct OrderStat { +      // n_1 in equation 26 of Chen and Goodman etc +      uint64_t n[5]; +      uint64_t count; +    }; + +    std::vector<OrderStat> orders_; +    OrderStat &full_; + +    std::vector<uint64_t> &counts_; +    std::vector<Discount> &discounts_; +}; + +// Reads all entries in order like NGramStream does.   +// But deletes any entries that have <s> in the 1st (not 0th) position on the +// way out by putting other entries in their place.  This disrupts the sort +// order but we don't care because the data is going to be sorted again.   +class CollapseStream { +  public: +    CollapseStream(const util::stream::ChainPosition &position) : +      current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), +      block_(position) { +      StartBlock(); +    } + +    const NGram &operator*() const { return current_; } +    const NGram *operator->() const { return ¤t_; } + +    operator bool() const { return block_; } + +    CollapseStream &operator++() { +      assert(block_); +      if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) { +        memcpy(current_.Base(), copy_from_, current_.TotalSize()); +        UpdateCopyFrom(); +      } +      current_.NextInMemory(); +      uint8_t *block_base = static_cast<uint8_t*>(block_->Get()); +      if (current_.Base() == block_base + block_->ValidSize()) { +        block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base); +        ++block_; +        StartBlock(); +      } +      return *this; +    } + +  private: +    void StartBlock() { +      for (; ; ++block_) { +        if (!block_) return; +        if (block_->ValidSize()) break; +      } +      current_.ReBase(block_->Get()); +      copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize(); +      UpdateCopyFrom(); +    } + +    // Find last without bos.   +    void UpdateCopyFrom() { +      for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) { +        if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break; +      } +    } + +    NGram current_; + +    // Goes backwards in the block +    uint8_t *copy_from_; + +    util::stream::Link block_; +}; + +} // namespace + +void AdjustCounts::Run(const ChainPositions &positions) { +  UTIL_TIMER("(%w s) Adjusted counts\n"); + +  const std::size_t order = positions.size(); +  StatCollector stats(order, counts_, discounts_); +  if (order == 1) { +    // Only unigrams.  Just collect stats.   +    for (NGramStream full(positions[0]); full; ++full)  +      stats.AddFull(full->Count()); +    stats.CalculateDiscounts(); +    return; +  } + +  NGramStreams streams; +  streams.Init(positions, positions.size() - 1); +  CollapseStream full(positions[positions.size() - 1]); + +  // Initialization: <unk> has count 0 and so does <s>.   +  NGramStream *lower_valid = streams.begin(); +  streams[0]->Count() = 0; +  *streams[0]->begin() = kUNK; +  stats.Add(0, 0); +  (++streams[0])->Count() = 0; +  *streams[0]->begin() = kBOS; +  // not in stats because it will get put in later.   + +  // iterate over full (the stream of the highest order ngrams) +  for (; full; ++full) { +    const WordIndex *different = FindDifference(*full, **lower_valid); +    std::size_t same = full->end() - 1 - different; +    // Increment the adjusted count.   +    if (same) ++streams[same - 1]->Count(); + +    // Output all the valid ones that changed.   +    for (; lower_valid >= &streams[same]; --lower_valid) { +      stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count()); +      ++*lower_valid; +    } + +    // This is here because bos is also const WordIndex *, so copy gets +    // consistent argument types.   +    const WordIndex *full_end = full->end(); +    // Initialize and mark as valid up to bos.   +    const WordIndex *bos; +    for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) { +      ++lower_valid; +      std::copy(bos, full_end, (*lower_valid)->begin()); +      (*lower_valid)->Count() = 1; +    } +    // Now bos indicates where <s> is or is the 0th word of full.   +    if (bos != full->begin()) { +      // There is an <s> beyond the 0th word.   +      NGramStream &to = *++lower_valid; +      std::copy(bos, full_end, to->begin()); +      to->Count() = full->Count();   +    } else { +      stats.AddFull(full->Count()); +    } +    assert(lower_valid >= &streams[0]); +  } + +  // Output everything valid. +  for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) { +    stats.Add(s - streams.begin(), (*s)->Count()); +    ++*s; +  } +  // Poison everyone!  Except the N-grams which were already poisoned by the input.    +  for (NGramStream *s = streams.begin(); s != streams.end(); ++s) +    s->Poison(); + +  stats.CalculateDiscounts(); + +  // NOTE: See special early-return case for unigrams near the top of this function +} + +}} // namespaces diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh new file mode 100644 index 00000000..f38ff79d --- /dev/null +++ b/klm/lm/builder/adjust_counts.hh @@ -0,0 +1,44 @@ +#ifndef LM_BUILDER_ADJUST_COUNTS__ +#define LM_BUILDER_ADJUST_COUNTS__ + +#include "lm/builder/discount.hh" +#include "util/exception.hh" + +#include <vector> + +#include <stdint.h> + +namespace lm { +namespace builder { + +class ChainPositions; + +class BadDiscountException : public util::Exception { +  public: +    BadDiscountException() throw(); +    ~BadDiscountException() throw(); +}; + +/* Compute adjusted counts.   + * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. + * Output: [1,N]-grams with adjusted counts.   + * [1,N)-grams are in suffix order + * N-grams are in undefined order (they're going to be sorted anyway). + */ +class AdjustCounts { +  public: +    AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts) +      : counts_(counts), discounts_(discounts) {} + +    void Run(const ChainPositions &positions); + +  private: +    std::vector<uint64_t> &counts_; +    std::vector<Discount> &discounts_; +}; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_ADJUST_COUNTS__ + diff --git a/klm/lm/builder/adjust_counts_test.cc b/klm/lm/builder/adjust_counts_test.cc new file mode 100644 index 00000000..68b5f33e --- /dev/null +++ b/klm/lm/builder/adjust_counts_test.cc @@ -0,0 +1,106 @@ +#include "lm/builder/adjust_counts.hh" + +#include "lm/builder/multi_stream.hh" +#include "util/scoped.hh" + +#include <boost/thread/thread.hpp> +#define BOOST_TEST_MODULE AdjustCounts +#include <boost/test/unit_test.hpp> + +namespace lm { namespace builder { namespace { + +class KeepCopy { +  public: +    KeepCopy() : size_(0) {} + +    void Run(const util::stream::ChainPosition &position) { +      for (util::stream::Link link(position); link; ++link) { +        mem_.call_realloc(size_ + link->ValidSize()); +        memcpy(static_cast<uint8_t*>(mem_.get()) + size_, link->Get(), link->ValidSize()); +        size_ += link->ValidSize(); +      } +    } + +    uint8_t *Get() { return static_cast<uint8_t*>(mem_.get()); } +    std::size_t Size() const { return size_; } + +  private: +    util::scoped_malloc mem_; +    std::size_t size_; +}; + +struct Gram4 { +  WordIndex ids[4]; +  uint64_t count; +}; + +class WriteInput { +  public: +    void Run(const util::stream::ChainPosition &position) { +      NGramStream input(position); +      Gram4 grams[] = { +        {{0,0,0,0},10}, +        {{0,0,3,0},3}, +        // bos +        {{1,1,1,2},5}, +        {{0,0,3,2},5}, +      }; +      for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) { +        memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4); +        input->Count() = grams[i].count; +      } +      input.Poison(); +    } +}; + +BOOST_AUTO_TEST_CASE(Simple) { +  KeepCopy outputs[4]; +  std::vector<uint64_t> counts; +  std::vector<Discount> discount; +  { +    util::stream::ChainConfig config; +    config.total_memory = 100; +    config.block_count = 1; +    Chains chains(4); +    for (unsigned i = 0; i < 4; ++i) { +      config.entry_size = NGram::TotalSize(i + 1); +      chains.push_back(config); +    } + +    chains[3] >> WriteInput(); +    ChainPositions for_adjust(chains); +    for (unsigned i = 0; i < 4; ++i) { +      chains[i] >> boost::ref(outputs[i]); +    } +    chains >> util::stream::kRecycle; +    BOOST_CHECK_THROW(AdjustCounts(counts, discount).Run(for_adjust), BadDiscountException); +  } +  BOOST_REQUIRE_EQUAL(4UL, counts.size()); +  BOOST_CHECK_EQUAL(4UL, counts[0]); +  // These are no longer set because the discounts are bad.   +/*  BOOST_CHECK_EQUAL(4UL, counts[1]); +  BOOST_CHECK_EQUAL(3UL, counts[2]); +  BOOST_CHECK_EQUAL(3UL, counts[3]);*/ +  BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size()); +  NGram uni(outputs[0].Get(), 1); +  BOOST_CHECK_EQUAL(kUNK, *uni.begin()); +  BOOST_CHECK_EQUAL(0ULL, uni.Count()); +  uni.NextInMemory(); +  BOOST_CHECK_EQUAL(kBOS, *uni.begin()); +  BOOST_CHECK_EQUAL(0ULL, uni.Count()); +  uni.NextInMemory(); +  BOOST_CHECK_EQUAL(0UL, *uni.begin()); +  BOOST_CHECK_EQUAL(2ULL, uni.Count()); +  uni.NextInMemory(); +  BOOST_CHECK_EQUAL(2ULL, uni.Count()); +  BOOST_CHECK_EQUAL(2UL, *uni.begin()); + +  BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size()); +  NGram bi(outputs[1].Get(), 2); +  BOOST_CHECK_EQUAL(0UL, *bi.begin()); +  BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1)); +  BOOST_CHECK_EQUAL(1ULL, bi.Count()); +  bi.NextInMemory(); +} + +}}} // namespaces diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc new file mode 100644 index 00000000..abea4ed0 --- /dev/null +++ b/klm/lm/builder/corpus_count.cc @@ -0,0 +1,224 @@ +#include "lm/builder/corpus_count.hh" + +#include "lm/builder/ngram.hh" +#include "lm/lm_exception.hh" +#include "lm/word_index.hh" +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/murmur_hash.hh" +#include "util/probing_hash_table.hh" +#include "util/scoped.hh" +#include "util/stream/chain.hh" +#include "util/stream/timer.hh" +#include "util/tokenize_piece.hh" + +#include <boost/unordered_set.hpp> +#include <boost/unordered_map.hpp> + +#include <functional> + +#include <stdint.h> + +namespace lm { +namespace builder { +namespace { + +class VocabHandout { +  public: +    explicit VocabHandout(int fd) { +      util::scoped_fd duped(util::DupOrThrow(fd)); +      word_list_.reset(util::FDOpenOrThrow(duped)); +       +      Lookup("<unk>"); // Force 0 +      Lookup("<s>"); // Force 1 +      Lookup("</s>"); // Force 2 +    } + +    WordIndex Lookup(const StringPiece &word) { +      uint64_t hashed = util::MurmurHashNative(word.data(), word.size()); +      std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size()))); +      if (ret.second) { +        char null_delimit = 0; +        util::WriteOrThrow(word_list_.get(), word.data(), word.size()); +        util::WriteOrThrow(word_list_.get(), &null_delimit, 1); +        UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words.  Change WordIndex to uint64_t in lm/word_index.hh."); +      } +      return ret.first->second; +    } + +    WordIndex Size() const { +      return seen_.size(); +    } + +  private: +    typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen; + +    Seen seen_; + +    util::scoped_FILE word_list_; +}; + +class DedupeHash : public std::unary_function<const WordIndex *, bool> { +  public: +    explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} + +    std::size_t operator()(const WordIndex *start) const { +      return util::MurmurHashNative(start, size_); +    } +     +  private: +    const std::size_t size_; +}; + +class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> { +  public: +    explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {} +     +    bool operator()(const WordIndex *first, const WordIndex *second) const { +      return !memcmp(first, second, size_); +    }  +     +  private: +    const std::size_t size_; +}; + +struct DedupeEntry { +  typedef WordIndex *Key; +  Key GetKey() const { return key; } +  Key key; +  static DedupeEntry Construct(WordIndex *at) { +    DedupeEntry ret; +    ret.key = at; +    return ret; +  } +}; + +typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe; + +const float kProbingMultiplier = 1.5; + +class Writer { +  public: +    Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)  +      : block_(position), gram_(block_->Get(), order), +        dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()), +        dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)), +        buffer_(new WordIndex[order - 1]), +        block_size_(position.GetChain().BlockSize()) { +      dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); +      assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size); +      if (order == 1) { +        // Add special words.  AdjustCounts is responsible if order != 1.     +        AddUnigramWord(kUNK); +        AddUnigramWord(kBOS); +      } +    } + +    ~Writer() { +      block_->SetValidSize(reinterpret_cast<const uint8_t*>(gram_.begin()) - static_cast<const uint8_t*>(block_->Get())); +      (++block_).Poison(); +    } + +    // Write context with a bunch of <s> +    void StartSentence() { +      for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) { +        *i = kBOS; +      } +    } + +    void Append(WordIndex word) { +      *(gram_.end() - 1) = word; +      Dedupe::MutableIterator at; +      bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at); +      if (found) { +        // Already present. +        NGram already(at->key, gram_.Order()); +        ++(already.Count()); +        // Shift left by one. +        memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1)); +        return; +      } +      // Complete the write.   +      gram_.Count() = 1; +      // Prepare the next n-gram.   +      if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) { +        NGram last(gram_); +        gram_.NextInMemory(); +        std::copy(last.begin() + 1, last.end(), gram_.begin()); +        return; +      } +      // Block end.  Need to store the context in a temporary buffer.   +      std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); +      dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); +      block_->SetValidSize(block_size_); +      gram_.ReBase((++block_)->Get()); +      std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin()); +    } + +  private: +    void AddUnigramWord(WordIndex index) { +      *gram_.begin() = index; +      gram_.Count() = 0; +      gram_.NextInMemory(); +      if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) { +        block_->SetValidSize(block_size_); +        gram_.ReBase((++block_)->Get()); +      } +    } + +    util::stream::Link block_; + +    NGram gram_; + +    // This is the memory behind the invalid value in dedupe_. +    std::vector<WordIndex> dedupe_invalid_; +    // Hash table combiner implementation. +    Dedupe dedupe_; + +    // Small buffer to hold existing ngrams when shifting across a block boundary.   +    boost::scoped_array<WordIndex> buffer_; + +    const std::size_t block_size_; +}; + +} // namespace + +float CorpusCount::DedupeMultiplier(std::size_t order) { +  return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order)); +} + +CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)  +  : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), +    dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), +    dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { +  token_count_ = 0; +  type_count_ = 0; +} + +void CorpusCount::Run(const util::stream::ChainPosition &position) { +  UTIL_TIMER("(%w s) Counted n-grams\n"); + +  VocabHandout vocab(vocab_write_); +  const WordIndex end_sentence = vocab.Lookup("</s>"); +  Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); +  uint64_t count = 0; +  StringPiece delimiters("\0\t\r ", 4); +  try { +    while(true) { +      StringPiece line(from_.ReadLine()); +      writer.StartSentence(); +      for (util::TokenIter<util::AnyCharacter, true> w(line, delimiters); w; ++w) { +        WordIndex word = vocab.Lookup(*w); +        UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus.  I plan to support models containing <unk> in the future."); +        writer.Append(word); +        ++count; +      } +      writer.Append(end_sentence); +    } +  } catch (const util::EndOfFileException &e) {} +  token_count_ = count; +  type_count_ = vocab.Size(); +} + +} // namespace builder +} // namespace lm diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh new file mode 100644 index 00000000..e255bad1 --- /dev/null +++ b/klm/lm/builder/corpus_count.hh @@ -0,0 +1,42 @@ +#ifndef LM_BUILDER_CORPUS_COUNT__ +#define LM_BUILDER_CORPUS_COUNT__ + +#include "lm/word_index.hh" +#include "util/scoped.hh" + +#include <cstddef> +#include <string> +#include <stdint.h> + +namespace util { +class FilePiece; +namespace stream { +class ChainPosition; +} // namespace stream +} // namespace util + +namespace lm { +namespace builder { + +class CorpusCount { +  public: +    // Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size +    static float DedupeMultiplier(std::size_t order); + +    CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block); + +    void Run(const util::stream::ChainPosition &position); + +  private: +    util::FilePiece &from_; +    int vocab_write_; +    uint64_t &token_count_; +    WordIndex &type_count_; + +    std::size_t dedupe_mem_size_; +    util::scoped_malloc dedupe_mem_; +}; + +} // namespace builder +} // namespace lm +#endif // LM_BUILDER_CORPUS_COUNT__ diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc new file mode 100644 index 00000000..8d53ca9d --- /dev/null +++ b/klm/lm/builder/corpus_count_test.cc @@ -0,0 +1,76 @@ +#include "lm/builder/corpus_count.hh" + +#include "lm/builder/ngram.hh" +#include "lm/builder/ngram_stream.hh" + +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/tokenize_piece.hh" +#include "util/stream/chain.hh" +#include "util/stream/stream.hh" + +#define BOOST_TEST_MODULE CorpusCountTest +#include <boost/test/unit_test.hpp> + +namespace lm { namespace builder { namespace { + +#define Check(str, count) { \ +  BOOST_REQUIRE(stream); \ +  w = stream->begin(); \ +  for (util::TokenIter<util::AnyCharacter, true> t(str, " "); t; ++t, ++w) { \ +    BOOST_CHECK_EQUAL(*t, v[*w]); \ +  } \ +  BOOST_CHECK_EQUAL((uint64_t)count, stream->Count()); \ +  ++stream; \ +} + +BOOST_AUTO_TEST_CASE(Short) { +  util::scoped_fd input_file(util::MakeTemp("corpus_count_test_temp")); +  const char input[] = "looking on a little more loin\non a little more loin\non foo little more loin\nbar\n\n"; +  // Blocks of 10 are +  // looking on a little more loin </s> on a little[duplicate] more[duplicate] loin[duplicate] </s>[duplicate] on[duplicate] foo +  // little more loin </s> bar </s> </s> + +  util::WriteOrThrow(input_file.get(), input, sizeof(input) - 1); +  util::FilePiece input_piece(input_file.release(), "temp file"); + +  util::stream::ChainConfig config; +  config.entry_size = NGram::TotalSize(3); +  config.total_memory = config.entry_size * 20; +  config.block_count = 2; + +  util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab")); + +  util::stream::Chain chain(config); +  NGramStream stream; +  uint64_t token_count; +  WordIndex type_count; +  CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize()); +  chain >> boost::ref(counter) >> stream >> util::stream::kRecycle; + +  const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"}; + +  WordIndex *w; + +  Check("<s> <s> looking", 1); +  Check("<s> looking on", 1); +  Check("looking on a", 1); +  Check("on a little", 2); +  Check("a little more", 2); +  Check("little more loin", 2); +  Check("more loin </s>", 2); +  Check("<s> <s> on", 2); +  Check("<s> on a", 1); +  Check("<s> on foo", 1); +  Check("on foo little", 1); +  Check("foo little more", 1); +  Check("little more loin", 1); +  Check("more loin </s>", 1); +  Check("<s> <s> bar", 1); +  Check("<s> bar </s>", 1); +  Check("<s> <s> </s>", 1); +  BOOST_CHECK(!stream); +  BOOST_CHECK_EQUAL(sizeof(v) / sizeof(const char*), type_count); +} + +}}} // namespaces diff --git a/klm/lm/builder/discount.hh b/klm/lm/builder/discount.hh new file mode 100644 index 00000000..754fb20d --- /dev/null +++ b/klm/lm/builder/discount.hh @@ -0,0 +1,26 @@ +#ifndef BUILDER_DISCOUNT__ +#define BUILDER_DISCOUNT__ + +#include <algorithm> + +#include <inttypes.h> + +namespace lm { +namespace builder { + +struct Discount { +  float amount[4]; + +  float Get(uint64_t count) const { +    return amount[std::min<uint64_t>(count, 3)]; +  } + +  float Apply(uint64_t count) const { +    return static_cast<float>(count) - Get(count); +  } +}; + +} // namespace builder +} // namespace lm + +#endif // BUILDER_DISCOUNT__ diff --git a/klm/lm/builder/header_info.hh b/klm/lm/builder/header_info.hh new file mode 100644 index 00000000..ccca1456 --- /dev/null +++ b/klm/lm/builder/header_info.hh @@ -0,0 +1,20 @@ +#ifndef LM_BUILDER_HEADER_INFO__ +#define LM_BUILDER_HEADER_INFO__ + +#include <string> +#include <stdint.h> + +// Some configuration info that is used to add +// comments to the beginning of an ARPA file +struct HeaderInfo { +  const std::string input_file; +  const uint64_t token_count; + +  HeaderInfo(const std::string& input_file_in, uint64_t token_count_in) +    : input_file(input_file_in), token_count(token_count_in) {} + +  // TODO: Add smoothing type +  // TODO: More info if multiple models were interpolated +}; + +#endif diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc new file mode 100644 index 00000000..58b42a20 --- /dev/null +++ b/klm/lm/builder/initial_probabilities.cc @@ -0,0 +1,136 @@ +#include "lm/builder/initial_probabilities.hh" + +#include "lm/builder/discount.hh" +#include "lm/builder/ngram_stream.hh" +#include "lm/builder/sort.hh" +#include "util/file.hh" +#include "util/stream/chain.hh" +#include "util/stream/io.hh" +#include "util/stream/stream.hh" + +#include <vector> + +namespace lm { namespace builder { + +namespace { +struct BufferEntry { +  // Gamma from page 20 of Chen and Goodman.   +  float gamma; +  // \sum_w a(c w) for all w.   +  float denominator; +}; + +// Extract an array of gamma from an array of BufferEntry.   +class OnlyGamma { +  public: +    void Run(const util::stream::ChainPosition &position) { +      for (util::stream::Link block_it(position); block_it; ++block_it) { +        float *out = static_cast<float*>(block_it->Get()); +        const float *in = out; +        const float *end = static_cast<const float*>(block_it->ValidEnd()); +        for (out += 1, in += 2; in < end; out += 1, in += 2) { +          *out = *in; +        } +        block_it->SetValidSize(block_it->ValidSize() / 2); +      } +    } +}; + +class AddRight { +  public: +    AddRight(const Discount &discount, const util::stream::ChainPosition &input)  +      : discount_(discount), input_(input) {} + +    void Run(const util::stream::ChainPosition &output) { +      NGramStream in(input_); +      util::stream::Stream out(output); + +      std::vector<WordIndex> previous(in->Order() - 1); +      const std::size_t size = sizeof(WordIndex) * previous.size(); +      for(; in; ++out) { +        memcpy(&previous[0], in->begin(), size); +        uint64_t denominator = 0; +        uint64_t counts[4]; +        memset(counts, 0, sizeof(counts)); +        do { +          denominator += in->Count(); +          ++counts[std::min(in->Count(), static_cast<uint64_t>(3))]; +        } while (++in && !memcmp(&previous[0], in->begin(), size)); +        BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get()); +        entry.denominator = static_cast<float>(denominator); +        entry.gamma = 0.0; +        for (unsigned i = 1; i <= 3; ++i) { +          entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]); +        } +        entry.gamma /= entry.denominator; +      } +      out.Poison(); +    } + +  private: +    const Discount &discount_; +    const util::stream::ChainPosition input_; +}; + +class MergeRight { +  public: +    MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount) +      : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {} + +    // calculate the initial probability of each n-gram (before order-interpolation) +    // Run() gets invoked once for each order +    void Run(const util::stream::ChainPosition &primary) { +      util::stream::Stream summed(from_adder_); + +      NGramStream grams(primary); + +      // Without interpolation, the interpolation weight goes to <unk>. +      if (grams->Order() == 1 && !interpolate_unigrams_) { +        BufferEntry sums(*static_cast<const BufferEntry*>(summed.Get())); +        assert(*grams->begin() == kUNK); +        grams->Value().uninterp.prob = sums.gamma; +        grams->Value().uninterp.gamma = 0.0; +        while (++grams) { +          grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator; +          grams->Value().uninterp.gamma = 0.0; +        } +        ++summed; +        return; +      } + +      std::vector<WordIndex> previous(grams->Order() - 1); +      const std::size_t size = sizeof(WordIndex) * previous.size(); +      for (; grams; ++summed) { +        memcpy(&previous[0], grams->begin(), size); +        const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get()); +        do { +          Payload &pay = grams->Value(); +          pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator; +          pay.uninterp.gamma = sums.gamma; +        } while (++grams && !memcmp(&previous[0], grams->begin(), size)); +      } +    } + +  private: +    bool interpolate_unigrams_; +    util::stream::ChainPosition from_adder_; +    Discount discount_; +}; + +} // namespace + +void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) { +  util::stream::ChainConfig gamma_config = config.adder_out; +  gamma_config.entry_size = sizeof(BufferEntry); +  for (size_t i = 0; i < primary.size(); ++i) { +    util::stream::ChainPosition second(second_in[i].Add()); +    second_in[i] >> util::stream::kRecycle; +    gamma_out.push_back(gamma_config); +    gamma_out[i] >> AddRight(discounts[i], second); +    primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]); +    // Don't bother with the OnlyGamma thread for something to discard.   +    if (i) gamma_out[i] >> OnlyGamma(); +  } +} + +}} // namespaces diff --git a/klm/lm/builder/initial_probabilities.hh b/klm/lm/builder/initial_probabilities.hh new file mode 100644 index 00000000..626388eb --- /dev/null +++ b/klm/lm/builder/initial_probabilities.hh @@ -0,0 +1,34 @@ +#ifndef LM_BUILDER_INITIAL_PROBABILITIES__ +#define LM_BUILDER_INITIAL_PROBABILITIES__ + +#include "lm/builder/discount.hh" +#include "util/stream/config.hh" + +#include <vector> + +namespace lm { +namespace builder { +class Chains; + +struct InitialProbabilitiesConfig { +  // These should be small buffers to keep the adder from getting too far ahead +  util::stream::ChainConfig adder_in; +  util::stream::ChainConfig adder_out; +  // SRILM doesn't normally interpolate unigrams.   +  bool interpolate_unigrams; +}; + +/* Compute initial (uninterpolated) probabilities + * primary: the normal chain of n-grams.  Incoming is context sorted adjusted + *   counts.  Outgoing has uninterpolated probabilities for use by Interpolate. + * second_in: a second copy of the primary input.  Discard the output.   + * gamma_out: Computed gamma values are output on these chains in suffix order. + *   The values are bare floats and should be buffered for interpolation to + *   use.   + */ +void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out); + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_INITIAL_PROBABILITIES__ diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc new file mode 100644 index 00000000..50026806 --- /dev/null +++ b/klm/lm/builder/interpolate.cc @@ -0,0 +1,65 @@ +#include "lm/builder/interpolate.hh" + +#include "lm/builder/joint_order.hh" +#include "lm/builder/multi_stream.hh" +#include "lm/builder/sort.hh" +#include "lm/lm_exception.hh" + +#include <assert.h> + +namespace lm { namespace builder { +namespace { + +class Callback { +  public: +    Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) { +      probs_[0] = uniform_prob; +      for (std::size_t i = 0; i < backoffs.size(); ++i) { +        backoffs_.push_back(backoffs[i]); +      } +    } + +    ~Callback() { +      for (std::size_t i = 0; i < backoffs_.size(); ++i) { +        if (backoffs_[i]) { +          std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl; +          abort(); +        } +      } +    } + +    void Enter(unsigned order_minus_1, NGram &gram) { +      Payload &pay = gram.Value(); +      pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; +      probs_[order_minus_1 + 1] = pay.complete.prob; +      pay.complete.prob = log10(pay.complete.prob); +      // TODO: this is a hack to skip n-grams that don't appear as context.  Pruning will require some different handling.   +      if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { +        pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get())); +        ++backoffs_[order_minus_1]; +      } else { +        // Not a context.   +        pay.complete.backoff = 0.0; +      } +    } + +    void Exit(unsigned, const NGram &) const {} + +  private: +    FixedArray<util::stream::Stream> backoffs_; + +    std::vector<float> probs_; +}; +} // namespace + +Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)  +  : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {} + +// perform order-wise interpolation +void Interpolate::Run(const ChainPositions &positions) { +  assert(positions.size() == backoffs_.size() + 1); +  Callback callback(uniform_prob_, backoffs_); +  JointOrder<Callback, SuffixOrder>(positions, callback); +} + +}} // namespaces diff --git a/klm/lm/builder/interpolate.hh b/klm/lm/builder/interpolate.hh new file mode 100644 index 00000000..9268d404 --- /dev/null +++ b/klm/lm/builder/interpolate.hh @@ -0,0 +1,27 @@ +#ifndef LM_BUILDER_INTERPOLATE__ +#define LM_BUILDER_INTERPOLATE__ + +#include <stdint.h> + +#include "lm/builder/multi_stream.hh" + +namespace lm { namespace builder { +  +/* Interpolate step.   + * Input: suffix sorted n-grams with (p_uninterpolated, gamma) from + * InitialProbabilities. + * Output: suffix sorted n-grams with complete probability + */ +class Interpolate { +  public: +    explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs); + +    void Run(const ChainPositions &positions); + +  private: +    float uniform_prob_; +    ChainPositions backoffs_; +}; + +}} // namespaces +#endif // LM_BUILDER_INTERPOLATE__ diff --git a/klm/lm/builder/joint_order.hh b/klm/lm/builder/joint_order.hh new file mode 100644 index 00000000..b5620144 --- /dev/null +++ b/klm/lm/builder/joint_order.hh @@ -0,0 +1,43 @@ +#ifndef LM_BUILDER_JOINT_ORDER__ +#define LM_BUILDER_JOINT_ORDER__ + +#include "lm/builder/multi_stream.hh" +#include "lm/lm_exception.hh" + +#include <string.h> + +namespace lm { namespace builder { + +template <class Callback, class Compare> void JointOrder(const ChainPositions &positions, Callback &callback) { +  // Allow matching to reference streams[-1]. +  NGramStreams streams_with_dummy; +  streams_with_dummy.InitWithDummy(positions); +  NGramStream *streams = streams_with_dummy.begin() + 1; + +  unsigned int order; +  for (order = 0; order < positions.size() && streams[order]; ++order) {} +  assert(order); // should always have <unk>. +  unsigned int current = 0; +  while (true) { +    // Does the context match the lower one?   +    if (!memcmp(streams[static_cast<int>(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset, sizeof(WordIndex) * current)) { +      callback.Enter(current, *streams[current]); +      // Transition to looking for extensions.   +      if (++current < order) continue; +    } +    // No extension left.   +    while(true) { +      assert(current > 0); +      --current; +      callback.Exit(current, *streams[current]); +      if (++streams[current]) break; +      UTIL_THROW_IF(order != current + 1, FormatLoadException, "Detected n-gram without matching suffix"); +      order = current; +      if (!order) return; +    } +  } +} + +}} // namespaces + +#endif // LM_BUILDER_JOINT_ORDER__ diff --git a/klm/lm/builder/main.cc b/klm/lm/builder/main.cc new file mode 100644 index 00000000..90b9dca2 --- /dev/null +++ b/klm/lm/builder/main.cc @@ -0,0 +1,94 @@ +#include "lm/builder/pipeline.hh" +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/usage.hh" + +#include <iostream> + +#include <boost/program_options.hpp> + +namespace { +class SizeNotify { +  public: +    SizeNotify(std::size_t &out) : behind_(out) {} + +    void operator()(const std::string &from) { +      behind_ = util::ParseSize(from); +    } + +  private: +    std::size_t &behind_; +}; + +boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value) { +  return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value); +} + +} // namespace + +int main(int argc, char *argv[]) { +  try { +    namespace po = boost::program_options; +    po::options_description options("Language model building options"); +    lm::builder::PipelineConfig pipeline; + +    options.add_options() +      ("order,o", po::value<std::size_t>(&pipeline.order)->required(), "Order of the model") +      ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") +      ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") +      ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") +      ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step") +      ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") +      ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") +      ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)") +      ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") +      ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); +    if (argc == 1) { +      std::cerr <<  +        "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n" +        "Please cite:\n" +        "@inproceedings{kenlm,\n" +        "author    = {Kenneth Heafield},\n" +        "title     = {{KenLM}: Faster and Smaller Language Model Queries},\n" +        "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n" +        "month     = {July}, year={2011},\n" +        "address   = {Edinburgh, UK},\n" +        "publisher = {Association for Computational Linguistics},\n" +        "}\n\n" +        "Provide the corpus on stdin.  The ARPA file will be written to stdout.  Order of\n" +        "the model (-o) is the only mandatory option.  As this is an on-disk program,\n" +        "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n" +        "Memory sizes are specified like GNU sort: a number followed by a unit character.\n" +        "Valid units are \% for percentage of memory (supported platforms only) and (in\n" +        "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y.  Default is K (*1024).\n\n"; +      std::cerr << options << std::endl; +      return 1; +    } +    po::variables_map vm; +    po::store(po::parse_command_line(argc, argv, options), vm); +    po::notify(vm); + +    util::NormalizeTempPrefix(pipeline.sort.temp_prefix); + +    lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; +    // TODO: evaluate options for these.   +    initial.adder_in.total_memory = 32768; +    initial.adder_in.block_count = 2; +    initial.adder_out.total_memory = 32768; +    initial.adder_out.block_count = 2; +    pipeline.read_backoffs = initial.adder_out; + +    // Read from stdin +    try { +      lm::builder::Pipeline(pipeline, 0, 1); +    } catch (const util::MallocException &e) { +      std::cerr << e.what() << std::endl; +      std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl; +      return 1; +    } +    util::PrintUsage(std::cerr); +  } catch (const std::exception &e) { +    std::cerr << e.what() << std::endl; +    return 1; +  } +} diff --git a/klm/lm/builder/multi_stream.hh b/klm/lm/builder/multi_stream.hh new file mode 100644 index 00000000..707a98c7 --- /dev/null +++ b/klm/lm/builder/multi_stream.hh @@ -0,0 +1,180 @@ +#ifndef LM_BUILDER_MULTI_STREAM__ +#define LM_BUILDER_MULTI_STREAM__ + +#include "lm/builder/ngram_stream.hh" +#include "util/scoped.hh" +#include "util/stream/chain.hh" + +#include <cstddef> +#include <new> + +#include <assert.h> +#include <stdlib.h> + +namespace lm { namespace builder { + +template <class T> class FixedArray { +  public: +    explicit FixedArray(std::size_t count) { +      Init(count); +    } + +    FixedArray() : newed_end_(NULL) {} + +    void Init(std::size_t count) { +      assert(!block_.get()); +      block_.reset(malloc(sizeof(T) * count)); +      if (!block_.get()) throw std::bad_alloc(); +      newed_end_ = begin(); +    } + +    FixedArray(const FixedArray &from) { +      std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get()); +      Init(size); +      for (std::size_t i = 0; i < size; ++i) { +        new(end()) T(from[i]); +        Constructed(); +      } +    } + +    ~FixedArray() { clear(); } + +    T *begin() { return static_cast<T*>(block_.get()); } +    const T *begin() const { return static_cast<const T*>(block_.get()); } +    // Always call Constructed after successful completion of new.   +    T *end() { return newed_end_; } +    const T *end() const { return newed_end_; } + +    T &back() { return *(end() - 1); } +    const T &back() const { return *(end() - 1); } + +    std::size_t size() const { return end() - begin(); } +    bool empty() const { return begin() == end(); } + +    T &operator[](std::size_t i) { return begin()[i]; } +    const T &operator[](std::size_t i) const { return begin()[i]; } + +    template <class C> void push_back(const C &c) { +      new (end()) T(c); +      Constructed(); +    } + +    void clear() { +      for (T *i = begin(); i != end(); ++i) +        i->~T(); +      newed_end_ = begin(); +    } + +  protected: +    void Constructed() { +      ++newed_end_; +    } + +  private: +    util::scoped_malloc block_; + +    T *newed_end_; +}; + +class Chains; + +class ChainPositions : public FixedArray<util::stream::ChainPosition> { +  public: +    ChainPositions() {} + +    void Init(Chains &chains); + +    explicit ChainPositions(Chains &chains) { +      Init(chains); +    } +}; + +class Chains : public FixedArray<util::stream::Chain> { +  private: +    template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun { +      typedef Chains type; +    }; + +  public: +    explicit Chains(std::size_t limit) : FixedArray<util::stream::Chain>(limit) {} + +    template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) { +      threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker)); +      return *this; +    } + +    template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) { +      threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker)); +      return *this; +    } + +    Chains &operator>>(const util::stream::Recycler &recycler) { +      for (util::stream::Chain *i = begin(); i != end(); ++i)  +        *i >> recycler; +      return *this; +    } + +    void Wait(bool release_memory = true) { +      threads_.clear(); +      for (util::stream::Chain *i = begin(); i != end(); ++i) { +        i->Wait(release_memory); +      } +    } + +  private: +    boost::ptr_vector<util::stream::Thread> threads_; + +    Chains(const Chains &); +    void operator=(const Chains &); +}; + +inline void ChainPositions::Init(Chains &chains) { +  FixedArray<util::stream::ChainPosition>::Init(chains.size()); +  for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) { +    new (end()) util::stream::ChainPosition(i->Add()); Constructed(); +  } +} + +inline Chains &operator>>(Chains &chains, ChainPositions &positions) { +  positions.Init(chains); +  return chains; +} + +class NGramStreams : public FixedArray<NGramStream> { +  public: +    NGramStreams() {} + +    // This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning). +    void InitWithDummy(const ChainPositions &positions) { +      FixedArray<NGramStream>::Init(positions.size() + 1); +      new (end()) NGramStream(); Constructed(); +      for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) { +        push_back(*i); +      } +    } + +    // Limit restricts to positions[0,limit) +    void Init(const ChainPositions &positions, std::size_t limit) { +      FixedArray<NGramStream>::Init(limit); +      for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) { +        push_back(*i); +      } +    } +    void Init(const ChainPositions &positions) { +      Init(positions, positions.size()); +    } + +    NGramStreams(const ChainPositions &positions) { +      Init(positions); +    } +}; + +inline Chains &operator>>(Chains &chains, NGramStreams &streams) { +  ChainPositions positions; +  chains >> positions; +  streams.Init(positions); +  return chains; +} + +}} // namespaces +#endif // LM_BUILDER_MULTI_STREAM__ diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh new file mode 100644 index 00000000..2984ed0b --- /dev/null +++ b/klm/lm/builder/ngram.hh @@ -0,0 +1,84 @@ +#ifndef LM_BUILDER_NGRAM__ +#define LM_BUILDER_NGRAM__ + +#include "lm/weights.hh" +#include "lm/word_index.hh" + +#include <cstddef> + +#include <assert.h> +#include <stdint.h> +#include <string.h> + +namespace lm { +namespace builder { + +struct Uninterpolated { +  float prob;  // Uninterpolated probability. +  float gamma; // Interpolation weight for lower order. +}; + +union Payload { +  uint64_t count; +  Uninterpolated uninterp; +  ProbBackoff complete; +}; + +class NGram { +  public: +    NGram(void *begin, std::size_t order)  +      : begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {} + +    const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); } +    uint8_t *Base() { return reinterpret_cast<uint8_t*>(begin_); } + +    void ReBase(void *to) { +      std::size_t difference = end_ - begin_; +      begin_ = reinterpret_cast<WordIndex*>(to); +      end_ = begin_ + difference; +    } + +    // Would do operator++ but that can get confusing for a stream.   +    void NextInMemory() { +      ReBase(&Value() + 1); +    } + +    // Lower-case in deference to STL.   +    const WordIndex *begin() const { return begin_; } +    WordIndex *begin() { return begin_; } +    const WordIndex *end() const { return end_; } +    WordIndex *end() { return end_; } + +    const Payload &Value() const { return *reinterpret_cast<const Payload *>(end_); } +    Payload &Value() { return *reinterpret_cast<Payload *>(end_); } + +    uint64_t &Count() { return Value().count; } +    const uint64_t Count() const { return Value().count; } + +    std::size_t Order() const { return end_ - begin_; } + +    static std::size_t TotalSize(std::size_t order) { +      return order * sizeof(WordIndex) + sizeof(Payload); +    } +    std::size_t TotalSize() const { +      // Compiler should optimize this.   +      return TotalSize(Order()); +    } +    static std::size_t OrderFromSize(std::size_t size) { +      std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex); +      assert(size == TotalSize(ret)); +      return ret; +    } + +  private: +    WordIndex *begin_, *end_; +}; + +const WordIndex kUNK = 0; +const WordIndex kBOS = 1; +const WordIndex kEOS = 2; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_NGRAM__ diff --git a/klm/lm/builder/ngram_stream.hh b/klm/lm/builder/ngram_stream.hh new file mode 100644 index 00000000..3c994664 --- /dev/null +++ b/klm/lm/builder/ngram_stream.hh @@ -0,0 +1,55 @@ +#ifndef LM_BUILDER_NGRAM_STREAM__ +#define LM_BUILDER_NGRAM_STREAM__ + +#include "lm/builder/ngram.hh" +#include "util/stream/chain.hh" +#include "util/stream/stream.hh" + +#include <cstddef> + +namespace lm { namespace builder { + +class NGramStream { +  public: +    NGramStream() : gram_(NULL, 0) {} + +    NGramStream(const util::stream::ChainPosition &position) : gram_(NULL, 0) { +      Init(position); +    } + +    void Init(const util::stream::ChainPosition &position) { +      stream_.Init(position); +      gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize())); +    } + +    NGram &operator*() { return gram_; } +    const NGram &operator*() const { return gram_; } + +    NGram *operator->() { return &gram_; } +    const NGram *operator->() const { return &gram_; } + +    void *Get() { return stream_.Get(); } +    const void *Get() const { return stream_.Get(); } + +    operator bool() const { return stream_; } +    bool operator!() const { return !stream_; } +    void Poison() { stream_.Poison(); } + +    NGramStream &operator++() { +      ++stream_; +      gram_.ReBase(stream_.Get()); +      return *this; +    } + +  private: +    NGram gram_; +    util::stream::Stream stream_; +}; + +inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) { +  str.Init(chain.Add()); +  return chain; +} + +}} // namespaces +#endif // LM_BUILDER_NGRAM_STREAM__ diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc new file mode 100644 index 00000000..14a1f721 --- /dev/null +++ b/klm/lm/builder/pipeline.cc @@ -0,0 +1,320 @@ +#include "lm/builder/pipeline.hh" + +#include "lm/builder/adjust_counts.hh" +#include "lm/builder/corpus_count.hh" +#include "lm/builder/initial_probabilities.hh" +#include "lm/builder/interpolate.hh" +#include "lm/builder/print.hh" +#include "lm/builder/sort.hh" + +#include "lm/sizes.hh" + +#include "util/exception.hh" +#include "util/file.hh" +#include "util/stream/io.hh" + +#include <algorithm> +#include <iostream> +#include <vector> + +namespace lm { namespace builder { + +namespace { +void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts) { +  std::cerr << "Statistics:\n"; +  for (size_t i = 0; i < counts.size(); ++i) { +    std::cerr << (i + 1) << ' ' << counts[i]; +    for (size_t d = 1; d <= 3; ++d) +      std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d]; +    std::cerr << '\n'; +  } +} + +class Master { +  public: +    explicit Master(const PipelineConfig &config)  +      : config_(config), chains_(config.order), files_(config.order) { +      config_.minimum_block = std::max(NGram::TotalSize(config_.order), config_.minimum_block); +    } + +    const PipelineConfig &Config() const { return config_; } + +    Chains &MutableChains() { return chains_; } + +    template <class T> Master &operator>>(const T &worker) { +      chains_ >> worker; +      return *this; +    } + +    // This takes the (partially) sorted ngrams and sets up for adjusted counts. +    void InitForAdjust(util::stream::Sort<SuffixOrder, AddCombiner> &ngrams, WordIndex types) { +      const std::size_t each_order_min = config_.minimum_block * config_.block_count; +      // We know how many unigrams there are.  Don't allocate more than needed to them. +      const std::size_t min_chains = (config_.order - 1) * each_order_min + +        std::min(types * NGram::TotalSize(1), each_order_min); +      // Do merge sort with calculated laziness. +      const std::size_t merge_using = ngrams.Merge(std::min(config_.TotalMemory() - min_chains, ngrams.DefaultLazy())); + +      std::vector<uint64_t> count_bounds(1, types); +      CreateChains(config_.TotalMemory() - merge_using, count_bounds); +      ngrams.Output(chains_.back(), merge_using); + +      // Setup unigram file.   +      files_.push_back(util::MakeTemp(config_.TempPrefix())); +    } + +    // For initial probabilities, but this is generic. +    void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, Chains &second, util::stream::ChainConfig second_config) { +      // Do merge first before allocating chain memory. +      for (std::size_t i = 1; i < config_.order; ++i) { +        sorts[i - 1].Merge(0); +      } +      // There's no lazy merge, so just divide memory amongst the chains. +      CreateChains(config_.TotalMemory(), counts); +      chains_.back().ActivateProgress(); +      chains_[0] >> files_[0].Source(); +      second_config.entry_size = NGram::TotalSize(1); +      second.push_back(second_config); +      second.back() >> files_[0].Source(); +      for (std::size_t i = 1; i < config_.order; ++i) { +        util::scoped_fd fd(sorts[i - 1].StealCompleted()); +        chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get())); +        chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true); +        second_config.entry_size = NGram::TotalSize(i + 1); +        second.push_back(second_config); +        second.back() >> util::stream::PRead(fd.release(), true); +      } +    } + +    // There is no sort after this, so go for broke on lazy merging. +    template <class Compare> void MaximumLazyInput(const std::vector<uint64_t> &counts, Sorts<Compare> &sorts) { +      // Determine the minimum we can use for all the chains. +      std::size_t min_chains = 0; +      for (std::size_t i = 0; i < config_.order; ++i) { +        min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast<uint64_t>(config_.minimum_block)); +      } +      std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains); +      std::vector<std::size_t> laziness; +      // Prioritize longer n-grams. +      for (util::stream::Sort<SuffixOrder> *i = sorts.end() - 1; i >= sorts.begin(); --i) { +        laziness.push_back(i->Merge(for_merge)); +        assert(for_merge >= laziness.back()); +        for_merge -= laziness.back(); +      } +      std::reverse(laziness.begin(), laziness.end()); + +      CreateChains(for_merge + min_chains, counts); +      chains_.back().ActivateProgress(); +      chains_[0] >> files_[0].Source(); +      for (std::size_t i = 1; i < config_.order; ++i) { +        sorts[i - 1].Output(chains_[i], laziness[i - 1]); +      } +    } + +    void BufferFinal(const std::vector<uint64_t> &counts) { +      chains_[0] >> files_[0].Sink(); +      for (std::size_t i = 1; i < config_.order; ++i) { +        files_.push_back(util::MakeTemp(config_.TempPrefix())); +        chains_[i] >> files_[i].Sink(); +      } +      chains_.Wait(true); +      // Use less memory.  Because we can. +      CreateChains(std::min(config_.sort.buffer_size * config_.order, config_.TotalMemory()), counts); +      for (std::size_t i = 0; i < config_.order; ++i) { +        chains_[i] >> files_[i].Source(); +      } +    } + +    template <class Compare> void SetupSorts(Sorts<Compare> &sorts) { +      sorts.Init(config_.order - 1); +      // Unigrams don't get sorted because their order is always the same. +      chains_[0] >> files_[0].Sink(); +      for (std::size_t i = 1; i < config_.order; ++i) { +        sorts.push_back(chains_[i], config_.sort, Compare(i + 1)); +      } +      chains_.Wait(true); +    } + +  private: +    // Create chains, allocating memory to them.  Totally heuristic.  Count +    // bounds are upper bounds on the counts or not present. +    void CreateChains(std::size_t remaining_mem, const std::vector<uint64_t> &count_bounds) { +      std::vector<std::size_t> assignments; +      assignments.reserve(config_.order); +      // Start by assigning maximum memory usage (to be refined later). +      for (std::size_t i = 0; i < count_bounds.size(); ++i) { +        assignments.push_back(static_cast<std::size_t>(std::min( +            static_cast<uint64_t>(remaining_mem), +            count_bounds[i] * static_cast<uint64_t>(NGram::TotalSize(i + 1))))); +      } +      assignments.resize(config_.order, remaining_mem); + +      // Now we know how much memory everybody wants.  How much will they get? +      // Proportional to this. +      std::vector<float> portions; +      // Indices of orders that have yet to be assigned. +      std::vector<std::size_t> unassigned; +      for (std::size_t i = 0; i < config_.order; ++i) { +        portions.push_back(static_cast<float>((i+1) * NGram::TotalSize(i+1))); +        unassigned.push_back(i); +      } +      /*If somebody doesn't eat their full dinner, give it to the rest of the +       * family.  Then somebody else might not eat their full dinner etc.  Ends +       * when everybody unassigned is hungry. +       */ +      float sum; +      bool found_more; +      std::vector<std::size_t> block_count(config_.order); +      do { +        sum = 0.0; +        for (std::size_t i = 0; i < unassigned.size(); ++i) { +          sum += portions[unassigned[i]]; +        } +        found_more = false; +        // If the proportional assignment is more than needed, give it just what it needs. +        for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end();) { +          if (assignments[*i] <= remaining_mem * (portions[*i] / sum)) { +            remaining_mem -= assignments[*i]; +            block_count[*i] = 1; +            i = unassigned.erase(i); +            found_more = true; +          } else { +            ++i; +          } +        } +      } while (found_more); +      for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end(); ++i) { +        assignments[*i] = remaining_mem * (portions[*i] / sum); +        block_count[*i] = config_.block_count; +      } +      chains_.clear(); +      std::cerr << "Chain sizes:"; +      for (std::size_t i = 0; i < config_.order; ++i) { +        std::cerr << ' ' << (i+1) << ":" << assignments[i]; +        chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i])); +      } +      std::cerr << std::endl; +    } + +    PipelineConfig config_; + +    Chains chains_; +    // Often only unigrams, but sometimes all orders.   +    FixedArray<util::stream::FileBuffer> files_; +}; + +void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) { +  const PipelineConfig &config = master.Config(); +  std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl; + +  UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory()); +  std::size_t memory_for_chain =  +    // This much memory to work with after vocab hash table. +    static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) / +    // Solve for block size including the dedupe multiplier for one block. +    (static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) * +    // Chain likes memory expressed in terms of total memory. +    static_cast<float>(config.block_count); +  util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain)); + +  WordIndex type_count; +  util::FilePiece text(text_file, NULL, &std::cerr); +  text_file_name = text.FileName(); +  CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); +  chain >> boost::ref(counter); + +  util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); +  chain.Wait(true); +  std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl; +  master.InitForAdjust(sorter, type_count); +} + +void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) { +  const PipelineConfig &config = master.Config(); +  Chains second(config.order); + +  { +    Sorts<ContextOrder> sorts; +    master.SetupSorts(sorts); +    PrintStatistics(counts, discounts); +    lm::ngram::ShowSizes(counts); +    std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl; +    master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in); +  } + +  Chains gamma_chains(config.order); +  InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains); +  // Don't care about gamma for 0.   +  gamma_chains[0] >> util::stream::kRecycle; +  gammas.Init(config.order - 1); +  for (std::size_t i = 1; i < config.order; ++i) { +    gammas.push_back(util::MakeTemp(config.TempPrefix())); +    gamma_chains[i] >> gammas[i - 1].Sink(); +  } +  // Has to be done here due to gamma_chains scope. +  master.SetupSorts(primary); +} + +void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) { +  std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl; +  const PipelineConfig &config = master.Config(); +  master.MaximumLazyInput(counts, primary); + +  Chains gamma_chains(config.order - 1); +  util::stream::ChainConfig read_backoffs(config.read_backoffs); +  read_backoffs.entry_size = sizeof(float); +  for (std::size_t i = 0; i < config.order - 1; ++i) { +    gamma_chains.push_back(read_backoffs); +    gamma_chains.back() >> gammas[i].Source(); +  } +  master >> Interpolate(counts[0], ChainPositions(gamma_chains)); +  gamma_chains >> util::stream::kRecycle; +  master.BufferFinal(counts); +} + +} // namespace + +void Pipeline(PipelineConfig config, int text_file, int out_arpa) { +  // Some fail-fast sanity checks. +  if (config.sort.buffer_size * 4 > config.TotalMemory()) { +    config.sort.buffer_size = config.TotalMemory() / 4; +    std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl; +  } +  if (config.minimum_block < NGram::TotalSize(config.order)) { +    config.minimum_block = NGram::TotalSize(config.order); +    std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl; +  } +  UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << "."); +  UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception, +      "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ".  Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size."); + +  UTIL_TIMER("(%w s) Total wall time elapsed\n"); +  Master master(config); + +  util::scoped_fd vocab_file(config.vocab_file.empty() ?  +      util::MakeTemp(config.TempPrefix()) :  +      util::CreateOrThrow(config.vocab_file.c_str())); +  uint64_t token_count; +  std::string text_file_name; +  CountText(text_file, vocab_file.get(), master, token_count, text_file_name); + +  std::vector<uint64_t> counts; +  std::vector<Discount> discounts; +  master >> AdjustCounts(counts, discounts); + +  { +    FixedArray<util::stream::FileBuffer> gammas; +    Sorts<SuffixOrder> primary; +    InitialProbabilities(counts, discounts, master, primary, gammas); +    InterpolateProbabilities(counts, master, primary, gammas); +  } + +  std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; +  VocabReconstitute vocab(vocab_file.get()); +  UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up.  Is there a null byte in the input?"); +  HeaderInfo header_info(text_file_name, token_count); +  master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; +  master.MutableChains().Wait(true); +} + +}} // namespaces diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh new file mode 100644 index 00000000..f1d6c5f6 --- /dev/null +++ b/klm/lm/builder/pipeline.hh @@ -0,0 +1,40 @@ +#ifndef LM_BUILDER_PIPELINE__ +#define LM_BUILDER_PIPELINE__ + +#include "lm/builder/initial_probabilities.hh" +#include "lm/builder/header_info.hh" +#include "util/stream/config.hh" +#include "util/file_piece.hh" + +#include <string> +#include <cstddef> + +namespace lm { namespace builder { + +struct PipelineConfig { +  std::size_t order; +  std::string vocab_file; +  util::stream::SortConfig sort; +  InitialProbabilitiesConfig initial_probs; +  util::stream::ChainConfig read_backoffs; +  bool verbose_header; + +  // Amount of memory to assume that the vocabulary hash table will use.  This +  // is subtracted from total memory for CorpusCount. +  std::size_t assume_vocab_hash_size; + +  // Minimum block size to tolerate. +  std::size_t minimum_block; + +  // Number of blocks to use.  This will be overridden to 1 if everything fits. +  std::size_t block_count; + +  const std::string &TempPrefix() const { return sort.temp_prefix; } +  std::size_t TotalMemory() const { return sort.total_memory; } +}; + +// Takes ownership of text_file. +void Pipeline(PipelineConfig config, int text_file, int out_arpa); + +}} // namespaces +#endif // LM_BUILDER_PIPELINE__ diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc new file mode 100644 index 00000000..b0323221 --- /dev/null +++ b/klm/lm/builder/print.cc @@ -0,0 +1,135 @@ +#include "lm/builder/print.hh" + +#include "util/double-conversion/double-conversion.h" +#include "util/double-conversion/utils.h" +#include "util/file.hh" +#include "util/mmap.hh" +#include "util/scoped.hh" +#include "util/stream/timer.hh" + +#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE +#include <boost/lexical_cast.hpp> + +#include <sstream> + +#include <string.h> + +namespace lm { namespace builder { + +VocabReconstitute::VocabReconstitute(int fd) { +  uint64_t size = util::SizeOrThrow(fd); +  util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_); +  const char *const start = static_cast<const char*>(memory_.get()); +  const char *i; +  for (i = start; i != start + size; i += strlen(i) + 1) { +    map_.push_back(i); +  } +  // Last one for LookupPiece. +  map_.push_back(i); +} + +namespace { +class OutputManager { +  public: +    static const std::size_t kOutBuf = 1048576; + +    // Does not take ownership of out. +    explicit OutputManager(int out) +      : buf_(util::MallocOrThrow(kOutBuf)), +        builder_(static_cast<char*>(buf_.get()), kOutBuf), +        // Mostly the default but with inf instead.  And no flags. +        convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0), +        fd_(out) {} + +    ~OutputManager() { +      Flush(); +    } + +    OutputManager &operator<<(float value) { +      // Odd, but this is the largest number found in the comments. +      EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8); +      convert_.ToShortestSingle(value, &builder_); +      return *this; +    } + +    OutputManager &operator<<(StringPiece str) { +      if (str.size() > kOutBuf) { +        Flush(); +        util::WriteOrThrow(fd_, str.data(), str.size()); +      } else { +        EnsureRemaining(str.size()); +        builder_.AddSubstring(str.data(), str.size()); +      } +      return *this; +    } + +    // Inefficient! +    OutputManager &operator<<(unsigned val) { +      return *this << boost::lexical_cast<std::string>(val); +    } + +    OutputManager &operator<<(char c) { +      EnsureRemaining(1); +      builder_.AddCharacter(c); +      return *this; +    } + +    void Flush() { +      util::WriteOrThrow(fd_, buf_.get(), builder_.position()); +      builder_.Reset(); +    } + +  private: +    void EnsureRemaining(std::size_t amount) { +      if (static_cast<std::size_t>(builder_.size() - builder_.position()) < amount) { +        Flush(); +      } +    } + +    util::scoped_malloc buf_; +    double_conversion::StringBuilder builder_; +    double_conversion::DoubleToStringConverter convert_; +    int fd_; +}; +} // namespace + +PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd)  +  : vocab_(vocab), out_fd_(out_fd) { +  std::stringstream stream; + +  if (header_info) { +    stream << "# Input file: " << header_info->input_file << '\n'; +    stream << "# Token count: " << header_info->token_count << '\n'; +    stream << "# Smoothing: Modified Kneser-Ney" << '\n'; +  } +  stream << "\\data\\\n"; +  for (size_t i = 0; i < counts.size(); ++i) { +    stream << "ngram " << (i+1) << '=' << counts[i] << '\n'; +  } +  stream << '\n'; +  std::string as_string(stream.str()); +  util::WriteOrThrow(out_fd, as_string.data(), as_string.size()); +} + +void PrintARPA::Run(const ChainPositions &positions) { +  UTIL_TIMER("(%w s) Wrote ARPA file\n"); +  OutputManager out(out_fd_); +  for (unsigned order = 1; order <= positions.size(); ++order) { +    out << "\\" << order << "-grams:" << '\n'; +    for (NGramStream stream(positions[order - 1]); stream; ++stream) { +      // Correcting for numerical precision issues.  Take that IRST.   +      out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin()); +      for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) { +        out << ' ' << vocab_.Lookup(*i); +      } +      float backoff = stream->Value().complete.backoff; +      if (backoff != 0.0) +        out << '\t' << backoff; +      out << '\n'; +    } +    out << '\n'; +  } +  out << "\\end\\\n"; +} + +}} // namespaces diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh new file mode 100644 index 00000000..aa932e75 --- /dev/null +++ b/klm/lm/builder/print.hh @@ -0,0 +1,102 @@ +#ifndef LM_BUILDER_PRINT__ +#define LM_BUILDER_PRINT__ + +#include "lm/builder/ngram.hh" +#include "lm/builder/multi_stream.hh" +#include "lm/builder/header_info.hh" +#include "util/file.hh" +#include "util/mmap.hh" +#include "util/string_piece.hh" + +#include <ostream> + +#include <assert.h> + +// Warning: print routines read all unigrams before all bigrams before all +// trigrams etc.  So if other parts of the chain move jointly, you'll have to +// buffer.   + +namespace lm { namespace builder { + +class VocabReconstitute { +  public: +    // fd must be alive for life of this object; does not take ownership. +    explicit VocabReconstitute(int fd); + +    const char *Lookup(WordIndex index) const { +      assert(index < map_.size() - 1); +      return map_[index]; +    } + +    StringPiece LookupPiece(WordIndex index) const { +      return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]); +    } + +    std::size_t Size() const { +      // There's an extra entry to support StringPiece lengths. +      return map_.size() - 1; +    } + +  private: +    util::scoped_memory memory_; +    std::vector<const char*> map_; +}; + +// Not defined, only specialized.   +template <class T> void PrintPayload(std::ostream &to, const Payload &payload); +template <> inline void PrintPayload<uint64_t>(std::ostream &to, const Payload &payload) { +  to << payload.count; +} +template <> inline void PrintPayload<Uninterpolated>(std::ostream &to, const Payload &payload) { +  to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma); +} +template <> inline void PrintPayload<ProbBackoff>(std::ostream &to, const Payload &payload) { +  to << payload.complete.prob << ' ' << payload.complete.backoff; +} + +// template parameter is the type stored.   +template <class V> class Print { +  public: +    explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {} + +    void Run(const ChainPositions &chains) { +      NGramStreams streams(chains); +      for (NGramStream *s = streams.begin(); s != streams.end(); ++s) { +        DumpStream(*s); +      } +    } + +    void Run(const util::stream::ChainPosition &position) { +      NGramStream stream(position); +      DumpStream(stream); +    } + +  private: +    void DumpStream(NGramStream &stream) { +      for (; stream; ++stream) { +        PrintPayload<V>(to_, stream->Value()); +        for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) { +          to_ << ' ' << vocab_.Lookup(*w) << '=' << *w; +        } +        to_ << '\n'; +      } +    } + +    const VocabReconstitute &vocab_; +    std::ostream &to_; +}; + +class PrintARPA { +  public: +    // header_info may be NULL to disable the header +    explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd); + +    void Run(const ChainPositions &positions); + +  private: +    const VocabReconstitute &vocab_; +    int out_fd_; +}; + +}} // namespaces +#endif // LM_BUILDER_PRINT__ diff --git a/klm/lm/builder/sort.hh b/klm/lm/builder/sort.hh new file mode 100644 index 00000000..9989389b --- /dev/null +++ b/klm/lm/builder/sort.hh @@ -0,0 +1,103 @@ +#ifndef LM_BUILDER_SORT__ +#define LM_BUILDER_SORT__ + +#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram.hh" +#include "lm/word_index.hh" +#include "util/stream/sort.hh" + +#include "util/stream/timer.hh" + +#include <functional> +#include <string> + +namespace lm { +namespace builder { + +template <class Child> class Comparator : public std::binary_function<const void *, const void *, bool> { +  public: +    explicit Comparator(std::size_t order) : order_(order) {} + +    inline bool operator()(const void *lhs, const void *rhs) const { +      return static_cast<const Child*>(this)->Compare(static_cast<const WordIndex*>(lhs), static_cast<const WordIndex*>(rhs)); +    } + +    std::size_t Order() const { return order_; } + +  protected: +    std::size_t order_; +}; + +class SuffixOrder : public Comparator<SuffixOrder> { +  public: +    explicit SuffixOrder(std::size_t order) : Comparator<SuffixOrder>(order) {} + +    inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { +      for (std::size_t i = order_ - 1; i != 0; --i) { +        if (lhs[i] != rhs[i]) +          return lhs[i] < rhs[i]; +      } +      return lhs[0] < rhs[0]; +    } + +    static const unsigned kMatchOffset = 1; +}; + +class ContextOrder : public Comparator<ContextOrder> { +  public: +    explicit ContextOrder(std::size_t order) : Comparator<ContextOrder>(order) {} + +    inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { +      for (int i = order_ - 2; i >= 0; --i) { +        if (lhs[i] != rhs[i]) +          return lhs[i] < rhs[i]; +      } +      return lhs[order_ - 1] < rhs[order_ - 1]; +    } +}; + +class PrefixOrder : public Comparator<PrefixOrder> { +  public: +    explicit PrefixOrder(std::size_t order) : Comparator<PrefixOrder>(order) {} + +    inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { +      for (std::size_t i = 0; i < order_; ++i) { +        if (lhs[i] != rhs[i]) +          return lhs[i] < rhs[i]; +      } +      return false; +    } +     +    static const unsigned kMatchOffset = 0; +}; + +// Sum counts for the same n-gram. +struct AddCombiner { +  bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const { +    NGram first(first_void, compare.Order()); +    // There isn't a const version of NGram.   +    NGram second(const_cast<void*>(second_void), compare.Order()); +    if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false; +    first.Count() += second.Count(); +    return true; +  } +}; + +// The combiner is only used on a single chain, so I didn't bother to allow +// that template.   +template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > { +  private: +    typedef util::stream::Sort<Compare> S; +    typedef FixedArray<S> P; + +  public: +    void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) { +      new (P::end()) S(chain, config, compare); +      P::Constructed(); +    } +}; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_SORT__ | 
