diff options
| author | Patrick Simianer <p@simianer.de> | 2013-01-21 12:29:43 +0100 | 
|---|---|---|
| committer | Patrick Simianer <p@simianer.de> | 2013-01-21 12:29:43 +0100 | 
| commit | 0d23f8aecbfaf982cd165ebfc2a1611cefcc7275 (patch) | |
| tree | 8eafa6ea43224ff70635cadd4d6f027d28f4986f /klm/lm | |
| parent | dbc66cd3944321961c5e11d5254fd914f05a98ad (diff) | |
| parent | 7cac43b858f3b681555bf0578f54b1f822c43207 (diff) | |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'klm/lm')
45 files changed, 4079 insertions, 78 deletions
| diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index a12c5f03..f15cbd77 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -1,7 +1,7 @@  bin_PROGRAMS = build_binary  build_binary_SOURCES = build_binary.cc -build_binary_LDADD = libklm.a ../util/libklm_util.a -lz +build_binary_LDADD = libklm.a ../util/libklm_util.a ../util/double-conversion/libklm_util_double.a -lz  #noinst_PROGRAMS = \  #  ngram_test @@ -12,6 +12,34 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz  noinst_LIBRARIES = libklm.a  libklm_a_SOURCES = \ +  bhiksha.hh \ +  binary_format.hh \ +  blank.hh \ +  config.hh \ +  enumerate_vocab.hh \ +  facade.hh \ +  left.hh \ +  lm_exception.hh \ +  max_order.hh \ +  model.hh \ +  model_type.hh \ +  ngram_query.hh \ +  partial.hh \ +  quantize.hh \ +  read_arpa.hh \ +  return.hh \ +  search_hashed.hh \ +  search_trie.hh \ +  sizes.hh \ +  state.hh \ +  trie.hh \ +  trie_sort.hh \ +  value.hh \ +  value_build.hh \ +  virtual_interface.hh \ +  vocab.hh \ +  weights.hh \ +  word_index.hh \    bhiksha.cc \    binary_format.cc \    config.cc \ @@ -22,11 +50,12 @@ libklm_a_SOURCES = \    read_arpa.cc \    search_hashed.cc \    search_trie.cc \ +  sizes.cc \    trie.cc \  	trie_sort.cc \  	value_build.cc \    virtual_interface.cc \    vocab.cc -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 2b8c9d5b..ab2c0c32 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -1,10 +1,14 @@  #include "lm/model.hh" +#include "lm/sizes.hh"  #include "util/file_piece.hh" +#include "util/usage.hh" +#include <algorithm>  #include <cstdlib>  #include <exception>  #include <iostream>  #include <iomanip> +#include <limits>  #include <math.h>  #include <stdlib.h> @@ -19,8 +23,8 @@ namespace lm {  namespace ngram {  namespace { -void Usage(const char *name) { -  std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" +void Usage(const char *name, const char *default_mem) { +  std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"  "-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"  "   Default is -100.  The ARPA file will always take precedence.\n"  "-s allows models to be built even if they do not have <s> and </s>.\n" @@ -38,8 +42,11 @@ void Usage(const char *name) {  "trie is a straightforward trie with bit-level packing.  It uses the least\n"  "memory and is still faster than SRI or IRST.  Building the trie format uses an\n"  "on-disk sort to save memory.\n" -"-t is the temporary directory prefix.  Default is the output file name.\n" -"-m limits memory use for sorting.  Measured in MB.  Default is 1024MB.\n" +"-T is the temporary directory prefix.  Default is the output file name.\n" +"-S determines memory use for sorting.  Default is " << default_mem << ".  This is compatible\n" +"   with GNU sort.  The number is followed by a unit: \% for percent of physical\n" +"   memory, b for bytes, K for Kilobytes, M for megabytes, then G,T,P,E,Z,Y.  \n" +"   Default unit is K for Kilobytes.\n"  "-q turns quantization on and sets the number of bits (e.g. -q 8).\n"  "-b sets backoff quantization bits.  Requires -q and defaults to that value.\n"  "-a compresses pointers using an array of offsets.  The parameter is the\n" @@ -83,47 +90,6 @@ void ParseFileList(const char *from, std::vector<std::string> &to) {    }  } -void ShowSizes(const char *file, const lm::ngram::Config &config) { -  std::vector<uint64_t> counts; -  util::FilePiece f(file); -  lm::ReadARPACounts(f, counts); -  uint64_t sizes[6]; -  sizes[0] = ProbingModel::Size(counts, config); -  sizes[1] = RestProbingModel::Size(counts, config); -  sizes[2] = TrieModel::Size(counts, config); -  sizes[3] = QuantTrieModel::Size(counts, config); -  sizes[4] = ArrayTrieModel::Size(counts, config); -  sizes[5] = QuantArrayTrieModel::Size(counts, config); -  uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); -  uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); -  uint64_t divide; -  char prefix; -  if (min_length < (1 << 10) * 10) { -    prefix = ' '; -    divide = 1; -  } else if (min_length < (1 << 20) * 10) { -    prefix = 'k'; -    divide = 1 << 10; -  } else if (min_length < (1ULL << 30) * 10) { -    prefix = 'M'; -    divide = 1 << 20; -  } else { -    prefix = 'G'; -    divide = 1 << 30; -  } -  long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide)))); -  std::cout << "Memory estimate:\ntype    "; -  // right align bytes.   -  for (long int i = 0; i < length - 2; ++i) std::cout << ' '; -  std::cout << prefix << "B\n" -    "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" -    "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n" -    "trie    " << std::setw(length) << (sizes[2] / divide) << " without quantization\n" -    "trie    " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" -    "trie    " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" -    "trie    " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; -} -  void ProbingQuantizationUnsupported() {    std::cerr << "Quantization is only implemented in the trie data structure." << std::endl;    exit(1); @@ -136,11 +102,14 @@ void ProbingQuantizationUnsupported() {  int main(int argc, char *argv[]) {    using namespace lm::ngram; +  const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G"; +    try {      bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false;      lm::ngram::Config config; +    config.building_memory = util::ParseSize(default_mem);      int opt; -    while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:sir:")) != -1) { +    while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) {        switch(opt) {          case 'q':            config.prob_bits = ParseBitCount(optarg); @@ -161,12 +130,16 @@ int main(int argc, char *argv[]) {          case 'p':            config.probing_multiplier = ParseFloat(optarg);            break; -        case 't': +        case 't': // legacy +        case 'T':            config.temporary_directory_prefix = optarg;            break; -        case 'm': +        case 'm': // legacy            config.building_memory = ParseUInt(optarg) * 1048576;            break; +        case 'S': +          config.building_memory = std::min(static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), util::ParseSize(optarg)); +          break;          case 'w':            set_write_method = true;            if (!strcmp(optarg, "mmap")) { @@ -174,7 +147,7 @@ int main(int argc, char *argv[]) {            } else if (!strcmp(optarg, "after")) {              config.write_method = Config::WRITE_AFTER;            } else { -            Usage(argv[0]); +            Usage(argv[0], default_mem);            }            break;          case 's': @@ -189,7 +162,7 @@ int main(int argc, char *argv[]) {            config.rest_function = Config::REST_LOWER;            break;          default: -          Usage(argv[0]); +          Usage(argv[0], default_mem);        }      }      if (!quantize && set_backoff_bits) { @@ -212,7 +185,7 @@ int main(int argc, char *argv[]) {        from_file = argv[optind + 1];        config.write_mmap = argv[optind + 2];      } else { -      Usage(argv[0]); +      Usage(argv[0], default_mem);      }      if (!strcmp(model_type, "probing")) {        if (!set_write_method) config.write_method = Config::WRITE_AFTER; @@ -242,7 +215,7 @@ int main(int argc, char *argv[]) {          }        }      } else { -      Usage(argv[0]); +      Usage(argv[0], default_mem);      }    }    catch (const std::exception &e) { diff --git a/klm/lm/builder/Makefile.am b/klm/lm/builder/Makefile.am new file mode 100644 index 00000000..b5c147fd --- /dev/null +++ b/klm/lm/builder/Makefile.am @@ -0,0 +1,28 @@ +bin_PROGRAMS = builder + +builder_SOURCES = \ +  main.cc \ +  adjust_counts.cc \ +  adjust_counts.hh \ +  corpus_count.cc \ +  corpus_count.hh \ +  discount.hh \ +  header_info.hh \ +  initial_probabilities.cc \ +  initial_probabilities.hh \ +  interpolate.cc \ +  interpolate.hh \ +  joint_order.hh \ +  multi_stream.hh \ +  ngram.hh \ +  ngram_stream.hh \ +  pipeline.cc \ +  pipeline.hh \ +  print.cc \ +  print.hh \ +  sort.hh + +builder_LDADD = ../libklm.a ../../util/double-conversion/libklm_util_double.a ../../util/stream/libklm_util_stream.a ../../util/libklm_util.a $(BOOST_THREAD_LIBS) + +AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm + diff --git a/klm/lm/builder/README.md b/klm/lm/builder/README.md new file mode 100644 index 00000000..be0d35e2 --- /dev/null +++ b/klm/lm/builder/README.md @@ -0,0 +1,47 @@ +Dependencies +============ + +Boost >= 1.42.0 is required.   + +For Ubuntu, +```bash +sudo apt-get install libboost1.48-all-dev +``` + +Alternatively, you can download, compile, and install it yourself: + +```bash +wget http://sourceforge.net/projects/boost/files/boost/1.52.0/boost_1_52_0.tar.gz/download -O boost_1_52_0.tar.gz +tar -xvzf boost_1_52_0.tar.gz +cd boost_1_52_0 +./bootstrap.sh +./b2 +sudo ./b2 install +``` + +Local install options (in a user-space prefix directory) are also possible. See http://www.boost.org/doc/libs/1_52_0/doc/html/bbv2/installation.html. + + +Building +======== + +```bash +bjam +``` +Your distribution might package bjam and boost-build separately from Boost.  Both are required.    + +Usage +===== + +Run +```bash +$ bin/lmplz +``` +to see command line arguments + +Running +======= + +```bash +bin/lmplz -o 5 <text >text.arpa +``` diff --git a/klm/lm/builder/TODO b/klm/lm/builder/TODO new file mode 100644 index 00000000..cb5aef3a --- /dev/null +++ b/klm/lm/builder/TODO @@ -0,0 +1,5 @@ +More tests! +Sharding. +Some way to manage all the crazy config options. +Option to build the binary file directly.   +Interpolation of different orders.   diff --git a/klm/lm/builder/adjust_counts.cc b/klm/lm/builder/adjust_counts.cc new file mode 100644 index 00000000..a6f48011 --- /dev/null +++ b/klm/lm/builder/adjust_counts.cc @@ -0,0 +1,216 @@ +#include "lm/builder/adjust_counts.hh" +#include "lm/builder/multi_stream.hh" +#include "util/stream/timer.hh" + +#include <algorithm> + +namespace lm { namespace builder { + +BadDiscountException::BadDiscountException() throw() {} +BadDiscountException::~BadDiscountException() throw() {} + +namespace { +// Return last word in full that is different.   +const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) { +  const WordIndex *cur_word = full.end() - 1; +  const WordIndex *pre_word = lower_last.end() - 1; +  // Find last difference.   +  for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {} +  return cur_word; +} + +class StatCollector { +  public: +    StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts)  +      : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) { +      memset(&orders_[0], 0, sizeof(OrderStat) * order); +    } + +    ~StatCollector() {} + +    void CalculateDiscounts() { +      counts_.resize(orders_.size()); +      discounts_.resize(orders_.size()); +      for (std::size_t i = 0; i < orders_.size(); ++i) { +        const OrderStat &s = orders_[i]; +        counts_[i] = s.count; + +        for (unsigned j = 1; j < 4; ++j) { +          // TODO: Specialize error message for j == 3, meaning 3+ +          UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for " +              << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any " +              << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?"); +        } + +        // See equation (26) in Chen and Goodman. +        discounts_[i].amount[0] = 0.0; +        float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]); +        for (unsigned j = 1; j < 4; ++j) { +          discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]); +          UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]); +        } +      } +    } + +    void Add(std::size_t order_minus_1, uint64_t count) { +      OrderStat &stat = orders_[order_minus_1]; +      ++stat.count; +      if (count < 5) ++stat.n[count]; +    } + +    void AddFull(uint64_t count) { +      ++full_.count; +      if (count < 5) ++full_.n[count]; +    } + +  private: +    struct OrderStat { +      // n_1 in equation 26 of Chen and Goodman etc +      uint64_t n[5]; +      uint64_t count; +    }; + +    std::vector<OrderStat> orders_; +    OrderStat &full_; + +    std::vector<uint64_t> &counts_; +    std::vector<Discount> &discounts_; +}; + +// Reads all entries in order like NGramStream does.   +// But deletes any entries that have <s> in the 1st (not 0th) position on the +// way out by putting other entries in their place.  This disrupts the sort +// order but we don't care because the data is going to be sorted again.   +class CollapseStream { +  public: +    CollapseStream(const util::stream::ChainPosition &position) : +      current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), +      block_(position) { +      StartBlock(); +    } + +    const NGram &operator*() const { return current_; } +    const NGram *operator->() const { return ¤t_; } + +    operator bool() const { return block_; } + +    CollapseStream &operator++() { +      assert(block_); +      if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) { +        memcpy(current_.Base(), copy_from_, current_.TotalSize()); +        UpdateCopyFrom(); +      } +      current_.NextInMemory(); +      uint8_t *block_base = static_cast<uint8_t*>(block_->Get()); +      if (current_.Base() == block_base + block_->ValidSize()) { +        block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base); +        ++block_; +        StartBlock(); +      } +      return *this; +    } + +  private: +    void StartBlock() { +      for (; ; ++block_) { +        if (!block_) return; +        if (block_->ValidSize()) break; +      } +      current_.ReBase(block_->Get()); +      copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize(); +      UpdateCopyFrom(); +    } + +    // Find last without bos.   +    void UpdateCopyFrom() { +      for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) { +        if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break; +      } +    } + +    NGram current_; + +    // Goes backwards in the block +    uint8_t *copy_from_; + +    util::stream::Link block_; +}; + +} // namespace + +void AdjustCounts::Run(const ChainPositions &positions) { +  UTIL_TIMER("(%w s) Adjusted counts\n"); + +  const std::size_t order = positions.size(); +  StatCollector stats(order, counts_, discounts_); +  if (order == 1) { +    // Only unigrams.  Just collect stats.   +    for (NGramStream full(positions[0]); full; ++full)  +      stats.AddFull(full->Count()); +    stats.CalculateDiscounts(); +    return; +  } + +  NGramStreams streams; +  streams.Init(positions, positions.size() - 1); +  CollapseStream full(positions[positions.size() - 1]); + +  // Initialization: <unk> has count 0 and so does <s>.   +  NGramStream *lower_valid = streams.begin(); +  streams[0]->Count() = 0; +  *streams[0]->begin() = kUNK; +  stats.Add(0, 0); +  (++streams[0])->Count() = 0; +  *streams[0]->begin() = kBOS; +  // not in stats because it will get put in later.   + +  // iterate over full (the stream of the highest order ngrams) +  for (; full; ++full) { +    const WordIndex *different = FindDifference(*full, **lower_valid); +    std::size_t same = full->end() - 1 - different; +    // Increment the adjusted count.   +    if (same) ++streams[same - 1]->Count(); + +    // Output all the valid ones that changed.   +    for (; lower_valid >= &streams[same]; --lower_valid) { +      stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count()); +      ++*lower_valid; +    } + +    // This is here because bos is also const WordIndex *, so copy gets +    // consistent argument types.   +    const WordIndex *full_end = full->end(); +    // Initialize and mark as valid up to bos.   +    const WordIndex *bos; +    for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) { +      ++lower_valid; +      std::copy(bos, full_end, (*lower_valid)->begin()); +      (*lower_valid)->Count() = 1; +    } +    // Now bos indicates where <s> is or is the 0th word of full.   +    if (bos != full->begin()) { +      // There is an <s> beyond the 0th word.   +      NGramStream &to = *++lower_valid; +      std::copy(bos, full_end, to->begin()); +      to->Count() = full->Count();   +    } else { +      stats.AddFull(full->Count()); +    } +    assert(lower_valid >= &streams[0]); +  } + +  // Output everything valid. +  for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) { +    stats.Add(s - streams.begin(), (*s)->Count()); +    ++*s; +  } +  // Poison everyone!  Except the N-grams which were already poisoned by the input.    +  for (NGramStream *s = streams.begin(); s != streams.end(); ++s) +    s->Poison(); + +  stats.CalculateDiscounts(); + +  // NOTE: See special early-return case for unigrams near the top of this function +} + +}} // namespaces diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh new file mode 100644 index 00000000..f38ff79d --- /dev/null +++ b/klm/lm/builder/adjust_counts.hh @@ -0,0 +1,44 @@ +#ifndef LM_BUILDER_ADJUST_COUNTS__ +#define LM_BUILDER_ADJUST_COUNTS__ + +#include "lm/builder/discount.hh" +#include "util/exception.hh" + +#include <vector> + +#include <stdint.h> + +namespace lm { +namespace builder { + +class ChainPositions; + +class BadDiscountException : public util::Exception { +  public: +    BadDiscountException() throw(); +    ~BadDiscountException() throw(); +}; + +/* Compute adjusted counts.   + * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. + * Output: [1,N]-grams with adjusted counts.   + * [1,N)-grams are in suffix order + * N-grams are in undefined order (they're going to be sorted anyway). + */ +class AdjustCounts { +  public: +    AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts) +      : counts_(counts), discounts_(discounts) {} + +    void Run(const ChainPositions &positions); + +  private: +    std::vector<uint64_t> &counts_; +    std::vector<Discount> &discounts_; +}; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_ADJUST_COUNTS__ + diff --git a/klm/lm/builder/adjust_counts_test.cc b/klm/lm/builder/adjust_counts_test.cc new file mode 100644 index 00000000..68b5f33e --- /dev/null +++ b/klm/lm/builder/adjust_counts_test.cc @@ -0,0 +1,106 @@ +#include "lm/builder/adjust_counts.hh" + +#include "lm/builder/multi_stream.hh" +#include "util/scoped.hh" + +#include <boost/thread/thread.hpp> +#define BOOST_TEST_MODULE AdjustCounts +#include <boost/test/unit_test.hpp> + +namespace lm { namespace builder { namespace { + +class KeepCopy { +  public: +    KeepCopy() : size_(0) {} + +    void Run(const util::stream::ChainPosition &position) { +      for (util::stream::Link link(position); link; ++link) { +        mem_.call_realloc(size_ + link->ValidSize()); +        memcpy(static_cast<uint8_t*>(mem_.get()) + size_, link->Get(), link->ValidSize()); +        size_ += link->ValidSize(); +      } +    } + +    uint8_t *Get() { return static_cast<uint8_t*>(mem_.get()); } +    std::size_t Size() const { return size_; } + +  private: +    util::scoped_malloc mem_; +    std::size_t size_; +}; + +struct Gram4 { +  WordIndex ids[4]; +  uint64_t count; +}; + +class WriteInput { +  public: +    void Run(const util::stream::ChainPosition &position) { +      NGramStream input(position); +      Gram4 grams[] = { +        {{0,0,0,0},10}, +        {{0,0,3,0},3}, +        // bos +        {{1,1,1,2},5}, +        {{0,0,3,2},5}, +      }; +      for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) { +        memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4); +        input->Count() = grams[i].count; +      } +      input.Poison(); +    } +}; + +BOOST_AUTO_TEST_CASE(Simple) { +  KeepCopy outputs[4]; +  std::vector<uint64_t> counts; +  std::vector<Discount> discount; +  { +    util::stream::ChainConfig config; +    config.total_memory = 100; +    config.block_count = 1; +    Chains chains(4); +    for (unsigned i = 0; i < 4; ++i) { +      config.entry_size = NGram::TotalSize(i + 1); +      chains.push_back(config); +    } + +    chains[3] >> WriteInput(); +    ChainPositions for_adjust(chains); +    for (unsigned i = 0; i < 4; ++i) { +      chains[i] >> boost::ref(outputs[i]); +    } +    chains >> util::stream::kRecycle; +    BOOST_CHECK_THROW(AdjustCounts(counts, discount).Run(for_adjust), BadDiscountException); +  } +  BOOST_REQUIRE_EQUAL(4UL, counts.size()); +  BOOST_CHECK_EQUAL(4UL, counts[0]); +  // These are no longer set because the discounts are bad.   +/*  BOOST_CHECK_EQUAL(4UL, counts[1]); +  BOOST_CHECK_EQUAL(3UL, counts[2]); +  BOOST_CHECK_EQUAL(3UL, counts[3]);*/ +  BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size()); +  NGram uni(outputs[0].Get(), 1); +  BOOST_CHECK_EQUAL(kUNK, *uni.begin()); +  BOOST_CHECK_EQUAL(0ULL, uni.Count()); +  uni.NextInMemory(); +  BOOST_CHECK_EQUAL(kBOS, *uni.begin()); +  BOOST_CHECK_EQUAL(0ULL, uni.Count()); +  uni.NextInMemory(); +  BOOST_CHECK_EQUAL(0UL, *uni.begin()); +  BOOST_CHECK_EQUAL(2ULL, uni.Count()); +  uni.NextInMemory(); +  BOOST_CHECK_EQUAL(2ULL, uni.Count()); +  BOOST_CHECK_EQUAL(2UL, *uni.begin()); + +  BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size()); +  NGram bi(outputs[1].Get(), 2); +  BOOST_CHECK_EQUAL(0UL, *bi.begin()); +  BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1)); +  BOOST_CHECK_EQUAL(1ULL, bi.Count()); +  bi.NextInMemory(); +} + +}}} // namespaces diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc new file mode 100644 index 00000000..abea4ed0 --- /dev/null +++ b/klm/lm/builder/corpus_count.cc @@ -0,0 +1,224 @@ +#include "lm/builder/corpus_count.hh" + +#include "lm/builder/ngram.hh" +#include "lm/lm_exception.hh" +#include "lm/word_index.hh" +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/murmur_hash.hh" +#include "util/probing_hash_table.hh" +#include "util/scoped.hh" +#include "util/stream/chain.hh" +#include "util/stream/timer.hh" +#include "util/tokenize_piece.hh" + +#include <boost/unordered_set.hpp> +#include <boost/unordered_map.hpp> + +#include <functional> + +#include <stdint.h> + +namespace lm { +namespace builder { +namespace { + +class VocabHandout { +  public: +    explicit VocabHandout(int fd) { +      util::scoped_fd duped(util::DupOrThrow(fd)); +      word_list_.reset(util::FDOpenOrThrow(duped)); +       +      Lookup("<unk>"); // Force 0 +      Lookup("<s>"); // Force 1 +      Lookup("</s>"); // Force 2 +    } + +    WordIndex Lookup(const StringPiece &word) { +      uint64_t hashed = util::MurmurHashNative(word.data(), word.size()); +      std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size()))); +      if (ret.second) { +        char null_delimit = 0; +        util::WriteOrThrow(word_list_.get(), word.data(), word.size()); +        util::WriteOrThrow(word_list_.get(), &null_delimit, 1); +        UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words.  Change WordIndex to uint64_t in lm/word_index.hh."); +      } +      return ret.first->second; +    } + +    WordIndex Size() const { +      return seen_.size(); +    } + +  private: +    typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen; + +    Seen seen_; + +    util::scoped_FILE word_list_; +}; + +class DedupeHash : public std::unary_function<const WordIndex *, bool> { +  public: +    explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} + +    std::size_t operator()(const WordIndex *start) const { +      return util::MurmurHashNative(start, size_); +    } +     +  private: +    const std::size_t size_; +}; + +class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> { +  public: +    explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {} +     +    bool operator()(const WordIndex *first, const WordIndex *second) const { +      return !memcmp(first, second, size_); +    }  +     +  private: +    const std::size_t size_; +}; + +struct DedupeEntry { +  typedef WordIndex *Key; +  Key GetKey() const { return key; } +  Key key; +  static DedupeEntry Construct(WordIndex *at) { +    DedupeEntry ret; +    ret.key = at; +    return ret; +  } +}; + +typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe; + +const float kProbingMultiplier = 1.5; + +class Writer { +  public: +    Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)  +      : block_(position), gram_(block_->Get(), order), +        dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()), +        dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)), +        buffer_(new WordIndex[order - 1]), +        block_size_(position.GetChain().BlockSize()) { +      dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); +      assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size); +      if (order == 1) { +        // Add special words.  AdjustCounts is responsible if order != 1.     +        AddUnigramWord(kUNK); +        AddUnigramWord(kBOS); +      } +    } + +    ~Writer() { +      block_->SetValidSize(reinterpret_cast<const uint8_t*>(gram_.begin()) - static_cast<const uint8_t*>(block_->Get())); +      (++block_).Poison(); +    } + +    // Write context with a bunch of <s> +    void StartSentence() { +      for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) { +        *i = kBOS; +      } +    } + +    void Append(WordIndex word) { +      *(gram_.end() - 1) = word; +      Dedupe::MutableIterator at; +      bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at); +      if (found) { +        // Already present. +        NGram already(at->key, gram_.Order()); +        ++(already.Count()); +        // Shift left by one. +        memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1)); +        return; +      } +      // Complete the write.   +      gram_.Count() = 1; +      // Prepare the next n-gram.   +      if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) { +        NGram last(gram_); +        gram_.NextInMemory(); +        std::copy(last.begin() + 1, last.end(), gram_.begin()); +        return; +      } +      // Block end.  Need to store the context in a temporary buffer.   +      std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); +      dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); +      block_->SetValidSize(block_size_); +      gram_.ReBase((++block_)->Get()); +      std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin()); +    } + +  private: +    void AddUnigramWord(WordIndex index) { +      *gram_.begin() = index; +      gram_.Count() = 0; +      gram_.NextInMemory(); +      if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) { +        block_->SetValidSize(block_size_); +        gram_.ReBase((++block_)->Get()); +      } +    } + +    util::stream::Link block_; + +    NGram gram_; + +    // This is the memory behind the invalid value in dedupe_. +    std::vector<WordIndex> dedupe_invalid_; +    // Hash table combiner implementation. +    Dedupe dedupe_; + +    // Small buffer to hold existing ngrams when shifting across a block boundary.   +    boost::scoped_array<WordIndex> buffer_; + +    const std::size_t block_size_; +}; + +} // namespace + +float CorpusCount::DedupeMultiplier(std::size_t order) { +  return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order)); +} + +CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)  +  : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), +    dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), +    dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { +  token_count_ = 0; +  type_count_ = 0; +} + +void CorpusCount::Run(const util::stream::ChainPosition &position) { +  UTIL_TIMER("(%w s) Counted n-grams\n"); + +  VocabHandout vocab(vocab_write_); +  const WordIndex end_sentence = vocab.Lookup("</s>"); +  Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); +  uint64_t count = 0; +  StringPiece delimiters("\0\t\r ", 4); +  try { +    while(true) { +      StringPiece line(from_.ReadLine()); +      writer.StartSentence(); +      for (util::TokenIter<util::AnyCharacter, true> w(line, delimiters); w; ++w) { +        WordIndex word = vocab.Lookup(*w); +        UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus.  I plan to support models containing <unk> in the future."); +        writer.Append(word); +        ++count; +      } +      writer.Append(end_sentence); +    } +  } catch (const util::EndOfFileException &e) {} +  token_count_ = count; +  type_count_ = vocab.Size(); +} + +} // namespace builder +} // namespace lm diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh new file mode 100644 index 00000000..e255bad1 --- /dev/null +++ b/klm/lm/builder/corpus_count.hh @@ -0,0 +1,42 @@ +#ifndef LM_BUILDER_CORPUS_COUNT__ +#define LM_BUILDER_CORPUS_COUNT__ + +#include "lm/word_index.hh" +#include "util/scoped.hh" + +#include <cstddef> +#include <string> +#include <stdint.h> + +namespace util { +class FilePiece; +namespace stream { +class ChainPosition; +} // namespace stream +} // namespace util + +namespace lm { +namespace builder { + +class CorpusCount { +  public: +    // Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size +    static float DedupeMultiplier(std::size_t order); + +    CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block); + +    void Run(const util::stream::ChainPosition &position); + +  private: +    util::FilePiece &from_; +    int vocab_write_; +    uint64_t &token_count_; +    WordIndex &type_count_; + +    std::size_t dedupe_mem_size_; +    util::scoped_malloc dedupe_mem_; +}; + +} // namespace builder +} // namespace lm +#endif // LM_BUILDER_CORPUS_COUNT__ diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc new file mode 100644 index 00000000..8d53ca9d --- /dev/null +++ b/klm/lm/builder/corpus_count_test.cc @@ -0,0 +1,76 @@ +#include "lm/builder/corpus_count.hh" + +#include "lm/builder/ngram.hh" +#include "lm/builder/ngram_stream.hh" + +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/tokenize_piece.hh" +#include "util/stream/chain.hh" +#include "util/stream/stream.hh" + +#define BOOST_TEST_MODULE CorpusCountTest +#include <boost/test/unit_test.hpp> + +namespace lm { namespace builder { namespace { + +#define Check(str, count) { \ +  BOOST_REQUIRE(stream); \ +  w = stream->begin(); \ +  for (util::TokenIter<util::AnyCharacter, true> t(str, " "); t; ++t, ++w) { \ +    BOOST_CHECK_EQUAL(*t, v[*w]); \ +  } \ +  BOOST_CHECK_EQUAL((uint64_t)count, stream->Count()); \ +  ++stream; \ +} + +BOOST_AUTO_TEST_CASE(Short) { +  util::scoped_fd input_file(util::MakeTemp("corpus_count_test_temp")); +  const char input[] = "looking on a little more loin\non a little more loin\non foo little more loin\nbar\n\n"; +  // Blocks of 10 are +  // looking on a little more loin </s> on a little[duplicate] more[duplicate] loin[duplicate] </s>[duplicate] on[duplicate] foo +  // little more loin </s> bar </s> </s> + +  util::WriteOrThrow(input_file.get(), input, sizeof(input) - 1); +  util::FilePiece input_piece(input_file.release(), "temp file"); + +  util::stream::ChainConfig config; +  config.entry_size = NGram::TotalSize(3); +  config.total_memory = config.entry_size * 20; +  config.block_count = 2; + +  util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab")); + +  util::stream::Chain chain(config); +  NGramStream stream; +  uint64_t token_count; +  WordIndex type_count; +  CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize()); +  chain >> boost::ref(counter) >> stream >> util::stream::kRecycle; + +  const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"}; + +  WordIndex *w; + +  Check("<s> <s> looking", 1); +  Check("<s> looking on", 1); +  Check("looking on a", 1); +  Check("on a little", 2); +  Check("a little more", 2); +  Check("little more loin", 2); +  Check("more loin </s>", 2); +  Check("<s> <s> on", 2); +  Check("<s> on a", 1); +  Check("<s> on foo", 1); +  Check("on foo little", 1); +  Check("foo little more", 1); +  Check("little more loin", 1); +  Check("more loin </s>", 1); +  Check("<s> <s> bar", 1); +  Check("<s> bar </s>", 1); +  Check("<s> <s> </s>", 1); +  BOOST_CHECK(!stream); +  BOOST_CHECK_EQUAL(sizeof(v) / sizeof(const char*), type_count); +} + +}}} // namespaces diff --git a/klm/lm/builder/discount.hh b/klm/lm/builder/discount.hh new file mode 100644 index 00000000..754fb20d --- /dev/null +++ b/klm/lm/builder/discount.hh @@ -0,0 +1,26 @@ +#ifndef BUILDER_DISCOUNT__ +#define BUILDER_DISCOUNT__ + +#include <algorithm> + +#include <inttypes.h> + +namespace lm { +namespace builder { + +struct Discount { +  float amount[4]; + +  float Get(uint64_t count) const { +    return amount[std::min<uint64_t>(count, 3)]; +  } + +  float Apply(uint64_t count) const { +    return static_cast<float>(count) - Get(count); +  } +}; + +} // namespace builder +} // namespace lm + +#endif // BUILDER_DISCOUNT__ diff --git a/klm/lm/builder/header_info.hh b/klm/lm/builder/header_info.hh new file mode 100644 index 00000000..ccca1456 --- /dev/null +++ b/klm/lm/builder/header_info.hh @@ -0,0 +1,20 @@ +#ifndef LM_BUILDER_HEADER_INFO__ +#define LM_BUILDER_HEADER_INFO__ + +#include <string> +#include <stdint.h> + +// Some configuration info that is used to add +// comments to the beginning of an ARPA file +struct HeaderInfo { +  const std::string input_file; +  const uint64_t token_count; + +  HeaderInfo(const std::string& input_file_in, uint64_t token_count_in) +    : input_file(input_file_in), token_count(token_count_in) {} + +  // TODO: Add smoothing type +  // TODO: More info if multiple models were interpolated +}; + +#endif diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc new file mode 100644 index 00000000..58b42a20 --- /dev/null +++ b/klm/lm/builder/initial_probabilities.cc @@ -0,0 +1,136 @@ +#include "lm/builder/initial_probabilities.hh" + +#include "lm/builder/discount.hh" +#include "lm/builder/ngram_stream.hh" +#include "lm/builder/sort.hh" +#include "util/file.hh" +#include "util/stream/chain.hh" +#include "util/stream/io.hh" +#include "util/stream/stream.hh" + +#include <vector> + +namespace lm { namespace builder { + +namespace { +struct BufferEntry { +  // Gamma from page 20 of Chen and Goodman.   +  float gamma; +  // \sum_w a(c w) for all w.   +  float denominator; +}; + +// Extract an array of gamma from an array of BufferEntry.   +class OnlyGamma { +  public: +    void Run(const util::stream::ChainPosition &position) { +      for (util::stream::Link block_it(position); block_it; ++block_it) { +        float *out = static_cast<float*>(block_it->Get()); +        const float *in = out; +        const float *end = static_cast<const float*>(block_it->ValidEnd()); +        for (out += 1, in += 2; in < end; out += 1, in += 2) { +          *out = *in; +        } +        block_it->SetValidSize(block_it->ValidSize() / 2); +      } +    } +}; + +class AddRight { +  public: +    AddRight(const Discount &discount, const util::stream::ChainPosition &input)  +      : discount_(discount), input_(input) {} + +    void Run(const util::stream::ChainPosition &output) { +      NGramStream in(input_); +      util::stream::Stream out(output); + +      std::vector<WordIndex> previous(in->Order() - 1); +      const std::size_t size = sizeof(WordIndex) * previous.size(); +      for(; in; ++out) { +        memcpy(&previous[0], in->begin(), size); +        uint64_t denominator = 0; +        uint64_t counts[4]; +        memset(counts, 0, sizeof(counts)); +        do { +          denominator += in->Count(); +          ++counts[std::min(in->Count(), static_cast<uint64_t>(3))]; +        } while (++in && !memcmp(&previous[0], in->begin(), size)); +        BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get()); +        entry.denominator = static_cast<float>(denominator); +        entry.gamma = 0.0; +        for (unsigned i = 1; i <= 3; ++i) { +          entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]); +        } +        entry.gamma /= entry.denominator; +      } +      out.Poison(); +    } + +  private: +    const Discount &discount_; +    const util::stream::ChainPosition input_; +}; + +class MergeRight { +  public: +    MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount) +      : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {} + +    // calculate the initial probability of each n-gram (before order-interpolation) +    // Run() gets invoked once for each order +    void Run(const util::stream::ChainPosition &primary) { +      util::stream::Stream summed(from_adder_); + +      NGramStream grams(primary); + +      // Without interpolation, the interpolation weight goes to <unk>. +      if (grams->Order() == 1 && !interpolate_unigrams_) { +        BufferEntry sums(*static_cast<const BufferEntry*>(summed.Get())); +        assert(*grams->begin() == kUNK); +        grams->Value().uninterp.prob = sums.gamma; +        grams->Value().uninterp.gamma = 0.0; +        while (++grams) { +          grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator; +          grams->Value().uninterp.gamma = 0.0; +        } +        ++summed; +        return; +      } + +      std::vector<WordIndex> previous(grams->Order() - 1); +      const std::size_t size = sizeof(WordIndex) * previous.size(); +      for (; grams; ++summed) { +        memcpy(&previous[0], grams->begin(), size); +        const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get()); +        do { +          Payload &pay = grams->Value(); +          pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator; +          pay.uninterp.gamma = sums.gamma; +        } while (++grams && !memcmp(&previous[0], grams->begin(), size)); +      } +    } + +  private: +    bool interpolate_unigrams_; +    util::stream::ChainPosition from_adder_; +    Discount discount_; +}; + +} // namespace + +void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) { +  util::stream::ChainConfig gamma_config = config.adder_out; +  gamma_config.entry_size = sizeof(BufferEntry); +  for (size_t i = 0; i < primary.size(); ++i) { +    util::stream::ChainPosition second(second_in[i].Add()); +    second_in[i] >> util::stream::kRecycle; +    gamma_out.push_back(gamma_config); +    gamma_out[i] >> AddRight(discounts[i], second); +    primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]); +    // Don't bother with the OnlyGamma thread for something to discard.   +    if (i) gamma_out[i] >> OnlyGamma(); +  } +} + +}} // namespaces diff --git a/klm/lm/builder/initial_probabilities.hh b/klm/lm/builder/initial_probabilities.hh new file mode 100644 index 00000000..626388eb --- /dev/null +++ b/klm/lm/builder/initial_probabilities.hh @@ -0,0 +1,34 @@ +#ifndef LM_BUILDER_INITIAL_PROBABILITIES__ +#define LM_BUILDER_INITIAL_PROBABILITIES__ + +#include "lm/builder/discount.hh" +#include "util/stream/config.hh" + +#include <vector> + +namespace lm { +namespace builder { +class Chains; + +struct InitialProbabilitiesConfig { +  // These should be small buffers to keep the adder from getting too far ahead +  util::stream::ChainConfig adder_in; +  util::stream::ChainConfig adder_out; +  // SRILM doesn't normally interpolate unigrams.   +  bool interpolate_unigrams; +}; + +/* Compute initial (uninterpolated) probabilities + * primary: the normal chain of n-grams.  Incoming is context sorted adjusted + *   counts.  Outgoing has uninterpolated probabilities for use by Interpolate. + * second_in: a second copy of the primary input.  Discard the output.   + * gamma_out: Computed gamma values are output on these chains in suffix order. + *   The values are bare floats and should be buffered for interpolation to + *   use.   + */ +void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out); + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_INITIAL_PROBABILITIES__ diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc new file mode 100644 index 00000000..50026806 --- /dev/null +++ b/klm/lm/builder/interpolate.cc @@ -0,0 +1,65 @@ +#include "lm/builder/interpolate.hh" + +#include "lm/builder/joint_order.hh" +#include "lm/builder/multi_stream.hh" +#include "lm/builder/sort.hh" +#include "lm/lm_exception.hh" + +#include <assert.h> + +namespace lm { namespace builder { +namespace { + +class Callback { +  public: +    Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) { +      probs_[0] = uniform_prob; +      for (std::size_t i = 0; i < backoffs.size(); ++i) { +        backoffs_.push_back(backoffs[i]); +      } +    } + +    ~Callback() { +      for (std::size_t i = 0; i < backoffs_.size(); ++i) { +        if (backoffs_[i]) { +          std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl; +          abort(); +        } +      } +    } + +    void Enter(unsigned order_minus_1, NGram &gram) { +      Payload &pay = gram.Value(); +      pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; +      probs_[order_minus_1 + 1] = pay.complete.prob; +      pay.complete.prob = log10(pay.complete.prob); +      // TODO: this is a hack to skip n-grams that don't appear as context.  Pruning will require some different handling.   +      if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { +        pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get())); +        ++backoffs_[order_minus_1]; +      } else { +        // Not a context.   +        pay.complete.backoff = 0.0; +      } +    } + +    void Exit(unsigned, const NGram &) const {} + +  private: +    FixedArray<util::stream::Stream> backoffs_; + +    std::vector<float> probs_; +}; +} // namespace + +Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)  +  : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {} + +// perform order-wise interpolation +void Interpolate::Run(const ChainPositions &positions) { +  assert(positions.size() == backoffs_.size() + 1); +  Callback callback(uniform_prob_, backoffs_); +  JointOrder<Callback, SuffixOrder>(positions, callback); +} + +}} // namespaces diff --git a/klm/lm/builder/interpolate.hh b/klm/lm/builder/interpolate.hh new file mode 100644 index 00000000..9268d404 --- /dev/null +++ b/klm/lm/builder/interpolate.hh @@ -0,0 +1,27 @@ +#ifndef LM_BUILDER_INTERPOLATE__ +#define LM_BUILDER_INTERPOLATE__ + +#include <stdint.h> + +#include "lm/builder/multi_stream.hh" + +namespace lm { namespace builder { +  +/* Interpolate step.   + * Input: suffix sorted n-grams with (p_uninterpolated, gamma) from + * InitialProbabilities. + * Output: suffix sorted n-grams with complete probability + */ +class Interpolate { +  public: +    explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs); + +    void Run(const ChainPositions &positions); + +  private: +    float uniform_prob_; +    ChainPositions backoffs_; +}; + +}} // namespaces +#endif // LM_BUILDER_INTERPOLATE__ diff --git a/klm/lm/builder/joint_order.hh b/klm/lm/builder/joint_order.hh new file mode 100644 index 00000000..b5620144 --- /dev/null +++ b/klm/lm/builder/joint_order.hh @@ -0,0 +1,43 @@ +#ifndef LM_BUILDER_JOINT_ORDER__ +#define LM_BUILDER_JOINT_ORDER__ + +#include "lm/builder/multi_stream.hh" +#include "lm/lm_exception.hh" + +#include <string.h> + +namespace lm { namespace builder { + +template <class Callback, class Compare> void JointOrder(const ChainPositions &positions, Callback &callback) { +  // Allow matching to reference streams[-1]. +  NGramStreams streams_with_dummy; +  streams_with_dummy.InitWithDummy(positions); +  NGramStream *streams = streams_with_dummy.begin() + 1; + +  unsigned int order; +  for (order = 0; order < positions.size() && streams[order]; ++order) {} +  assert(order); // should always have <unk>. +  unsigned int current = 0; +  while (true) { +    // Does the context match the lower one?   +    if (!memcmp(streams[static_cast<int>(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset, sizeof(WordIndex) * current)) { +      callback.Enter(current, *streams[current]); +      // Transition to looking for extensions.   +      if (++current < order) continue; +    } +    // No extension left.   +    while(true) { +      assert(current > 0); +      --current; +      callback.Exit(current, *streams[current]); +      if (++streams[current]) break; +      UTIL_THROW_IF(order != current + 1, FormatLoadException, "Detected n-gram without matching suffix"); +      order = current; +      if (!order) return; +    } +  } +} + +}} // namespaces + +#endif // LM_BUILDER_JOINT_ORDER__ diff --git a/klm/lm/builder/main.cc b/klm/lm/builder/main.cc new file mode 100644 index 00000000..90b9dca2 --- /dev/null +++ b/klm/lm/builder/main.cc @@ -0,0 +1,94 @@ +#include "lm/builder/pipeline.hh" +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/usage.hh" + +#include <iostream> + +#include <boost/program_options.hpp> + +namespace { +class SizeNotify { +  public: +    SizeNotify(std::size_t &out) : behind_(out) {} + +    void operator()(const std::string &from) { +      behind_ = util::ParseSize(from); +    } + +  private: +    std::size_t &behind_; +}; + +boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value) { +  return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value); +} + +} // namespace + +int main(int argc, char *argv[]) { +  try { +    namespace po = boost::program_options; +    po::options_description options("Language model building options"); +    lm::builder::PipelineConfig pipeline; + +    options.add_options() +      ("order,o", po::value<std::size_t>(&pipeline.order)->required(), "Order of the model") +      ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") +      ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") +      ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") +      ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step") +      ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") +      ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") +      ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)") +      ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") +      ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); +    if (argc == 1) { +      std::cerr <<  +        "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n" +        "Please cite:\n" +        "@inproceedings{kenlm,\n" +        "author    = {Kenneth Heafield},\n" +        "title     = {{KenLM}: Faster and Smaller Language Model Queries},\n" +        "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n" +        "month     = {July}, year={2011},\n" +        "address   = {Edinburgh, UK},\n" +        "publisher = {Association for Computational Linguistics},\n" +        "}\n\n" +        "Provide the corpus on stdin.  The ARPA file will be written to stdout.  Order of\n" +        "the model (-o) is the only mandatory option.  As this is an on-disk program,\n" +        "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n" +        "Memory sizes are specified like GNU sort: a number followed by a unit character.\n" +        "Valid units are \% for percentage of memory (supported platforms only) and (in\n" +        "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y.  Default is K (*1024).\n\n"; +      std::cerr << options << std::endl; +      return 1; +    } +    po::variables_map vm; +    po::store(po::parse_command_line(argc, argv, options), vm); +    po::notify(vm); + +    util::NormalizeTempPrefix(pipeline.sort.temp_prefix); + +    lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; +    // TODO: evaluate options for these.   +    initial.adder_in.total_memory = 32768; +    initial.adder_in.block_count = 2; +    initial.adder_out.total_memory = 32768; +    initial.adder_out.block_count = 2; +    pipeline.read_backoffs = initial.adder_out; + +    // Read from stdin +    try { +      lm::builder::Pipeline(pipeline, 0, 1); +    } catch (const util::MallocException &e) { +      std::cerr << e.what() << std::endl; +      std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl; +      return 1; +    } +    util::PrintUsage(std::cerr); +  } catch (const std::exception &e) { +    std::cerr << e.what() << std::endl; +    return 1; +  } +} diff --git a/klm/lm/builder/multi_stream.hh b/klm/lm/builder/multi_stream.hh new file mode 100644 index 00000000..707a98c7 --- /dev/null +++ b/klm/lm/builder/multi_stream.hh @@ -0,0 +1,180 @@ +#ifndef LM_BUILDER_MULTI_STREAM__ +#define LM_BUILDER_MULTI_STREAM__ + +#include "lm/builder/ngram_stream.hh" +#include "util/scoped.hh" +#include "util/stream/chain.hh" + +#include <cstddef> +#include <new> + +#include <assert.h> +#include <stdlib.h> + +namespace lm { namespace builder { + +template <class T> class FixedArray { +  public: +    explicit FixedArray(std::size_t count) { +      Init(count); +    } + +    FixedArray() : newed_end_(NULL) {} + +    void Init(std::size_t count) { +      assert(!block_.get()); +      block_.reset(malloc(sizeof(T) * count)); +      if (!block_.get()) throw std::bad_alloc(); +      newed_end_ = begin(); +    } + +    FixedArray(const FixedArray &from) { +      std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get()); +      Init(size); +      for (std::size_t i = 0; i < size; ++i) { +        new(end()) T(from[i]); +        Constructed(); +      } +    } + +    ~FixedArray() { clear(); } + +    T *begin() { return static_cast<T*>(block_.get()); } +    const T *begin() const { return static_cast<const T*>(block_.get()); } +    // Always call Constructed after successful completion of new.   +    T *end() { return newed_end_; } +    const T *end() const { return newed_end_; } + +    T &back() { return *(end() - 1); } +    const T &back() const { return *(end() - 1); } + +    std::size_t size() const { return end() - begin(); } +    bool empty() const { return begin() == end(); } + +    T &operator[](std::size_t i) { return begin()[i]; } +    const T &operator[](std::size_t i) const { return begin()[i]; } + +    template <class C> void push_back(const C &c) { +      new (end()) T(c); +      Constructed(); +    } + +    void clear() { +      for (T *i = begin(); i != end(); ++i) +        i->~T(); +      newed_end_ = begin(); +    } + +  protected: +    void Constructed() { +      ++newed_end_; +    } + +  private: +    util::scoped_malloc block_; + +    T *newed_end_; +}; + +class Chains; + +class ChainPositions : public FixedArray<util::stream::ChainPosition> { +  public: +    ChainPositions() {} + +    void Init(Chains &chains); + +    explicit ChainPositions(Chains &chains) { +      Init(chains); +    } +}; + +class Chains : public FixedArray<util::stream::Chain> { +  private: +    template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun { +      typedef Chains type; +    }; + +  public: +    explicit Chains(std::size_t limit) : FixedArray<util::stream::Chain>(limit) {} + +    template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) { +      threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker)); +      return *this; +    } + +    template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) { +      threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker)); +      return *this; +    } + +    Chains &operator>>(const util::stream::Recycler &recycler) { +      for (util::stream::Chain *i = begin(); i != end(); ++i)  +        *i >> recycler; +      return *this; +    } + +    void Wait(bool release_memory = true) { +      threads_.clear(); +      for (util::stream::Chain *i = begin(); i != end(); ++i) { +        i->Wait(release_memory); +      } +    } + +  private: +    boost::ptr_vector<util::stream::Thread> threads_; + +    Chains(const Chains &); +    void operator=(const Chains &); +}; + +inline void ChainPositions::Init(Chains &chains) { +  FixedArray<util::stream::ChainPosition>::Init(chains.size()); +  for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) { +    new (end()) util::stream::ChainPosition(i->Add()); Constructed(); +  } +} + +inline Chains &operator>>(Chains &chains, ChainPositions &positions) { +  positions.Init(chains); +  return chains; +} + +class NGramStreams : public FixedArray<NGramStream> { +  public: +    NGramStreams() {} + +    // This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning). +    void InitWithDummy(const ChainPositions &positions) { +      FixedArray<NGramStream>::Init(positions.size() + 1); +      new (end()) NGramStream(); Constructed(); +      for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) { +        push_back(*i); +      } +    } + +    // Limit restricts to positions[0,limit) +    void Init(const ChainPositions &positions, std::size_t limit) { +      FixedArray<NGramStream>::Init(limit); +      for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) { +        push_back(*i); +      } +    } +    void Init(const ChainPositions &positions) { +      Init(positions, positions.size()); +    } + +    NGramStreams(const ChainPositions &positions) { +      Init(positions); +    } +}; + +inline Chains &operator>>(Chains &chains, NGramStreams &streams) { +  ChainPositions positions; +  chains >> positions; +  streams.Init(positions); +  return chains; +} + +}} // namespaces +#endif // LM_BUILDER_MULTI_STREAM__ diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh new file mode 100644 index 00000000..2984ed0b --- /dev/null +++ b/klm/lm/builder/ngram.hh @@ -0,0 +1,84 @@ +#ifndef LM_BUILDER_NGRAM__ +#define LM_BUILDER_NGRAM__ + +#include "lm/weights.hh" +#include "lm/word_index.hh" + +#include <cstddef> + +#include <assert.h> +#include <stdint.h> +#include <string.h> + +namespace lm { +namespace builder { + +struct Uninterpolated { +  float prob;  // Uninterpolated probability. +  float gamma; // Interpolation weight for lower order. +}; + +union Payload { +  uint64_t count; +  Uninterpolated uninterp; +  ProbBackoff complete; +}; + +class NGram { +  public: +    NGram(void *begin, std::size_t order)  +      : begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {} + +    const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); } +    uint8_t *Base() { return reinterpret_cast<uint8_t*>(begin_); } + +    void ReBase(void *to) { +      std::size_t difference = end_ - begin_; +      begin_ = reinterpret_cast<WordIndex*>(to); +      end_ = begin_ + difference; +    } + +    // Would do operator++ but that can get confusing for a stream.   +    void NextInMemory() { +      ReBase(&Value() + 1); +    } + +    // Lower-case in deference to STL.   +    const WordIndex *begin() const { return begin_; } +    WordIndex *begin() { return begin_; } +    const WordIndex *end() const { return end_; } +    WordIndex *end() { return end_; } + +    const Payload &Value() const { return *reinterpret_cast<const Payload *>(end_); } +    Payload &Value() { return *reinterpret_cast<Payload *>(end_); } + +    uint64_t &Count() { return Value().count; } +    const uint64_t Count() const { return Value().count; } + +    std::size_t Order() const { return end_ - begin_; } + +    static std::size_t TotalSize(std::size_t order) { +      return order * sizeof(WordIndex) + sizeof(Payload); +    } +    std::size_t TotalSize() const { +      // Compiler should optimize this.   +      return TotalSize(Order()); +    } +    static std::size_t OrderFromSize(std::size_t size) { +      std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex); +      assert(size == TotalSize(ret)); +      return ret; +    } + +  private: +    WordIndex *begin_, *end_; +}; + +const WordIndex kUNK = 0; +const WordIndex kBOS = 1; +const WordIndex kEOS = 2; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_NGRAM__ diff --git a/klm/lm/builder/ngram_stream.hh b/klm/lm/builder/ngram_stream.hh new file mode 100644 index 00000000..3c994664 --- /dev/null +++ b/klm/lm/builder/ngram_stream.hh @@ -0,0 +1,55 @@ +#ifndef LM_BUILDER_NGRAM_STREAM__ +#define LM_BUILDER_NGRAM_STREAM__ + +#include "lm/builder/ngram.hh" +#include "util/stream/chain.hh" +#include "util/stream/stream.hh" + +#include <cstddef> + +namespace lm { namespace builder { + +class NGramStream { +  public: +    NGramStream() : gram_(NULL, 0) {} + +    NGramStream(const util::stream::ChainPosition &position) : gram_(NULL, 0) { +      Init(position); +    } + +    void Init(const util::stream::ChainPosition &position) { +      stream_.Init(position); +      gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize())); +    } + +    NGram &operator*() { return gram_; } +    const NGram &operator*() const { return gram_; } + +    NGram *operator->() { return &gram_; } +    const NGram *operator->() const { return &gram_; } + +    void *Get() { return stream_.Get(); } +    const void *Get() const { return stream_.Get(); } + +    operator bool() const { return stream_; } +    bool operator!() const { return !stream_; } +    void Poison() { stream_.Poison(); } + +    NGramStream &operator++() { +      ++stream_; +      gram_.ReBase(stream_.Get()); +      return *this; +    } + +  private: +    NGram gram_; +    util::stream::Stream stream_; +}; + +inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) { +  str.Init(chain.Add()); +  return chain; +} + +}} // namespaces +#endif // LM_BUILDER_NGRAM_STREAM__ diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc new file mode 100644 index 00000000..14a1f721 --- /dev/null +++ b/klm/lm/builder/pipeline.cc @@ -0,0 +1,320 @@ +#include "lm/builder/pipeline.hh" + +#include "lm/builder/adjust_counts.hh" +#include "lm/builder/corpus_count.hh" +#include "lm/builder/initial_probabilities.hh" +#include "lm/builder/interpolate.hh" +#include "lm/builder/print.hh" +#include "lm/builder/sort.hh" + +#include "lm/sizes.hh" + +#include "util/exception.hh" +#include "util/file.hh" +#include "util/stream/io.hh" + +#include <algorithm> +#include <iostream> +#include <vector> + +namespace lm { namespace builder { + +namespace { +void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts) { +  std::cerr << "Statistics:\n"; +  for (size_t i = 0; i < counts.size(); ++i) { +    std::cerr << (i + 1) << ' ' << counts[i]; +    for (size_t d = 1; d <= 3; ++d) +      std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d]; +    std::cerr << '\n'; +  } +} + +class Master { +  public: +    explicit Master(const PipelineConfig &config)  +      : config_(config), chains_(config.order), files_(config.order) { +      config_.minimum_block = std::max(NGram::TotalSize(config_.order), config_.minimum_block); +    } + +    const PipelineConfig &Config() const { return config_; } + +    Chains &MutableChains() { return chains_; } + +    template <class T> Master &operator>>(const T &worker) { +      chains_ >> worker; +      return *this; +    } + +    // This takes the (partially) sorted ngrams and sets up for adjusted counts. +    void InitForAdjust(util::stream::Sort<SuffixOrder, AddCombiner> &ngrams, WordIndex types) { +      const std::size_t each_order_min = config_.minimum_block * config_.block_count; +      // We know how many unigrams there are.  Don't allocate more than needed to them. +      const std::size_t min_chains = (config_.order - 1) * each_order_min + +        std::min(types * NGram::TotalSize(1), each_order_min); +      // Do merge sort with calculated laziness. +      const std::size_t merge_using = ngrams.Merge(std::min(config_.TotalMemory() - min_chains, ngrams.DefaultLazy())); + +      std::vector<uint64_t> count_bounds(1, types); +      CreateChains(config_.TotalMemory() - merge_using, count_bounds); +      ngrams.Output(chains_.back(), merge_using); + +      // Setup unigram file.   +      files_.push_back(util::MakeTemp(config_.TempPrefix())); +    } + +    // For initial probabilities, but this is generic. +    void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, Chains &second, util::stream::ChainConfig second_config) { +      // Do merge first before allocating chain memory. +      for (std::size_t i = 1; i < config_.order; ++i) { +        sorts[i - 1].Merge(0); +      } +      // There's no lazy merge, so just divide memory amongst the chains. +      CreateChains(config_.TotalMemory(), counts); +      chains_.back().ActivateProgress(); +      chains_[0] >> files_[0].Source(); +      second_config.entry_size = NGram::TotalSize(1); +      second.push_back(second_config); +      second.back() >> files_[0].Source(); +      for (std::size_t i = 1; i < config_.order; ++i) { +        util::scoped_fd fd(sorts[i - 1].StealCompleted()); +        chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get())); +        chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true); +        second_config.entry_size = NGram::TotalSize(i + 1); +        second.push_back(second_config); +        second.back() >> util::stream::PRead(fd.release(), true); +      } +    } + +    // There is no sort after this, so go for broke on lazy merging. +    template <class Compare> void MaximumLazyInput(const std::vector<uint64_t> &counts, Sorts<Compare> &sorts) { +      // Determine the minimum we can use for all the chains. +      std::size_t min_chains = 0; +      for (std::size_t i = 0; i < config_.order; ++i) { +        min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast<uint64_t>(config_.minimum_block)); +      } +      std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains); +      std::vector<std::size_t> laziness; +      // Prioritize longer n-grams. +      for (util::stream::Sort<SuffixOrder> *i = sorts.end() - 1; i >= sorts.begin(); --i) { +        laziness.push_back(i->Merge(for_merge)); +        assert(for_merge >= laziness.back()); +        for_merge -= laziness.back(); +      } +      std::reverse(laziness.begin(), laziness.end()); + +      CreateChains(for_merge + min_chains, counts); +      chains_.back().ActivateProgress(); +      chains_[0] >> files_[0].Source(); +      for (std::size_t i = 1; i < config_.order; ++i) { +        sorts[i - 1].Output(chains_[i], laziness[i - 1]); +      } +    } + +    void BufferFinal(const std::vector<uint64_t> &counts) { +      chains_[0] >> files_[0].Sink(); +      for (std::size_t i = 1; i < config_.order; ++i) { +        files_.push_back(util::MakeTemp(config_.TempPrefix())); +        chains_[i] >> files_[i].Sink(); +      } +      chains_.Wait(true); +      // Use less memory.  Because we can. +      CreateChains(std::min(config_.sort.buffer_size * config_.order, config_.TotalMemory()), counts); +      for (std::size_t i = 0; i < config_.order; ++i) { +        chains_[i] >> files_[i].Source(); +      } +    } + +    template <class Compare> void SetupSorts(Sorts<Compare> &sorts) { +      sorts.Init(config_.order - 1); +      // Unigrams don't get sorted because their order is always the same. +      chains_[0] >> files_[0].Sink(); +      for (std::size_t i = 1; i < config_.order; ++i) { +        sorts.push_back(chains_[i], config_.sort, Compare(i + 1)); +      } +      chains_.Wait(true); +    } + +  private: +    // Create chains, allocating memory to them.  Totally heuristic.  Count +    // bounds are upper bounds on the counts or not present. +    void CreateChains(std::size_t remaining_mem, const std::vector<uint64_t> &count_bounds) { +      std::vector<std::size_t> assignments; +      assignments.reserve(config_.order); +      // Start by assigning maximum memory usage (to be refined later). +      for (std::size_t i = 0; i < count_bounds.size(); ++i) { +        assignments.push_back(static_cast<std::size_t>(std::min( +            static_cast<uint64_t>(remaining_mem), +            count_bounds[i] * static_cast<uint64_t>(NGram::TotalSize(i + 1))))); +      } +      assignments.resize(config_.order, remaining_mem); + +      // Now we know how much memory everybody wants.  How much will they get? +      // Proportional to this. +      std::vector<float> portions; +      // Indices of orders that have yet to be assigned. +      std::vector<std::size_t> unassigned; +      for (std::size_t i = 0; i < config_.order; ++i) { +        portions.push_back(static_cast<float>((i+1) * NGram::TotalSize(i+1))); +        unassigned.push_back(i); +      } +      /*If somebody doesn't eat their full dinner, give it to the rest of the +       * family.  Then somebody else might not eat their full dinner etc.  Ends +       * when everybody unassigned is hungry. +       */ +      float sum; +      bool found_more; +      std::vector<std::size_t> block_count(config_.order); +      do { +        sum = 0.0; +        for (std::size_t i = 0; i < unassigned.size(); ++i) { +          sum += portions[unassigned[i]]; +        } +        found_more = false; +        // If the proportional assignment is more than needed, give it just what it needs. +        for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end();) { +          if (assignments[*i] <= remaining_mem * (portions[*i] / sum)) { +            remaining_mem -= assignments[*i]; +            block_count[*i] = 1; +            i = unassigned.erase(i); +            found_more = true; +          } else { +            ++i; +          } +        } +      } while (found_more); +      for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end(); ++i) { +        assignments[*i] = remaining_mem * (portions[*i] / sum); +        block_count[*i] = config_.block_count; +      } +      chains_.clear(); +      std::cerr << "Chain sizes:"; +      for (std::size_t i = 0; i < config_.order; ++i) { +        std::cerr << ' ' << (i+1) << ":" << assignments[i]; +        chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i])); +      } +      std::cerr << std::endl; +    } + +    PipelineConfig config_; + +    Chains chains_; +    // Often only unigrams, but sometimes all orders.   +    FixedArray<util::stream::FileBuffer> files_; +}; + +void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) { +  const PipelineConfig &config = master.Config(); +  std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl; + +  UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory()); +  std::size_t memory_for_chain =  +    // This much memory to work with after vocab hash table. +    static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) / +    // Solve for block size including the dedupe multiplier for one block. +    (static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) * +    // Chain likes memory expressed in terms of total memory. +    static_cast<float>(config.block_count); +  util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain)); + +  WordIndex type_count; +  util::FilePiece text(text_file, NULL, &std::cerr); +  text_file_name = text.FileName(); +  CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); +  chain >> boost::ref(counter); + +  util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); +  chain.Wait(true); +  std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl; +  master.InitForAdjust(sorter, type_count); +} + +void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) { +  const PipelineConfig &config = master.Config(); +  Chains second(config.order); + +  { +    Sorts<ContextOrder> sorts; +    master.SetupSorts(sorts); +    PrintStatistics(counts, discounts); +    lm::ngram::ShowSizes(counts); +    std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl; +    master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in); +  } + +  Chains gamma_chains(config.order); +  InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains); +  // Don't care about gamma for 0.   +  gamma_chains[0] >> util::stream::kRecycle; +  gammas.Init(config.order - 1); +  for (std::size_t i = 1; i < config.order; ++i) { +    gammas.push_back(util::MakeTemp(config.TempPrefix())); +    gamma_chains[i] >> gammas[i - 1].Sink(); +  } +  // Has to be done here due to gamma_chains scope. +  master.SetupSorts(primary); +} + +void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) { +  std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl; +  const PipelineConfig &config = master.Config(); +  master.MaximumLazyInput(counts, primary); + +  Chains gamma_chains(config.order - 1); +  util::stream::ChainConfig read_backoffs(config.read_backoffs); +  read_backoffs.entry_size = sizeof(float); +  for (std::size_t i = 0; i < config.order - 1; ++i) { +    gamma_chains.push_back(read_backoffs); +    gamma_chains.back() >> gammas[i].Source(); +  } +  master >> Interpolate(counts[0], ChainPositions(gamma_chains)); +  gamma_chains >> util::stream::kRecycle; +  master.BufferFinal(counts); +} + +} // namespace + +void Pipeline(PipelineConfig config, int text_file, int out_arpa) { +  // Some fail-fast sanity checks. +  if (config.sort.buffer_size * 4 > config.TotalMemory()) { +    config.sort.buffer_size = config.TotalMemory() / 4; +    std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl; +  } +  if (config.minimum_block < NGram::TotalSize(config.order)) { +    config.minimum_block = NGram::TotalSize(config.order); +    std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl; +  } +  UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << "."); +  UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception, +      "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ".  Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size."); + +  UTIL_TIMER("(%w s) Total wall time elapsed\n"); +  Master master(config); + +  util::scoped_fd vocab_file(config.vocab_file.empty() ?  +      util::MakeTemp(config.TempPrefix()) :  +      util::CreateOrThrow(config.vocab_file.c_str())); +  uint64_t token_count; +  std::string text_file_name; +  CountText(text_file, vocab_file.get(), master, token_count, text_file_name); + +  std::vector<uint64_t> counts; +  std::vector<Discount> discounts; +  master >> AdjustCounts(counts, discounts); + +  { +    FixedArray<util::stream::FileBuffer> gammas; +    Sorts<SuffixOrder> primary; +    InitialProbabilities(counts, discounts, master, primary, gammas); +    InterpolateProbabilities(counts, master, primary, gammas); +  } + +  std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; +  VocabReconstitute vocab(vocab_file.get()); +  UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up.  Is there a null byte in the input?"); +  HeaderInfo header_info(text_file_name, token_count); +  master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; +  master.MutableChains().Wait(true); +} + +}} // namespaces diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh new file mode 100644 index 00000000..f1d6c5f6 --- /dev/null +++ b/klm/lm/builder/pipeline.hh @@ -0,0 +1,40 @@ +#ifndef LM_BUILDER_PIPELINE__ +#define LM_BUILDER_PIPELINE__ + +#include "lm/builder/initial_probabilities.hh" +#include "lm/builder/header_info.hh" +#include "util/stream/config.hh" +#include "util/file_piece.hh" + +#include <string> +#include <cstddef> + +namespace lm { namespace builder { + +struct PipelineConfig { +  std::size_t order; +  std::string vocab_file; +  util::stream::SortConfig sort; +  InitialProbabilitiesConfig initial_probs; +  util::stream::ChainConfig read_backoffs; +  bool verbose_header; + +  // Amount of memory to assume that the vocabulary hash table will use.  This +  // is subtracted from total memory for CorpusCount. +  std::size_t assume_vocab_hash_size; + +  // Minimum block size to tolerate. +  std::size_t minimum_block; + +  // Number of blocks to use.  This will be overridden to 1 if everything fits. +  std::size_t block_count; + +  const std::string &TempPrefix() const { return sort.temp_prefix; } +  std::size_t TotalMemory() const { return sort.total_memory; } +}; + +// Takes ownership of text_file. +void Pipeline(PipelineConfig config, int text_file, int out_arpa); + +}} // namespaces +#endif // LM_BUILDER_PIPELINE__ diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc new file mode 100644 index 00000000..b0323221 --- /dev/null +++ b/klm/lm/builder/print.cc @@ -0,0 +1,135 @@ +#include "lm/builder/print.hh" + +#include "util/double-conversion/double-conversion.h" +#include "util/double-conversion/utils.h" +#include "util/file.hh" +#include "util/mmap.hh" +#include "util/scoped.hh" +#include "util/stream/timer.hh" + +#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE +#include <boost/lexical_cast.hpp> + +#include <sstream> + +#include <string.h> + +namespace lm { namespace builder { + +VocabReconstitute::VocabReconstitute(int fd) { +  uint64_t size = util::SizeOrThrow(fd); +  util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_); +  const char *const start = static_cast<const char*>(memory_.get()); +  const char *i; +  for (i = start; i != start + size; i += strlen(i) + 1) { +    map_.push_back(i); +  } +  // Last one for LookupPiece. +  map_.push_back(i); +} + +namespace { +class OutputManager { +  public: +    static const std::size_t kOutBuf = 1048576; + +    // Does not take ownership of out. +    explicit OutputManager(int out) +      : buf_(util::MallocOrThrow(kOutBuf)), +        builder_(static_cast<char*>(buf_.get()), kOutBuf), +        // Mostly the default but with inf instead.  And no flags. +        convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0), +        fd_(out) {} + +    ~OutputManager() { +      Flush(); +    } + +    OutputManager &operator<<(float value) { +      // Odd, but this is the largest number found in the comments. +      EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8); +      convert_.ToShortestSingle(value, &builder_); +      return *this; +    } + +    OutputManager &operator<<(StringPiece str) { +      if (str.size() > kOutBuf) { +        Flush(); +        util::WriteOrThrow(fd_, str.data(), str.size()); +      } else { +        EnsureRemaining(str.size()); +        builder_.AddSubstring(str.data(), str.size()); +      } +      return *this; +    } + +    // Inefficient! +    OutputManager &operator<<(unsigned val) { +      return *this << boost::lexical_cast<std::string>(val); +    } + +    OutputManager &operator<<(char c) { +      EnsureRemaining(1); +      builder_.AddCharacter(c); +      return *this; +    } + +    void Flush() { +      util::WriteOrThrow(fd_, buf_.get(), builder_.position()); +      builder_.Reset(); +    } + +  private: +    void EnsureRemaining(std::size_t amount) { +      if (static_cast<std::size_t>(builder_.size() - builder_.position()) < amount) { +        Flush(); +      } +    } + +    util::scoped_malloc buf_; +    double_conversion::StringBuilder builder_; +    double_conversion::DoubleToStringConverter convert_; +    int fd_; +}; +} // namespace + +PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd)  +  : vocab_(vocab), out_fd_(out_fd) { +  std::stringstream stream; + +  if (header_info) { +    stream << "# Input file: " << header_info->input_file << '\n'; +    stream << "# Token count: " << header_info->token_count << '\n'; +    stream << "# Smoothing: Modified Kneser-Ney" << '\n'; +  } +  stream << "\\data\\\n"; +  for (size_t i = 0; i < counts.size(); ++i) { +    stream << "ngram " << (i+1) << '=' << counts[i] << '\n'; +  } +  stream << '\n'; +  std::string as_string(stream.str()); +  util::WriteOrThrow(out_fd, as_string.data(), as_string.size()); +} + +void PrintARPA::Run(const ChainPositions &positions) { +  UTIL_TIMER("(%w s) Wrote ARPA file\n"); +  OutputManager out(out_fd_); +  for (unsigned order = 1; order <= positions.size(); ++order) { +    out << "\\" << order << "-grams:" << '\n'; +    for (NGramStream stream(positions[order - 1]); stream; ++stream) { +      // Correcting for numerical precision issues.  Take that IRST.   +      out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin()); +      for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) { +        out << ' ' << vocab_.Lookup(*i); +      } +      float backoff = stream->Value().complete.backoff; +      if (backoff != 0.0) +        out << '\t' << backoff; +      out << '\n'; +    } +    out << '\n'; +  } +  out << "\\end\\\n"; +} + +}} // namespaces diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh new file mode 100644 index 00000000..aa932e75 --- /dev/null +++ b/klm/lm/builder/print.hh @@ -0,0 +1,102 @@ +#ifndef LM_BUILDER_PRINT__ +#define LM_BUILDER_PRINT__ + +#include "lm/builder/ngram.hh" +#include "lm/builder/multi_stream.hh" +#include "lm/builder/header_info.hh" +#include "util/file.hh" +#include "util/mmap.hh" +#include "util/string_piece.hh" + +#include <ostream> + +#include <assert.h> + +// Warning: print routines read all unigrams before all bigrams before all +// trigrams etc.  So if other parts of the chain move jointly, you'll have to +// buffer.   + +namespace lm { namespace builder { + +class VocabReconstitute { +  public: +    // fd must be alive for life of this object; does not take ownership. +    explicit VocabReconstitute(int fd); + +    const char *Lookup(WordIndex index) const { +      assert(index < map_.size() - 1); +      return map_[index]; +    } + +    StringPiece LookupPiece(WordIndex index) const { +      return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]); +    } + +    std::size_t Size() const { +      // There's an extra entry to support StringPiece lengths. +      return map_.size() - 1; +    } + +  private: +    util::scoped_memory memory_; +    std::vector<const char*> map_; +}; + +// Not defined, only specialized.   +template <class T> void PrintPayload(std::ostream &to, const Payload &payload); +template <> inline void PrintPayload<uint64_t>(std::ostream &to, const Payload &payload) { +  to << payload.count; +} +template <> inline void PrintPayload<Uninterpolated>(std::ostream &to, const Payload &payload) { +  to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma); +} +template <> inline void PrintPayload<ProbBackoff>(std::ostream &to, const Payload &payload) { +  to << payload.complete.prob << ' ' << payload.complete.backoff; +} + +// template parameter is the type stored.   +template <class V> class Print { +  public: +    explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {} + +    void Run(const ChainPositions &chains) { +      NGramStreams streams(chains); +      for (NGramStream *s = streams.begin(); s != streams.end(); ++s) { +        DumpStream(*s); +      } +    } + +    void Run(const util::stream::ChainPosition &position) { +      NGramStream stream(position); +      DumpStream(stream); +    } + +  private: +    void DumpStream(NGramStream &stream) { +      for (; stream; ++stream) { +        PrintPayload<V>(to_, stream->Value()); +        for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) { +          to_ << ' ' << vocab_.Lookup(*w) << '=' << *w; +        } +        to_ << '\n'; +      } +    } + +    const VocabReconstitute &vocab_; +    std::ostream &to_; +}; + +class PrintARPA { +  public: +    // header_info may be NULL to disable the header +    explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd); + +    void Run(const ChainPositions &positions); + +  private: +    const VocabReconstitute &vocab_; +    int out_fd_; +}; + +}} // namespaces +#endif // LM_BUILDER_PRINT__ diff --git a/klm/lm/builder/sort.hh b/klm/lm/builder/sort.hh new file mode 100644 index 00000000..9989389b --- /dev/null +++ b/klm/lm/builder/sort.hh @@ -0,0 +1,103 @@ +#ifndef LM_BUILDER_SORT__ +#define LM_BUILDER_SORT__ + +#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram.hh" +#include "lm/word_index.hh" +#include "util/stream/sort.hh" + +#include "util/stream/timer.hh" + +#include <functional> +#include <string> + +namespace lm { +namespace builder { + +template <class Child> class Comparator : public std::binary_function<const void *, const void *, bool> { +  public: +    explicit Comparator(std::size_t order) : order_(order) {} + +    inline bool operator()(const void *lhs, const void *rhs) const { +      return static_cast<const Child*>(this)->Compare(static_cast<const WordIndex*>(lhs), static_cast<const WordIndex*>(rhs)); +    } + +    std::size_t Order() const { return order_; } + +  protected: +    std::size_t order_; +}; + +class SuffixOrder : public Comparator<SuffixOrder> { +  public: +    explicit SuffixOrder(std::size_t order) : Comparator<SuffixOrder>(order) {} + +    inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { +      for (std::size_t i = order_ - 1; i != 0; --i) { +        if (lhs[i] != rhs[i]) +          return lhs[i] < rhs[i]; +      } +      return lhs[0] < rhs[0]; +    } + +    static const unsigned kMatchOffset = 1; +}; + +class ContextOrder : public Comparator<ContextOrder> { +  public: +    explicit ContextOrder(std::size_t order) : Comparator<ContextOrder>(order) {} + +    inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { +      for (int i = order_ - 2; i >= 0; --i) { +        if (lhs[i] != rhs[i]) +          return lhs[i] < rhs[i]; +      } +      return lhs[order_ - 1] < rhs[order_ - 1]; +    } +}; + +class PrefixOrder : public Comparator<PrefixOrder> { +  public: +    explicit PrefixOrder(std::size_t order) : Comparator<PrefixOrder>(order) {} + +    inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { +      for (std::size_t i = 0; i < order_; ++i) { +        if (lhs[i] != rhs[i]) +          return lhs[i] < rhs[i]; +      } +      return false; +    } +     +    static const unsigned kMatchOffset = 0; +}; + +// Sum counts for the same n-gram. +struct AddCombiner { +  bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const { +    NGram first(first_void, compare.Order()); +    // There isn't a const version of NGram.   +    NGram second(const_cast<void*>(second_void), compare.Order()); +    if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false; +    first.Count() += second.Count(); +    return true; +  } +}; + +// The combiner is only used on a single chain, so I didn't bother to allow +// that template.   +template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > { +  private: +    typedef util::stream::Sort<Compare> S; +    typedef FixedArray<S> P; + +  public: +    void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) { +      new (P::end()) S(chain, config, compare); +      P::Constructed(); +    } +}; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_SORT__ diff --git a/klm/lm/filter/arpa_io.cc b/klm/lm/filter/arpa_io.cc new file mode 100644 index 00000000..f8568ac4 --- /dev/null +++ b/klm/lm/filter/arpa_io.cc @@ -0,0 +1,108 @@ +#include "lm/filter/arpa_io.hh" +#include "util/file_piece.hh" + +#include <iostream> +#include <ostream> +#include <string> +#include <vector> + +#include <ctype.h> +#include <errno.h> +#include <string.h> + +namespace lm { + +ARPAInputException::ARPAInputException(const StringPiece &message) throw() { +  *this << message; +} + +ARPAInputException::ARPAInputException(const StringPiece &message, const StringPiece &line) throw() { +  *this << message << " in line " << line; +} + +ARPAInputException::~ARPAInputException() throw() {} + +ARPAOutputException::ARPAOutputException(const char *message, const std::string &file_name) throw() { +  *this << message << " in file " << file_name; +} + +ARPAOutputException::~ARPAOutputException() throw() {} + +// Seeking is the responsibility of the caller. +void WriteCounts(std::ostream &out, const std::vector<uint64_t> &number) { +  out << "\n\\data\\\n"; +  for (unsigned int i = 0; i < number.size(); ++i) { +    out << "ngram " << i+1 << "=" << number[i] << '\n'; +  } +  out << '\n'; +} + +size_t SizeNeededForCounts(const std::vector<uint64_t> &number) { +  std::ostringstream buf; +  WriteCounts(buf, number); +  return buf.tellp(); +} + +bool IsEntirelyWhiteSpace(const StringPiece &line) { +  for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) { +    if (!isspace(line.data()[i])) return false; +  } +  return true; +} + +ARPAOutput::ARPAOutput(const char *name, size_t buffer_size) : file_name_(name), buffer_(new char[buffer_size]) { +  try { +    file_.exceptions(std::ostream::eofbit | std::ostream::failbit | std::ostream::badbit); +    if (!file_.rdbuf()->pubsetbuf(buffer_.get(), buffer_size)) { +      std::cerr << "Warning: could not enlarge buffer for " << name << std::endl; +      buffer_.reset(); +    } +    file_.open(name, std::ios::out | std::ios::binary); +  } catch (const std::ios_base::failure &f) { +    throw ARPAOutputException("Opening", file_name_); +  } +} + +void ARPAOutput::ReserveForCounts(std::streampos reserve) { +  try { +    for (std::streampos i = 0; i < reserve; i += std::streampos(1)) { +      file_ << '\n'; +    } +  } catch (const std::ios_base::failure &f) { +    throw ARPAOutputException("Writing blanks to reserve space for counts to ", file_name_); +  } +} + +void ARPAOutput::BeginLength(unsigned int length) { +  fast_counter_ = 0; +  try { +    file_ << '\\' << length << "-grams:" << '\n'; +  } catch (const std::ios_base::failure &f) { +    throw ARPAOutputException("Writing n-gram header to ", file_name_); +  } +} + +void ARPAOutput::EndLength(unsigned int length) { +  try { +    file_ << '\n'; +  } catch (const std::ios_base::failure &f) { +    throw ARPAOutputException("Writing blank at end of count list to ", file_name_); +  } +  if (length > counts_.size()) { +    counts_.resize(length); +  } +  counts_[length - 1] = fast_counter_; +} + +void ARPAOutput::Finish() { +  try { +    file_ << "\\end\\\n"; +    file_.seekp(0); +    WriteCounts(file_, counts_); +    file_ << std::flush; +  } catch (const std::ios_base::failure &f) { +    throw ARPAOutputException("Finishing including writing counts at beginning to ", file_name_); +  } +} + +} // namespace lm diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh new file mode 100644 index 00000000..5b31620b --- /dev/null +++ b/klm/lm/filter/arpa_io.hh @@ -0,0 +1,115 @@ +#ifndef LM_FILTER_ARPA_IO__ +#define LM_FILTER_ARPA_IO__ +/* Input and output for ARPA format language model files. + */ +#include "lm/read_arpa.hh" +#include "util/exception.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include <boost/noncopyable.hpp> +#include <boost/scoped_array.hpp> + +#include <fstream> +#include <string> +#include <vector> + +#include <err.h> +#include <string.h> +#include <stdint.h> + +namespace util { class FilePiece; } + +namespace lm { + +class ARPAInputException : public util::Exception { +  public: +    explicit ARPAInputException(const StringPiece &message) throw(); +    explicit ARPAInputException(const StringPiece &message, const StringPiece &line) throw(); +    virtual ~ARPAInputException() throw(); +}; + +class ARPAOutputException : public util::ErrnoException { +  public: +    ARPAOutputException(const char *prefix, const std::string &file_name) throw(); +    virtual ~ARPAOutputException() throw(); + +    const std::string &File() const throw() { return file_name_; } + +  private: +    const std::string file_name_; +}; + +// Handling for the counts of n-grams at the beginning of ARPA files. +size_t SizeNeededForCounts(const std::vector<uint64_t> &number); + +/* Writes an ARPA file.  This has to be seekable so the counts can be written + * at the end.  Hence, I just have it own a std::fstream instead of accepting + * a separately held std::ostream.  TODO: use the fast one from estimation. + */ +class ARPAOutput : boost::noncopyable { +  public: +    explicit ARPAOutput(const char *name, size_t buffer_size = 65536); + +    void ReserveForCounts(std::streampos reserve); + +    void BeginLength(unsigned int length); + +    void AddNGram(const StringPiece &line) { +      try { +        file_ << line << '\n'; +      } catch (const std::ios_base::failure &f) { +        throw ARPAOutputException("Writing an n-gram", file_name_); +      } +      ++fast_counter_; +    } + +    void AddNGram(const StringPiece &ngram, const StringPiece &line) { +      AddNGram(line); +    } + +    template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { +      AddNGram(line); +    } + +    void EndLength(unsigned int length); + +    void Finish(); + +  private: +    const std::string file_name_; +    boost::scoped_array<char> buffer_; +    std::fstream file_; +    size_t fast_counter_; +    std::vector<uint64_t> counts_; +}; + + +template <class Output> void ReadNGrams(util::FilePiece &in, unsigned int length, uint64_t number, Output &out) { +  ReadNGramHeader(in, length); +  out.BeginLength(length); +  for (uint64_t i = 0; i < number; ++i) { +    StringPiece line = in.ReadLine(); +    util::TokenIter<util::SingleCharacter> tabber(line, '\t'); +    if (!tabber) throw ARPAInputException("blank line", line); +    if (!++tabber) throw ARPAInputException("no tab", line); + +    out.AddNGram(*tabber, line); +  } +  out.EndLength(length); +} + +template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) { +  std::vector<uint64_t> number; +  ReadARPACounts(in_lm, number); +  out.ReserveForCounts(SizeNeededForCounts(number)); +  for (unsigned int i = 0; i < number.size(); ++i) { +    ReadNGrams(in_lm, i + 1, number[i], out); +  } +  ReadEnd(in_lm); +  out.Finish(); +} + +} // namespace lm + +#endif // LM_FILTER_ARPA_IO__ diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh new file mode 100644 index 00000000..97c0fa25 --- /dev/null +++ b/klm/lm/filter/count_io.hh @@ -0,0 +1,91 @@ +#ifndef LM_FILTER_COUNT_IO__ +#define LM_FILTER_COUNT_IO__ + +#include <fstream> +#include <iostream> +#include <string> + +#include <err.h> + +#include "util/file_piece.hh" + +namespace lm { + +class CountOutput : boost::noncopyable { +  public: +    explicit CountOutput(const char *name) : file_(name, std::ios::out) {} + +    void AddNGram(const StringPiece &line) { +      if (!(file_ << line << '\n')) { +        err(3, "Writing counts file failed"); +      } +    } + +    template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { +      AddNGram(line); +    } + +    void AddNGram(const StringPiece &ngram, const StringPiece &line) { +      AddNGram(line); +    } + +  private: +    std::fstream file_; +}; + +class CountBatch { +  public: +    explicit CountBatch(std::streamsize initial_read)  +      : initial_read_(initial_read) { +      buffer_.reserve(initial_read); +    } + +    void Read(std::istream &in) { +      buffer_.resize(initial_read_); +      in.read(&*buffer_.begin(), initial_read_); +      buffer_.resize(in.gcount()); +      char got; +      while (in.get(got) && got != '\n') +        buffer_.push_back(got); +    } + +    template <class Output> void Send(Output &out) { +      for (util::TokenIter<util::SingleCharacter> line(StringPiece(&*buffer_.begin(), buffer_.size()), '\n'); line; ++line) { +        util::TokenIter<util::SingleCharacter> tabber(*line, '\t'); +        if (!tabber) { +          std::cerr << "Warning: empty n-gram count line being removed\n"; +          continue; +        } +        util::TokenIter<util::SingleCharacter, true> words(*tabber, ' '); +        if (!words) { +          std::cerr << "Line has a tab but no words.\n"; +          continue; +        } +        out.AddNGram(words, util::TokenIter<util::SingleCharacter, true>::end(), *line); +      } +    } + +  private: +    std::streamsize initial_read_; + +    // This could have been a std::string but that's less happy with raw writes.   +    std::vector<char> buffer_; +}; + +template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) { +  try { +    while (true) { +      StringPiece line = in_file.ReadLine(); +      util::TokenIter<util::SingleCharacter> tabber(line, '\t'); +      if (!tabber) { +        std::cerr << "Warning: empty n-gram count line being removed\n"; +        continue; +      } +      out.AddNGram(*tabber, line); +    } +  } catch (const util::EndOfFileException &e) {} +} + +} // namespace lm + +#endif // LM_FILTER_COUNT_IO__ diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh new file mode 100644 index 00000000..7f945b0d --- /dev/null +++ b/klm/lm/filter/format.hh @@ -0,0 +1,250 @@ +#ifndef LM_FILTER_FORMAT_H__ +#define LM_FITLER_FORMAT_H__ + +#include "lm/filter/arpa_io.hh" +#include "lm/filter/count_io.hh" + +#include <boost/lexical_cast.hpp> +#include <boost/ptr_container/ptr_vector.hpp> + +#include <iosfwd> + +namespace lm { + +template <class Single> class MultipleOutput { +  private: +    typedef boost::ptr_vector<Single> Singles; +    typedef typename Singles::iterator SinglesIterator; + +  public: +    MultipleOutput(const char *prefix, size_t number) { +      files_.reserve(number); +      std::string tmp; +      for (unsigned int i = 0; i < number; ++i) { +        tmp = prefix; +        tmp += boost::lexical_cast<std::string>(i); +        files_.push_back(new Single(tmp.c_str())); +      } +    } + +    void AddNGram(const StringPiece &line) { +      for (SinglesIterator i = files_.begin(); i != files_.end(); ++i) +        i->AddNGram(line); +    } + +    template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { +      for (SinglesIterator i = files_.begin(); i != files_.end(); ++i) +        i->AddNGram(begin, end, line); +    } + +    void SingleAddNGram(size_t offset, const StringPiece &line) { +      files_[offset].AddNGram(line); +    } + +    template <class Iterator> void SingleAddNGram(size_t offset, const Iterator &begin, const Iterator &end, const StringPiece &line) { +      files_[offset].AddNGram(begin, end, line); +    } + +  protected: +    Singles files_; +}; + +class MultipleARPAOutput : public MultipleOutput<ARPAOutput> { +  public: +    MultipleARPAOutput(const char *prefix, size_t number) : MultipleOutput<ARPAOutput>(prefix, number) {} + +    void ReserveForCounts(std::streampos reserve) { +      for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i) +        i->ReserveForCounts(reserve); +    } + +    void BeginLength(unsigned int length) { +      for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i) +        i->BeginLength(length); +    } + +    void EndLength(unsigned int length) { +      for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i) +        i->EndLength(length); +    } + +    void Finish() { +      for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i) +        i->Finish(); +    } +}; + +template <class Filter, class Output> class DispatchInput { +  public: +    DispatchInput(Filter &filter, Output &output) : filter_(filter), output_(output) {} + +/*    template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { +      filter_.AddNGram(begin, end, line, output_); +    }*/ + +    void AddNGram(const StringPiece &ngram, const StringPiece &line) { +      filter_.AddNGram(ngram, line, output_); +    } + +  protected: +    Filter &filter_; +    Output &output_; +}; + +template <class Filter, class Output> class DispatchARPAInput : public DispatchInput<Filter, Output> { +  private: +    typedef DispatchInput<Filter, Output> B; + +  public: +    DispatchARPAInput(Filter &filter, Output &output) : B(filter, output) {} + +    void ReserveForCounts(std::streampos reserve) { B::output_.ReserveForCounts(reserve); } +    void BeginLength(unsigned int length) { B::output_.BeginLength(length); } + +    void EndLength(unsigned int length) { +      B::filter_.Flush(); +      B::output_.EndLength(length); +    } +    void Finish() { B::output_.Finish(); } +}; + +struct ARPAFormat { +  typedef ARPAOutput Output; +  typedef MultipleARPAOutput Multiple; +  static void Copy(util::FilePiece &in, Output &out) { +    ReadARPA(in, out); +  } +  template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) { +    DispatchARPAInput<Filter, Out> dispatcher(filter, output); +    ReadARPA(in, dispatcher); +  } +}; + +struct CountFormat { +  typedef CountOutput Output; +  typedef MultipleOutput<Output> Multiple; +  static void Copy(util::FilePiece &in, Output &out) { +    ReadCount(in, out); +  } +  template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) { +    DispatchInput<Filter, Out> dispatcher(filter, output); +    ReadCount(in, dispatcher); +  } +}; + +/* For multithreading, the buffer classes hold batches of filter inputs and + * outputs in memory.  The strings get reused a lot, so keep them around + * instead of clearing each time.   + */ +class InputBuffer { +  public: +    InputBuffer() : actual_(0) {} +     +    void Reserve(size_t size) { lines_.reserve(size); } + +    template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { +      if (lines_.size() == actual_) lines_.resize(lines_.size() + 1); +      // TODO avoid this copy. +      std::string &copied = lines_[actual_].line; +      copied.assign(line.data(), line.size()); +      lines_[actual_].ngram.set(copied.data() + (ngram.data() - line.data()), ngram.size()); +      ++actual_; +    } + +    template <class Filter, class Output> void CallFilter(Filter &filter, Output &output) const { +      for (std::vector<Line>::const_iterator i = lines_.begin(); i != lines_.begin() + actual_; ++i) { +        filter.AddNGram(i->ngram, i->line, output); +      } +    } + +    void Clear() { actual_ = 0; } +    bool Empty() { return actual_ == 0; } +    size_t Size() { return actual_; } + +  private: +    struct Line { +      std::string line; +      StringPiece ngram; +    }; + +    size_t actual_; + +    std::vector<Line> lines_; +}; + +class BinaryOutputBuffer { +  public: +    BinaryOutputBuffer() {} + +    void Reserve(size_t size) { +      lines_.reserve(size); +    } +     +    void AddNGram(const StringPiece &line) { +      lines_.push_back(line); +    } +     +    template <class Output> void Flush(Output &output) { +      for (std::vector<StringPiece>::const_iterator i = lines_.begin(); i != lines_.end(); ++i) { +        output.AddNGram(*i); +      } +      lines_.clear(); +    } +     +  private: +    std::vector<StringPiece> lines_; +}; + +class MultipleOutputBuffer { +  public: +    MultipleOutputBuffer() : last_(NULL) {} + +    void Reserve(size_t size) { +      annotated_.reserve(size); +    } + +    void AddNGram(const StringPiece &line) { +      annotated_.resize(annotated_.size() + 1); +      annotated_.back().line = line; +    } + +    void SingleAddNGram(size_t offset, const StringPiece &line) { +      if ((line.data() == last_.data()) && (line.length() == last_.length())) { +        annotated_.back().systems.push_back(offset); +      } else { +        annotated_.resize(annotated_.size() + 1); +        annotated_.back().systems.push_back(offset); +        annotated_.back().line = line; +        last_ = line; +      } +    } + +    template <class Output> void Flush(Output &output) { +      for (std::vector<Annotated>::const_iterator i = annotated_.begin(); i != annotated_.end(); ++i) { +        if (i->systems.empty()) { +          output.AddNGram(i->line); +        } else { +          for (std::vector<size_t>::const_iterator j = i->systems.begin(); j != i->systems.end(); ++j) { +            output.SingleAddNGram(*j, i->line); +          } +        } +      } +      annotated_.clear(); +    } + +  private: +    struct Annotated { +      // If this is empty, send to all systems.  +      // A filter should never send to all systems and send to a single one. +      std::vector<size_t> systems; +      StringPiece line; +    }; + +    StringPiece last_; + +    std::vector<Annotated> annotated_; +}; + +} // namespace lm + +#endif // LM_FILTER_FORMAT_H__ diff --git a/klm/lm/filter/main.cc b/klm/lm/filter/main.cc new file mode 100644 index 00000000..c42243e2 --- /dev/null +++ b/klm/lm/filter/main.cc @@ -0,0 +1,249 @@ +#include "lm/filter/arpa_io.hh" +#include "lm/filter/format.hh" +#include "lm/filter/phrase.hh" +#ifndef NTHREAD +#include "lm/filter/thread.hh" +#endif +#include "lm/filter/vocab.hh" +#include "lm/filter/wrapper.hh" +#include "util/file_piece.hh" + +#include <boost/ptr_container/ptr_vector.hpp> + +#include <cstring> +#include <fstream> +#include <iostream> +#include <memory> + +namespace lm { +namespace { + +void DisplayHelp(const char *name) { +  std::cerr +    << "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n" +    "copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n" +    "    parser.\n" +    "single mode treats the entire input as a single sentence.\n" +    "multiple mode filters to multiple sentences in parallel.  Each sentence is on\n" +    "    a separate line.  A separate file is created for each file by appending the\n" +    "    0-indexed line number to the output file name.\n" +    "union mode produces one filtered model that is the union of models created by\n" +    "    multiple mode.\n\n" +    "context means only the context (all but last word) has to pass the filter, but\n" +    "    the entire n-gram is output.\n\n" +    "phrase means that the vocabulary is actually tab-delimited phrases and that the\n" +    "    phrases can generate the n-gram when assembled in arbitrary order and\n" +    "    clipped.  Currently works with multiple or union mode.\n\n" +    "The file format is set by [raw|arpa] with default arpa:\n" +    "raw means space-separated tokens, optionally followed by a tab and arbitrary\n" +    "    text.  This is useful for ngram count files.\n" +    "arpa means the ARPA file format for n-gram language models.\n\n" +#ifndef NTHREAD +    "threads:m sets m threads (default: conccurrency detected by boost)\n" +    "batch_size:m sets the batch size for threading.  Expect memory usage from this\n" +    "    of 2*threads*batch_size n-grams.\n\n" +#else +    "This binary was compiled with -DNTHREAD, disabling threading.  If you wanted\n" +    "    threading, compile without this flag against Boost >=1.42.0.\n\n" +#endif +    "There are two inputs: vocabulary and model.  Either may be given as a file\n" +    "    while the other is on stdin.  Specify the type given as a file using\n" +    "    vocab: or model: before the file name.  \n\n" +    "For ARPA format, the output must be seekable.  For raw format, it can be a\n" +    "    stream i.e. /dev/stdout\n"; +} + +typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION} FilterMode; +typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format; + +struct Config { +  Config() :  +#ifndef NTHREAD +  batch_size(25000), +  threads(boost::thread::hardware_concurrency()), +#endif +  phrase(false), +  context(false), +  format(FORMAT_ARPA) +  { +#ifndef NTHREAD +    if (!threads) threads = 1; +#endif +  } + +#ifndef NTHREAD +  size_t batch_size; +  size_t threads; +#endif +  bool phrase; +  bool context; +  FilterMode mode; +  Format format; +}; + +template <class Format, class Filter, class OutputBuffer, class Output> void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) { +#ifndef NTHREAD +  if (config.threads == 1) { +#endif +    Format::RunFilter(in_lm, filter, output); +#ifndef NTHREAD +  } else { +    typedef Controller<Filter, OutputBuffer, Output> Threaded; +    Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output); +    Format::RunFilter(in_lm, threading, output); +  } +#endif +} + +template <class Format, class Filter, class OutputBuffer, class Output> void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) { +  if (config.context) { +    ContextFilter<Filter> context_filter(filter); +    RunThreadedFilter<Format, ContextFilter<Filter>, OutputBuffer, Output>(config, in_lm, context_filter, output); +  } else { +    RunThreadedFilter<Format, Filter, OutputBuffer, Output>(config, in_lm, filter, output); +  } +} + +template <class Format, class Binary> void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) { +  typedef BinaryFilter<Binary> Filter; +  RunContextFilter<Format, Filter, BinaryOutputBuffer, typename Format::Output>(config, in_lm, Filter(binary), out); +} + +template <class Format> void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) { +  if (config.mode == MODE_MULTIPLE) { +    if (config.phrase) { +      typedef phrase::Multiple Filter; +      phrase::Substrings substrings; +      typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings)); +      RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(substrings), out); +    } else { +      typedef vocab::Multiple Filter; +      boost::unordered_map<std::string, std::vector<unsigned int> > words; +      typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words)); +      RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(words), out); +    } +    return; +  } + +  typename Format::Output out(out_name); + +  if (config.mode == MODE_COPY) { +    Format::Copy(in_lm, out); +    return; +  } + +  if (config.mode == MODE_SINGLE) { +    vocab::Single::Words words; +    vocab::ReadSingle(in_vocab, words); +    DispatchBinaryFilter<Format, vocab::Single>(config, in_lm, vocab::Single(words), out); +    return; +  } + +  if (config.mode == MODE_UNION) { +    if (config.phrase) { +      phrase::Substrings substrings; +      phrase::ReadMultiple(in_vocab, substrings); +      DispatchBinaryFilter<Format, phrase::Union>(config, in_lm, phrase::Union(substrings), out); +    } else { +      vocab::Union::Words words; +      vocab::ReadMultiple(in_vocab, words); +      DispatchBinaryFilter<Format, vocab::Union>(config, in_lm, vocab::Union(words), out); +    } +    return; +  } +} + +} // namespace +} // namespace lm + +int main(int argc, char *argv[]) { +  if (argc < 4) { +    lm::DisplayHelp(argv[0]); +    return 1; +  } + +  // I used to have boost::program_options, but some users didn't want to compile boost.   +  lm::Config config; +  boost::optional<lm::FilterMode> mode; +  for (int i = 1; i < argc - 2; ++i) { +    const char *str = argv[i]; +    if (!std::strcmp(str, "copy")) { +      mode = lm::MODE_COPY; +    } else if (!std::strcmp(str, "single")) { +      mode = lm::MODE_SINGLE; +    } else if (!std::strcmp(str, "multiple")) { +      mode = lm::MODE_MULTIPLE; +    } else if (!std::strcmp(str, "union")) { +      mode = lm::MODE_UNION; +    } else if (!std::strcmp(str, "phrase")) { +      config.phrase = true; +    } else if (!std::strcmp(str, "context")) { +      config.context = true; +    } else if (!std::strcmp(str, "arpa")) { +      config.format = lm::FORMAT_ARPA; +    } else if (!std::strcmp(str, "raw")) { +      config.format = lm::FORMAT_COUNT; +#ifndef NTHREAD +    } else if (!std::strncmp(str, "threads:", 8)) { +      config.threads = boost::lexical_cast<size_t>(str + 8); +      if (!config.threads) { +        std::cerr << "Specify at least one thread." << std::endl; +        return 1; +      } +    } else if (!std::strncmp(str, "batch_size:", 11)) { +      config.batch_size = boost::lexical_cast<size_t>(str + 11); +      if (config.batch_size < 5000) { +        std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; +        if (!config.batch_size) return 1; +      } +#endif +    } else { +      lm::DisplayHelp(argv[0]); +      return 1; +    } +  } +   +  if (!mode) { +    lm::DisplayHelp(argv[0]); +    return 1; +  } +  config.mode = *mode; + +  if (config.phrase && config.mode != lm::MODE_UNION && mode != lm::MODE_MULTIPLE) { +    std::cerr << "Phrase constraint currently only works in multiple or union mode.  If you really need it for single, put everything on one line and use union." << std::endl; +    return 1; +  } + +  bool cmd_is_model = true; +  const char *cmd_input = argv[argc - 2]; +  if (!strncmp(cmd_input, "vocab:", 6)) { +    cmd_is_model = false; +    cmd_input += 6; +  } else if (!strncmp(cmd_input, "model:", 6)) { +    cmd_input += 6; +  } else if (strchr(cmd_input, ':')) { +    errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); +  } else { +    std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; +  } +  std::ifstream cmd_file; +  std::istream *vocab; +  if (cmd_is_model) { +    vocab = &std::cin; +  } else { +    cmd_file.open(cmd_input, std::ios::in); +    if (!cmd_file) { +      err(2, "Could not open input file %s", cmd_input); +    } +    vocab = &cmd_file; +  } + +  util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); + +  if (config.format == lm::FORMAT_ARPA) { +    lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); +  } else if (config.format == lm::FORMAT_COUNT) { +    lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); +  } +  return 0; +} diff --git a/klm/lm/filter/phrase.cc b/klm/lm/filter/phrase.cc new file mode 100644 index 00000000..1bef2a3f --- /dev/null +++ b/klm/lm/filter/phrase.cc @@ -0,0 +1,281 @@ +#include "lm/filter/phrase.hh" + +#include "lm/filter/format.hh" + +#include <algorithm> +#include <functional> +#include <iostream> +#include <queue> +#include <string> +#include <vector> + +#include <ctype.h> + +namespace lm { +namespace phrase { + +unsigned int ReadMultiple(std::istream &in, Substrings &out) { +  bool sentence_content = false; +  unsigned int sentence_id = 0; +  std::vector<Hash> phrase; +  std::string word; +  while (in) { +    char c; +    // Gather a word. +    while (!isspace(c = in.get()) && in) word += c; +    // Treat EOF like a newline. +    if (!in) c = '\n'; +    // Add the word to the phrase. +    if (!word.empty()) { +      phrase.push_back(util::MurmurHashNative(word.data(), word.size())); +      word.clear(); +    } +    if (c == ' ') continue; +    // It's more than just a space.  Close out the phrase.   +    if (!phrase.empty()) { +      sentence_content = true; +      out.AddPhrase(sentence_id, phrase.begin(), phrase.end()); +      phrase.clear(); +    } +    if (c == '\t' || c == '\v') continue; +    // It's more than a space or tab: a newline.    +    if (sentence_content) { +      ++sentence_id; +      sentence_content = false; +    } +  } +  if (!in.eof()) in.exceptions(std::istream::failbit | std::istream::badbit); +  return sentence_id + sentence_content; +} + +namespace detail { const StringPiece kEndSentence("</s>"); } + +namespace { + +typedef unsigned int Sentence; +typedef std::vector<Sentence> Sentences; + +class Vertex; + +class Arc { +  public: +    Arc() {} + +    // For arcs from one vertex to another.   +    void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) { +      Set(to, intersect); +      from_ = &from; +    } + +    /* For arcs from before the n-gram begins to somewhere in the n-gram (right +     * aligned).  These have no from_ vertex; it implictly matches every +     * sentence.  This also handles when the n-gram is a substring of a phrase.  +     */ +    void SetRight(Vertex &to, const Sentences &complete) { +      Set(to, complete); +      from_ = NULL; +    } + +    Sentence Current() const { +      return *current_; +    } + +    bool Empty() const { +      return current_ == last_; +    } + +    /* When this function returns: +     * If Empty() then there's nothing left from this intersection. +     * +     * If Current() == to then to is part of the intersection.  +     * +     * Otherwise, Current() > to.  In this case, to is not part of the +     * intersection and neither is anything < Current().  To determine if +     * any value >= Current() is in the intersection, call LowerBound again +     * with the value.    +     */ +    void LowerBound(const Sentence to); + +  private: +    void Set(Vertex &to, const Sentences &sentences); + +    const Sentence *current_; +    const Sentence *last_; +    Vertex *from_; +}; + +struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> { +  bool operator()(const Arc *first, const Arc *second) const { +    return first->Current() > second->Current(); +  } +}; + +class Vertex { +  public: +    Vertex() : current_(0) {} + +    Sentence Current() const { +      return current_; +    } + +    bool Empty() const { +      return incoming_.empty(); +    } + +    void LowerBound(const Sentence to); + +  private: +    friend class Arc; + +    void AddIncoming(Arc *arc) { +      if (!arc->Empty()) incoming_.push(arc); +    } + +    unsigned int current_; +    std::priority_queue<Arc*, std::vector<Arc*>, ArcGreater> incoming_; +}; + +void Arc::LowerBound(const Sentence to) { +  current_ = std::lower_bound(current_, last_, to); +  // If *current_ > to, don't advance from_.  The intervening values of +  // from_ may be useful for another one of its outgoing arcs. +  if (!from_ || Empty() || (Current() > to)) return; +  assert(Current() == to); +  from_->LowerBound(to); +  if (from_->Empty()) { +    current_ = last_; +    return; +  } +  assert(from_->Current() >= to); +  if (from_->Current() > to) { +    current_ = std::lower_bound(current_ + 1, last_, from_->Current()); +  } +} + +void Arc::Set(Vertex &to, const Sentences &sentences) { +  current_ = &*sentences.begin(); +  last_ = &*sentences.end(); +  to.AddIncoming(this); +} + +void Vertex::LowerBound(const Sentence to) { +  if (Empty()) return; +  // Union lower bound.   +  while (true) { +    Arc *top = incoming_.top(); +    if (top->Current() > to) { +      current_ = top->Current(); +      return; +    } +    // If top->Current() == to, we still need to verify that's an actual  +    // element and not just a bound.   +    incoming_.pop(); +    top->LowerBound(to); +    if (!top->Empty()) { +      incoming_.push(top); +      if (top->Current() == to) { +        current_ = to; +        return; +      } +    } else if (Empty()) { +      return; +    } +  } +} + +void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Vertex *const vertices, Arc *free_arc) { +  assert(!hashes.empty()); + +  const Hash *const first_word = &*hashes.begin(); +  const Hash *const last_word = &*hashes.end() - 1; + +  Hash hash = 0; +  const Sentences *found; +  // Phrases starting at or before the first word in the n-gram. +  { +    Vertex *vertex = vertices; +    for (const Hash *word = first_word; ; ++word, ++vertex) { +      hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word); +      // Now hash is [hashes.begin(), word]. +      if (word == last_word) { +        if (phrase.FindSubstring(hash, found)) +          (free_arc++)->SetRight(*vertex, *found); +        break; +      } +      if (!phrase.FindRight(hash, found)) break; +      (free_arc++)->SetRight(*vertex, *found); +    } +  } + +  // Phrases starting at the second or later word in the n-gram.    +  Vertex *vertex_from = vertices; +  for (const Hash *word_from = first_word + 1; word_from != &*hashes.end(); ++word_from, ++vertex_from) { +    hash = 0; +    Vertex *vertex_to = vertex_from + 1; +    for (const Hash *word_to = word_from; ; ++word_to, ++vertex_to) { +      // Notice that word_to and vertex_to have the same index.   +      hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word_to); +      // Now hash covers [word_from, word_to]. +      if (word_to == last_word) { +        if (phrase.FindLeft(hash, found)) +          (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found); +        break; +      } +      if (!phrase.FindPhrase(hash, found)) break; +      (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found); +    } +  } +} + +} // namespace + +namespace detail { + +} // namespace detail + +bool Union::Evaluate() { +  assert(!hashes_.empty()); +  // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.   +  Vertex vertices[hashes_.size()]; +  // One for every substring.   +  Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; +  BuildGraph(substrings_, hashes_, vertices, arcs); +  Vertex &last_vertex = vertices[hashes_.size() - 1]; + +  unsigned int lower = 0; +  while (true) { +    last_vertex.LowerBound(lower); +    if (last_vertex.Empty()) return false; +    if (last_vertex.Current() == lower) return true; +    lower = last_vertex.Current(); +  } +} + +template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) { +  assert(!hashes_.empty()); +  // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.   +  Vertex vertices[hashes_.size()]; +  // One for every substring.   +  Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; +  BuildGraph(substrings_, hashes_, vertices, arcs); +  Vertex &last_vertex = vertices[hashes_.size() - 1]; + +  unsigned int lower = 0; +  while (true) { +    last_vertex.LowerBound(lower); +    if (last_vertex.Empty()) return; +    if (last_vertex.Current() == lower) { +      output.SingleAddNGram(lower, line); +      ++lower; +    } else { +      lower = last_vertex.Current(); +    } +  } +} + +template void Multiple::Evaluate<CountFormat::Multiple>(const StringPiece &line, CountFormat::Multiple &output); +template void Multiple::Evaluate<ARPAFormat::Multiple>(const StringPiece &line, ARPAFormat::Multiple &output); +template void Multiple::Evaluate<MultipleOutputBuffer>(const StringPiece &line, MultipleOutputBuffer &output); + +} // namespace phrase +} // namespace lm diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh new file mode 100644 index 00000000..07479dea --- /dev/null +++ b/klm/lm/filter/phrase.hh @@ -0,0 +1,153 @@ +#ifndef LM_FILTER_PHRASE_H__ +#define LM_FILTER_PHRASE_H__ + +#include "util/murmur_hash.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include <boost/unordered_map.hpp> + +#include <iosfwd> +#include <vector> + +#define LM_FILTER_PHRASE_METHOD(caps, lower) \ +bool Find##caps(Hash key, const std::vector<unsigned int> *&out) const {\ +  Table::const_iterator i(table_.find(key));\ +  if (i==table_.end()) return false; \ +  out = &i->second.lower; \ +  return true; \ +} + +namespace lm { +namespace phrase { + +typedef uint64_t Hash; + +class Substrings { +  private: +    /* This is the value in a hash table where the key is a string.  It indicates +     * four sets of sentences: +     * substring is sentences with a phrase containing the key as a substring.   +     * left is sentencess with a phrase that begins with the key (left aligned). +     * right is sentences with a phrase that ends with the key (right aligned). +     * phrase is sentences where the key is a phrase. +     * Each set is encoded as a vector of sentence ids in increasing order. +     */ +    struct SentenceRelation { +      std::vector<unsigned int> substring, left, right, phrase; +    }; +    /* Most of the CPU is hash table lookups, so let's not complicate it with +     * vector equality comparisons.  If a collision happens, the SentenceRelation +     * structure will contain the union of sentence ids over the colliding strings. +     * In that case, the filter will be slightly more permissive.   +     * The key here is the same as boost's hash of std::vector<std::string>.   +     */ +    typedef boost::unordered_map<Hash, SentenceRelation> Table; + +  public: +    Substrings() {} + +    /* If the string isn't a substring of any phrase, return NULL.  Otherwise, +     * return a pointer to std::vector<unsigned int> listing sentences with +     * matching phrases.  This set may be empty for Left, Right, or Phrase. +     * Example: const std::vector<unsigned int> *FindSubstring(Hash key) +     */ +    LM_FILTER_PHRASE_METHOD(Substring, substring) +    LM_FILTER_PHRASE_METHOD(Left, left) +    LM_FILTER_PHRASE_METHOD(Right, right) +    LM_FILTER_PHRASE_METHOD(Phrase, phrase) + +    // sentence_id must be non-decreasing.  Iterators are over words in the phrase.   +    template <class Iterator> void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) { +      // Iterate over all substrings.   +      for (Iterator start = begin; start != end; ++start) { +        Hash hash = 0; +        SentenceRelation *relation; +        for (Iterator finish = start; finish != end; ++finish) { +          hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish); +          // Now hash is of [start, finish]. +          relation = &table_[hash]; +          AppendSentence(relation->substring, sentence_id); +          if (start == begin) AppendSentence(relation->left, sentence_id); +        } +        AppendSentence(relation->right, sentence_id); +        if (start == begin) AppendSentence(relation->phrase, sentence_id); +      } +    } + +  private: +    void AppendSentence(std::vector<unsigned int> &vec, unsigned int sentence_id) { +      if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id); +    } + +    Table table_; +}; + +// Read a file with one sentence per line containing tab-delimited phrases of +// space-separated words.   +unsigned int ReadMultiple(std::istream &in, Substrings &out); + +namespace detail { +extern const StringPiece kEndSentence; + +template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::vector<Hash> &hashes) { +  hashes.clear(); +  if (i == end) return; +  // TODO: check strict phrase boundaries after <s> and before </s>.  For now, just skip tags.   +  if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) { +    ++i; +  } +  for (; i != end && (*i != kEndSentence); ++i) { +    hashes.push_back(util::MurmurHashNative(i->data(), i->size())); +  } +} + +} // namespace detail + +class Union { +  public: +    explicit Union(const Substrings &substrings) : substrings_(substrings) {} + +    template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) { +      detail::MakeHashes(begin, end, hashes_); +      return hashes_.empty() || Evaluate(); +    } + +  private: +    bool Evaluate(); + +    std::vector<Hash> hashes_; + +    const Substrings &substrings_; +}; + +class Multiple { +  public: +    explicit Multiple(const Substrings &substrings) : substrings_(substrings) {} + +    template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { +      detail::MakeHashes(begin, end, hashes_); +      if (hashes_.empty()) { +        output.AddNGram(line); +        return; +      } +      Evaluate(line, output); +    } + +    template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { +      AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output); +    } + +    void Flush() const {} + +  private: +    template <class Output> void Evaluate(const StringPiece &line, Output &output); + +    std::vector<Hash> hashes_; + +    const Substrings &substrings_; +}; + +} // namespace phrase +} // namespace lm +#endif // LM_FILTER_PHRASE_H__ diff --git a/klm/lm/filter/thread.hh b/klm/lm/filter/thread.hh new file mode 100644 index 00000000..e785b263 --- /dev/null +++ b/klm/lm/filter/thread.hh @@ -0,0 +1,167 @@ +#ifndef LM_FILTER_THREAD_H__ +#define LM_FILTER_THREAD_H__ + +#include "util/thread_pool.hh" + +#include <boost/utility/in_place_factory.hpp> + +#include <deque> +#include <stack> + +namespace lm { + +template <class OutputBuffer> class ThreadBatch { +  public: +    ThreadBatch() {} +     +    void Reserve(size_t size) { +      input_.Reserve(size); +      output_.Reserve(size); +     } + +    // File reading thread.   +    InputBuffer &Fill(uint64_t sequence) { +      sequence_ = sequence; +      // Why wait until now to clear instead of after output?  free in the same +      // thread as allocated.   +      input_.Clear(); +      return input_; +    } + +    // Filter worker thread.   +    template <class Filter> void CallFilter(Filter &filter) { +      input_.CallFilter(filter, output_); +    } + +    uint64_t Sequence() const { return sequence_; } + +    // File writing thread.   +    template <class RealOutput> void Flush(RealOutput &output) { +      output_.Flush(output); +    } + +  private: +    InputBuffer input_; +    OutputBuffer output_; + +    uint64_t sequence_; +}; + +template <class Batch, class Filter> class FilterWorker { +  public: +    typedef Batch *Request; + +    FilterWorker(const Filter &filter, util::PCQueue<Request> &done) : filter_(filter), done_(done) {} + +    void operator()(Request request) { +      request->CallFilter(filter_); +      done_.Produce(request); +    } + +  private: +    Filter filter_; + +    util::PCQueue<Request> &done_; +}; + +// There should only be one OutputWorker. +template <class Batch, class Output> class OutputWorker { +  public: +    typedef Batch *Request; + +    OutputWorker(Output &output, util::PCQueue<Request> &done) : output_(output), done_(done), base_sequence_(0) {} + +    void operator()(Request request) { +      assert(request->Sequence() >= base_sequence_); +      // Assemble the output in order.   +      uint64_t pos = request->Sequence() - base_sequence_; +      if (pos >= ordering_.size()) { +        ordering_.resize(pos + 1, NULL); +      } +      ordering_[pos] = request; +      while (!ordering_.empty() && ordering_.front()) { +        ordering_.front()->Flush(output_); +        done_.Produce(ordering_.front()); +        ordering_.pop_front(); +        ++base_sequence_; +      } +    } + +  private: +    Output &output_; + +    util::PCQueue<Request> &done_; + +    std::deque<Request> ordering_; + +    uint64_t base_sequence_; +}; + +template <class Filter, class OutputBuffer, class RealOutput> class Controller : boost::noncopyable { +  private: +    typedef ThreadBatch<OutputBuffer> Batch; + +  public: +    Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output)  +      : batch_size_(batch_size), queue_size_(queue), +        batches_(queue), +        to_read_(queue), +        output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL), +        filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL), +        sequence_(0) { +      for (size_t i = 0; i < queue; ++i) { +        batches_[i].Reserve(batch_size); +        local_read_.push(&batches_[i]); +      } +      NewInput(); +    } + +    void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) { +      input_->AddNGram(ngram, line, output); +      if (input_->Size() == batch_size_) { +        FlushInput(); +        NewInput(); +      } +    } + +    void Flush() { +      FlushInput(); +      while (local_read_.size() < queue_size_) { +        MoveRead(); +      } +      NewInput(); +    } + +  private: +    void FlushInput() { +      if (input_->Empty()) return; +      filter_.Produce(local_read_.top()); +      local_read_.pop(); +      if (local_read_.empty()) MoveRead(); +    } + +    void NewInput() { +      input_ = &local_read_.top()->Fill(sequence_++); +    } + +    void MoveRead() { +      local_read_.push(to_read_.Consume()); +    } + +    const size_t batch_size_; +    const size_t queue_size_; + +    std::vector<Batch> batches_; + +    util::PCQueue<Batch*> to_read_; +    std::stack<Batch*> local_read_; +    util::ThreadPool<OutputWorker<Batch, RealOutput> > output_; +    util::ThreadPool<FilterWorker<Batch, Filter> > filter_; + +    uint64_t sequence_; +    InputBuffer *input_; +}; + +} // namespace lm + +#endif // LM_FILTER_THREAD_H__ diff --git a/klm/lm/filter/vocab.cc b/klm/lm/filter/vocab.cc new file mode 100644 index 00000000..7ee4e84b --- /dev/null +++ b/klm/lm/filter/vocab.cc @@ -0,0 +1,54 @@ +#include "lm/filter/vocab.hh" + +#include <istream> +#include <iostream> + +#include <ctype.h> +#include <err.h> + +namespace lm { +namespace vocab { + +void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out) { +  in.exceptions(std::istream::badbit); +  std::string word; +  while (in >> word) { +    out.insert(word); +  } +} + +namespace { +bool IsLineEnd(std::istream &in) { +  int got; +  do { +    got = in.get(); +    if (!in) return true; +    if (got == '\n') return true; +  } while (isspace(got)); +  in.unget(); +  return false; +} +}// namespace + +// Read space separated words in enter separated lines.  These lines can be +// very long, so don't read an entire line at a time.   +unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out) { +  in.exceptions(std::istream::badbit); +  unsigned int sentence = 0; +  bool used_id = false; +  std::string word; +  while (in >> word) { +    used_id = true; +    std::vector<unsigned int> &posting = out[word]; +    if (posting.empty() || (posting.back() != sentence)) +      posting.push_back(sentence); +    if (IsLineEnd(in)) { +      ++sentence; +      used_id = false; +    } +  } +  return sentence + used_id; +} + +} // namespace vocab +} // namespace lm diff --git a/klm/lm/filter/vocab.hh b/klm/lm/filter/vocab.hh new file mode 100644 index 00000000..e2b6adff --- /dev/null +++ b/klm/lm/filter/vocab.hh @@ -0,0 +1,132 @@ +#ifndef LM_FILTER_VOCAB_H__ +#define LM_FILTER_VOCAB_H__ + +// Vocabulary-based filters for language models. + +#include "util/multi_intersection.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include <boost/noncopyable.hpp> +#include <boost/range/iterator_range.hpp> +#include <boost/unordered/unordered_map.hpp> +#include <boost/unordered/unordered_set.hpp> + +#include <string> +#include <vector> + +namespace lm { +namespace vocab { + +void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out); + +// Read one sentence vocabulary per line.  Return the number of sentences. +unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out); + +/* Is this a special tag like <s> or <UNK>?  This actually includes anything + * surrounded with < and >, which most tokenizers separate for real words, so + * this should not catch real words as it looks at a single token.    + */ +inline bool IsTag(const StringPiece &value) { +  // The parser should never give an empty string. +  assert(!value.empty()); +  return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>'); +} + +class Single { +  public: +    typedef boost::unordered_set<std::string> Words; + +    explicit Single(const Words &vocab) : vocab_(vocab) {} + +    template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) { +      for (Iterator i = begin; i != end; ++i) { +        if (IsTag(*i)) continue; +        if (FindStringPiece(vocab_, *i) == vocab_.end()) return false; +      } +      return true; +    } + +  private: +    const Words &vocab_; +}; + +class Union { +  public: +    typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words; + +    explicit Union(const Words &vocabs) : vocabs_(vocabs) {} + +    template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) { +      sets_.clear(); + +      for (Iterator i(begin); i != end; ++i) { +        if (IsTag(*i)) continue; +        Words::const_iterator found(FindStringPiece(vocabs_, *i)); +        if (vocabs_.end() == found) return false; +        sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end())); +      } +      return (sets_.empty() || util::FirstIntersection(sets_)); +    } + +  private: +    const Words &vocabs_; + +    std::vector<boost::iterator_range<const unsigned int*> > sets_; +}; + +class Multiple { +  public: +    typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words; + +    Multiple(const Words &vocabs) : vocabs_(vocabs) {} + +  private: +    // Callback from AllIntersection that does AddNGram. +    template <class Output> class Callback { +      public: +        Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {} + +        void operator()(unsigned int index) { +          out_.SingleAddNGram(index, line_); +        } + +      private: +        Output &out_; +        const StringPiece &line_; +    }; + +  public: +    template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { +      sets_.clear(); +      for (Iterator i(begin); i != end; ++i) { +        if (IsTag(*i)) continue; +        Words::const_iterator found(FindStringPiece(vocabs_, *i)); +        if (vocabs_.end() == found) return; +        sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end())); +      } +      if (sets_.empty()) { +        output.AddNGram(line); +        return; +      } + +      Callback<Output> cb(output, line); +      util::AllIntersection(sets_, cb); +    } + +    template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { +      AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output); +    } + +    void Flush() const {} + +  private: +    const Words &vocabs_; + +    std::vector<boost::iterator_range<const unsigned int*> > sets_; +}; + +} // namespace vocab +} // namespace lm + +#endif // LM_FILTER_VOCAB_H__ diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh new file mode 100644 index 00000000..90b07a08 --- /dev/null +++ b/klm/lm/filter/wrapper.hh @@ -0,0 +1,58 @@ +#ifndef LM_FILTER_WRAPPER_H__ +#define LM_FILTER_WRAPPER_H__ + +#include "util/string_piece.hh" + +#include <algorithm> +#include <string> +#include <vector> + +namespace lm { + +// Provide a single-output filter with the same interface as a +// multiple-output filter so clients code against one interface. +template <class Binary> class BinaryFilter { +  public: +    // Binary modes are just references (and a set) and it makes the API cleaner to copy them.   +    explicit BinaryFilter(Binary binary) : binary_(binary) {} + +    template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { +      if (binary_.PassNGram(begin, end)) +        output.AddNGram(line); +    } + +    template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { +      AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output); +    } + +    void Flush() const {} + +  private: +    Binary binary_; +}; + +// Wrap another filter to pay attention only to context words +template <class FilterT> class ContextFilter { +  public: +    typedef FilterT Filter; + +    explicit ContextFilter(Filter &backend) : backend_(backend) {} + +    template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { +      pieces_.clear(); +      // TODO: this copy could be avoided by a lookahead iterator. +      std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_)); +      backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output); +    } + +    void Flush() const {} + +  private: +    std::vector<StringPiece> pieces_; + +    Filter backend_; +}; + +} // namespace lm + +#endif // LM_FILTER_WRAPPER_H__ diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 32084b5b..eb159094 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -1,6 +1,7 @@  #include "lm/model.hh"  #include <stdlib.h> +#include <string.h>  #define BOOST_TEST_MODULE ModelTest  #include <boost/test/unit_test.hpp> @@ -22,17 +23,20 @@ std::ostream &operator<<(std::ostream &o, const State &state) {  namespace { +// Stupid bjam reverses the command line arguments randomly.  const char *TestLocation() { -  if (boost::unit_test::framework::master_test_suite().argc < 2) { +  if (boost::unit_test::framework::master_test_suite().argc < 3) {      return "test.arpa";    } -  return boost::unit_test::framework::master_test_suite().argv[1]; +  char **argv = boost::unit_test::framework::master_test_suite().argv; +  return argv[strstr(argv[1], "nounk") ? 2 : 1];  }  const char *TestNoUnkLocation() {    if (boost::unit_test::framework::master_test_suite().argc < 3) {      return "test_nounk.arpa";    } -  return boost::unit_test::framework::master_test_suite().argv[2]; +  char **argv = boost::unit_test::framework::master_test_suite().argv; +  return argv[strstr(argv[1], "nounk") ? 1 : 2];  }  template <class Model> State GetState(const Model &model, const char *word, const State &in) { diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index b709fef9..9ea08798 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -1,6 +1,7 @@  #include "lm/read_arpa.hh"  #include "lm/blank.hh" +#include "util/file.hh"  #include <cmath>  #include <cstdlib> @@ -45,8 +46,14 @@ uint64_t ReadCount(const std::string &from) {  void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {    number.clear(); -  StringPiece line; -  while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} +  StringPiece line = in.ReadLine(); +  // In general, ARPA files can have arbitrary text before "\data\" +  // But in KenLM, we require such lines to start with "#", so that +  // we can do stricter error checking +  while (IsEntirelyWhiteSpace(line) || line.starts_with("#")) { +    line = in.ReadLine(); +  } +    if (line != "\\data\\") {      if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {        UTIL_THROW(FormatLoadException, "Looks like a gzip file.  If this is an ARPA file, pipe " << in.FileName() << " through zcat.  If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); diff --git a/klm/lm/sizes.cc b/klm/lm/sizes.cc new file mode 100644 index 00000000..55ad586c --- /dev/null +++ b/klm/lm/sizes.cc @@ -0,0 +1,63 @@ +#include "lm/sizes.hh" +#include "lm/model.hh" +#include "util/file_piece.hh" + +#include <vector> +#include <iomanip> + +namespace lm { +namespace ngram { + +void ShowSizes(const std::vector<uint64_t> &counts, const lm::ngram::Config &config) { +  uint64_t sizes[6]; +  sizes[0] = ProbingModel::Size(counts, config); +  sizes[1] = RestProbingModel::Size(counts, config); +  sizes[2] = TrieModel::Size(counts, config); +  sizes[3] = QuantTrieModel::Size(counts, config); +  sizes[4] = ArrayTrieModel::Size(counts, config); +  sizes[5] = QuantArrayTrieModel::Size(counts, config); +  uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); +  uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); +  uint64_t divide; +  char prefix; +  if (min_length < (1 << 10) * 10) { +    prefix = ' '; +    divide = 1; +  } else if (min_length < (1 << 20) * 10) { +    prefix = 'k'; +    divide = 1 << 10; +  } else if (min_length < (1ULL << 30) * 10) { +    prefix = 'M'; +    divide = 1 << 20; +  } else { +    prefix = 'G'; +    divide = 1 << 30; +  } +  long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide)))); +  std::cerr << "Memory estimate for binary LM:\ntype    "; + +  // right align bytes.   +  for (long int i = 0; i < length - 2; ++i) std::cerr << ' '; + +  std::cerr << prefix << "B\n" +    "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" +    "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n" +    "trie    " << std::setw(length) << (sizes[2] / divide) << " without quantization\n" +    "trie    " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" +    "trie    " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" +    "trie    " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; +} + +void ShowSizes(const std::vector<uint64_t> &counts) { +  lm::ngram::Config config; +  ShowSizes(counts, config); +} + +void ShowSizes(const char *file, const lm::ngram::Config &config) { +  std::vector<uint64_t> counts; +  util::FilePiece f(file); +  lm::ReadARPACounts(f, counts); +  ShowSizes(counts, config); +} + +}} //namespaces diff --git a/klm/lm/sizes.hh b/klm/lm/sizes.hh new file mode 100644 index 00000000..85abade7 --- /dev/null +++ b/klm/lm/sizes.hh @@ -0,0 +1,17 @@ +#ifndef LM_SIZES__ +#define LM_SIZES__ + +#include <vector> + +#include <stdint.h> + +namespace lm { namespace ngram { + +struct Config; + +void ShowSizes(const std::vector<uint64_t> &counts, const lm::ngram::Config &config); +void ShowSizes(const std::vector<uint64_t> &counts); +void ShowSizes(const char *file, const lm::ngram::Config &config); + +}} // namespaces +#endif // LM_SIZES__ diff --git a/klm/lm/state.hh b/klm/lm/state.hh index 551510a8..d8e6c132 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -56,14 +56,14 @@ inline uint64_t hash_value(const State &state, uint64_t seed = 0) {  struct Left {    bool operator==(const Left &other) const {      return  -      (length == other.length) &&  -      pointers[length - 1] == other.pointers[length - 1] && -      full == other.full; +      length == other.length && +      (!length || (pointers[length - 1] == other.pointers[length - 1] && full == other.full));    }    int Compare(const Left &other) const {      if (length < other.length) return -1;      if (length > other.length) return 1; +    if (length == 0) return 0; // Must be full.      if (pointers[length - 1] > other.pointers[length - 1]) return 1;      if (pointers[length - 1] < other.pointers[length - 1]) return -1;      return (int)full - (int)other.full; diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 8663e94e..dc542bb3 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -65,13 +65,13 @@ class PartialViewProxy {  typedef util::ProxyIterator<PartialViewProxy> PartialIter; -FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) { -  util::scoped_fd file(maker.Make()); +FILE *DiskFlush(const void *mem_begin, const void *mem_end, const std::string &temp_prefix) { +  util::scoped_fd file(util::MakeTemp(temp_prefix));    util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin);    return util::FDOpenOrThrow(file);  } -FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) { +FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_prefix, std::size_t entry_size, unsigned char order) {    const size_t context_size = sizeof(WordIndex) * (order - 1);    // Sort just the contexts using the same memory.      PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); @@ -84,7 +84,7 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make  #endif      (context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1))); -  util::scoped_FILE out(maker.MakeFile()); +  util::scoped_FILE out(util::FMakeTemp(temp_prefix));    // Write out to file and uniqueify at the same time.  Could have used unique_copy if there was an appropriate OutputIterator.      if (context_begin == context_end) return out.release(); @@ -114,12 +114,12 @@ struct FirstCombine {    }  }; -template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) { +template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const std::string &temp_prefix, std::size_t weights_size, unsigned char order, const Combine &combine) {    std::size_t entry_size = sizeof(WordIndex) * order + weights_size;    RecordReader first, second;    first.Init(first_file, entry_size);    second.Init(second_file, entry_size); -  util::scoped_FILE out_file(maker.MakeFile()); +  util::scoped_FILE out_file(util::FMakeTemp(temp_prefix));    EntryCompare less(order);    while (first && second) {      if (less(first.Data(), second.Data())) { @@ -177,9 +177,8 @@ void RecordReader::Rewind() {  }  SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { -  util::TempMaker maker(file_prefix);    PositiveProbWarn warn(config.positive_log_probability); -  unigram_.reset(maker.Make()); +  unigram_.reset(util::MakeTemp(file_prefix));    {      // In case <unk> appears.        size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff); @@ -202,7 +201,7 @@ SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<u    if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);    for (unsigned char order = 2; order <= counts.size(); ++order) { -    ConvertToSorted(f, vocab, counts, maker, order, warn, mem.get(), buffer); +    ConvertToSorted(f, vocab, counts, file_prefix, order, warn, mem.get(), buffer);    }    ReadEnd(f);  } @@ -227,7 +226,7 @@ class Closer {  };  } // namespace -void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) { +void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {    ReadNGramHeader(f, order);    const size_t count = counts[order - 1];    // Size of weights.  Does it include backoff?   @@ -261,8 +260,8 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo      std::sort  #endif          (NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order))); -    files.push_back(DiskFlush(begin, out_end, maker)); -    contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order)); +    files.push_back(DiskFlush(begin, out_end, file_prefix)); +    contexts.push_back(WriteContextFile(begin, out_end, file_prefix, entry_size, order));      done += (out_end - begin) / entry_size;    } @@ -270,10 +269,10 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo    // All individual files created.  Merge them.      while (files.size() > 1) { -    files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine())); +    files.push_back(MergeSortedFiles(files[0], files[1], file_prefix, weights_size, order, ThrowCombine()));      files_closer.PopFront();      files_closer.PopFront(); -    contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine())); +    contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], file_prefix, 0, order - 1, FirstCombine()));      contexts_closer.PopFront();      contexts_closer.PopFront();    } diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh index 2197b80c..1afd9562 100644 --- a/klm/lm/trie_sort.hh +++ b/klm/lm/trie_sort.hh @@ -18,7 +18,6 @@  namespace util {  class FilePiece; -class TempMaker;  } // namespace util  namespace lm { @@ -101,7 +100,7 @@ class SortedFiles {      }    private: -    void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size); +    void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size);      util::scoped_fd unigram_; | 
