diff options
Diffstat (limited to 'klm/lm')
| -rw-r--r-- | klm/lm/binary_format.cc | 1 | ||||
| -rw-r--r-- | klm/lm/binary_format.hh | 9 | ||||
| -rw-r--r-- | klm/lm/build_binary.cc | 112 | ||||
| -rw-r--r-- | klm/lm/enumerate_vocab.hh | 7 | ||||
| -rw-r--r-- | klm/lm/lm_exception.cc | 8 | ||||
| -rw-r--r-- | klm/lm/lm_exception.hh | 18 | ||||
| -rw-r--r-- | klm/lm/model.cc | 1 | ||||
| -rw-r--r-- | klm/lm/model.hh | 23 | ||||
| -rw-r--r-- | klm/lm/model_test.cc | 4 | ||||
| -rw-r--r-- | klm/lm/read_arpa.cc | 4 | ||||
| -rw-r--r-- | klm/lm/read_arpa.hh | 2 | ||||
| -rw-r--r-- | klm/lm/search_trie.cc | 200 | ||||
| -rw-r--r-- | klm/lm/trie.cc | 42 | 
13 files changed, 329 insertions, 102 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 2a075b6b..69a06355 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -141,7 +141,6 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t  }  uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) { -  if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0");    if (config.write_mmap) {      std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size;      // Write out an mmap file.   diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index f95f05f7..a43c883c 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -67,9 +67,12 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)      if (detail::IsBinaryFormat(backing.file.get())) {        detail::ReadHeader(backing.file.get(), params);        detail::MatchCheck(To::kModelType, params); -      std::size_t memory_size = To::Size(params.counts, config); -      uint8_t *start = detail::SetupBinary(config, params, memory_size, backing); -      to.InitializeFromBinary(start, params, config, backing.file.get()); +      // Replace the probing_multiplier.   +      Config new_config(config); +      new_config.probing_multiplier = params.fixed.probing_multiplier; +      std::size_t memory_size = To::Size(params.counts, new_config); +      uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); +      to.InitializeFromBinary(start, params, new_config, backing.file.get());      } else {        detail::ComplainAboutARPA(config, To::kModelType);        util::FilePiece f(backing.file.release(), file, config.messages); diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 4db631a2..ec034640 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -1,13 +1,113 @@  #include "lm/model.hh" +#include "util/file_piece.hh"  #include <iostream> +#include <iomanip> + +#include <math.h> +#include <stdlib.h> +#include <unistd.h> + +namespace lm { +namespace ngram { +namespace { + +void Usage(const char *name) { +  std::cerr << "Usage: " << name << " [-u unknown_probability] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" +"Where type is one of probing, trie, or sorted:\n\n" +"probing uses a probing hash table.  It is the fastest but uses the most memory.\n" +"-p sets the space multiplier and must be >1.0.  The default is 1.5.\n\n" +"trie is a straightforward trie with bit-level packing.  It uses the least\n" +"memory and is still faster than SRI or IRST.  Building the trie format uses an\n" +"on-disk sort to save memory.\n" +"-t is the temporary directory prefix.  Default is the output file name.\n" +"-m is the amount of memory to use, in MB.  Default is 1024MB (1GB).\n\n" +"sorted is like probing but uses a sorted uniform map instead of a hash table.\n" +"It uses more memory than trie and is also slower, so there's no real reason to\n" +"use it.\n\n" +"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" +"Passing only an input file will print memory usage of each data structure.\n" +"If the ARPA file does not have <unk>, -u sets <unk>'s probability; default 0.0.\n"; +  exit(1); +} + +// I could really use boost::lexical_cast right about now.   +float ParseFloat(const char *from) { +  char *end; +  float ret = strtod(from, &end); +  if (*end) throw util::ParseNumberException(from); +  return ret; +} +unsigned long int ParseUInt(const char *from) { +  char *end; +  unsigned long int ret = strtoul(from, &end, 10); +  if (*end) throw util::ParseNumberException(from); +  return ret; +} + +void ShowSizes(const char *file, const lm::ngram::Config &config) { +  std::vector<uint64_t> counts; +  util::FilePiece f(file); +  lm::ReadARPACounts(f, counts); +  std::size_t probing_size = ProbingModel::Size(counts, config); +  // probing is always largest so use it to determine number of columns.   +  long int length = std::max<long int>(5, lrint(ceil(log10(probing_size)))); +  std::cout << "Memory usage:\ntype    "; +  // right align bytes.   +  for (long int i = 0; i < length - 5; ++i) std::cout << ' '; +  std::cout << "bytes\n" +    "probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n" +    "trie    " << std::setw(length) << TrieModel::Size(counts, config) << "\n" +    "sorted  " << std::setw(length) << SortedModel::Size(counts, config) << "\n"; +} + +} // namespace ngram +} // namespace lm +} // namespace  int main(int argc, char *argv[]) { -  if (argc != 3) { -    std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl; -    return 1; -  } +  using namespace lm::ngram; +    lm::ngram::Config config; -  config.write_mmap = argv[2]; -  lm::ngram::Model(argv[1], config); +  int opt; +  while ((opt = getopt(argc, argv, "u:p:t:m:")) != -1) { +    switch(opt) { +      case 'u': +        config.unknown_missing_prob = ParseFloat(optarg); +        break; +      case 'p': +        config.probing_multiplier = ParseFloat(optarg); +        break; +      case 't': +        config.temporary_directory_prefix = optarg; +        break; +      case 'm': +        config.building_memory = ParseUInt(optarg) * 1048576; +        break; +      default: +        Usage(argv[0]); +    } +  } +  if (optind + 1 == argc) { +    ShowSizes(argv[optind], config); +  } else if (optind + 2 == argc) { +    config.write_mmap = argv[optind + 1]; +    ProbingModel(argv[optind], config); +  } else if (optind + 3 == argc) { +    const char *model_type = argv[optind]; +    const char *from_file = argv[optind + 1]; +    config.write_mmap = argv[optind + 2]; +    if (!strcmp(model_type, "probing")) { +      ProbingModel(from_file, config); +    } else if (!strcmp(model_type, "sorted")) { +      SortedModel(from_file, config); +    } else if (!strcmp(model_type, "trie")) { +      TrieModel(from_file, config); +    } else { +      Usage(argv[0]); +    } +  } else { +    Usage(argv[0]); +  } +  return 0;  } diff --git a/klm/lm/enumerate_vocab.hh b/klm/lm/enumerate_vocab.hh index 7a2f7d12..e734316b 100644 --- a/klm/lm/enumerate_vocab.hh +++ b/klm/lm/enumerate_vocab.hh @@ -8,9 +8,10 @@ namespace lm {  namespace ngram {  /* If you need the actual strings in the vocabulary, inherit from this class - * and implement Add.  Then put a pointer in Config.enumerate_vocab.   - * Add is called once per n-gram.  index starts at 0 and increases by 1 each - * time.   + * and implement Add.  Then put a pointer in Config.enumerate_vocab; it does + * not take ownership.  Add is called once per vocab word.  index starts at 0 + * and increases by 1 each time.  This is only used by the Model constructor; + * the pointer is not retained by the class.     */  class EnumerateVocab {    public: diff --git a/klm/lm/lm_exception.cc b/klm/lm/lm_exception.cc index ab2ec52f..473849d1 100644 --- a/klm/lm/lm_exception.cc +++ b/klm/lm/lm_exception.cc @@ -5,14 +5,18 @@  namespace lm { +ConfigException::ConfigException() throw() {} +ConfigException::~ConfigException() throw() {} +  LoadException::LoadException() throw() {}  LoadException::~LoadException() throw() {} -VocabLoadException::VocabLoadException() throw() {} -VocabLoadException::~VocabLoadException() throw() {}  FormatLoadException::FormatLoadException() throw() {}  FormatLoadException::~FormatLoadException() throw() {} +VocabLoadException::VocabLoadException() throw() {} +VocabLoadException::~VocabLoadException() throw() {} +  SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() {    *this << "Missing special word " << which;  } diff --git a/klm/lm/lm_exception.hh b/klm/lm/lm_exception.hh index 1216c4c7..3773c572 100644 --- a/klm/lm/lm_exception.hh +++ b/klm/lm/lm_exception.hh @@ -11,6 +11,12 @@  namespace lm { +class ConfigException : public util::Exception { +  public: +    ConfigException() throw(); +    ~ConfigException() throw(); +}; +  class LoadException : public util::Exception {     public:        virtual ~LoadException() throw(); @@ -19,18 +25,18 @@ class LoadException : public util::Exception {        LoadException() throw();  }; -class VocabLoadException : public LoadException { -  public: -    virtual ~VocabLoadException() throw(); -    VocabLoadException() throw(); -}; -  class FormatLoadException : public LoadException {    public:      FormatLoadException() throw();      ~FormatLoadException() throw();  }; +class VocabLoadException : public LoadException { +  public: +    virtual ~VocabLoadException() throw(); +    VocabLoadException() throw(); +}; +  class SpecialWordMissingException : public VocabLoadException {    public:      explicit SpecialWordMissingException(StringPiece which) throw(); diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 6921d4d9..421e72fa 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -23,6 +23,7 @@ namespace detail {  template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {    if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ".  Edit ngram.hh's kMaxOrder to at least this value and recompile.");    if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); +  if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");    return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);  } diff --git a/klm/lm/model.hh b/klm/lm/model.hh index e0eeee17..53e5773d 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -12,6 +12,8 @@  #include <algorithm>  #include <vector> +#include <string.h> +  namespace util { class FilePiece; }  namespace lm { @@ -21,9 +23,10 @@ namespace ngram {  // Having this limit means that State can be  // (kMaxOrder - 1) * sizeof(float) bytes instead of  // sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead -const std::size_t kMaxOrder = 6; +const unsigned char kMaxOrder = 6; -// This is a POD.   +// This is a POD but if you want memcmp to return the same as operator==, call +// ZeroRemaining first.      class State {    public:      bool operator==(const State &other) const { @@ -37,6 +40,22 @@ class State {        return true;      } +    // Three way comparison function.   +    int Compare(const State &other) const { +      if (valid_length_ == other.valid_length_) { +        return memcmp(history_, other.history_, valid_length_ * sizeof(WordIndex)); +      } +      return (valid_length_ < other.valid_length_) ? -1 : 1; +    } + +    // Call this before using raw memcmp.   +    void ZeroRemaining() { +      for (unsigned char i = valid_length_; i < kMaxOrder - 1; ++i) { +        history_[i] = 0; +        backoff_[i] = 0.0; +      } +    } +      // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.        // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.        WordIndex history_[kMaxOrder - 1]; diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 159628d4..b5125a95 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -4,6 +4,7 @@  #define BOOST_TEST_MODULE ModelTest  #include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp>  namespace lm {  namespace ngram { @@ -123,7 +124,7 @@ class ExpectEnumerateVocab : public EnumerateVocab {      }      void Check(const base::Vocabulary &vocab) { -      BOOST_CHECK_EQUAL(34, seen.size()); +      BOOST_CHECK_EQUAL(34ULL, seen.size());        BOOST_REQUIRE(!seen.empty());        BOOST_CHECK_EQUAL("<unk>", seen[0]);        for (WordIndex i = 0; i < seen.size(); ++i) { @@ -144,6 +145,7 @@ template <class ModelT> void LoadingTest() {    config.messages = NULL;    ExpectEnumerateVocab enumerate;    config.enumerate_vocab = &enumerate; +  config.probing_multiplier = 2.0;    ModelT m("test.arpa", config);    enumerate.Check(m.GetVocabulary());    Starters(m); diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 8e9a770d..262a9c6a 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -49,7 +49,7 @@ template <class F> void GenericReadNGramHeader(F &in, unsigned int length) {    while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}    std::stringstream expected;    expected << '\\' << length << "-grams:"; -  if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead.  "); +  if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead");  }  template <class F> void GenericReadEnd(F &in) { @@ -110,7 +110,7 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {        {          float got = in.ReadFloat();          if (got != 0.0) -          UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff."); +          UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff");        }        break;      case '\n': diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh index cabdb195..571fcbc5 100644 --- a/klm/lm/read_arpa.hh +++ b/klm/lm/read_arpa.hh @@ -54,7 +54,7 @@ template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const uns      }      ReadBackoff(f, weights);    } catch(util::Exception &e) { -    e << " in the " <<  n << "-gram at byte " << f.Offset(); +    e << " in the " << static_cast<unsigned int>(n) << "-gram at byte " << f.Offset();      throw;    }  } diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 182e27f5..12294682 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,3 +1,4 @@ +/* This is where the trie is built.  It's on-disk.  */  #include "lm/search_trie.hh"  #include "lm/lm_exception.hh" @@ -8,6 +9,7 @@  #include "lm/word_index.hh"  #include "util/ersatz_progress.hh"  #include "util/file_piece.hh" +#include "util/proxy_iterator.hh"  #include "util/scoped.hh"  #include <algorithm> @@ -30,43 +32,119 @@ namespace ngram {  namespace trie {  namespace { -template <unsigned char Order> class FullEntry { +/* An entry is a n-gram with probability.  It consists of: + * WordIndex[order] + * float probability + * backoff probability (omitted for highest order n-gram) + * These are stored consecutively in memory.  We want to sort them.   + * + * The problem is the length depends on order (but all n-grams being compared + * have the same order).  Allocating each entry on the heap (i.e. std::vector + * or std::string) then sorting pointers is the normal solution.  But that's + * too memory inefficient.  A lot of this code is just here to force std::sort + * to work with records where length is specified at runtime (and avoid using + * Boost for LM code).  I could have used qsort, but the point is to also + * support __gnu_cxx:parallel_sort which doesn't have a qsort version.   + */ + +class EntryIterator {    public: -    typedef ProbBackoff Weights; -    static const unsigned char kOrder = Order; +    EntryIterator() {} -    // reverse order -    WordIndex words[Order]; -    Weights weights; +    EntryIterator(void *ptr, std::size_t size) : ptr_(static_cast<uint8_t*>(ptr)), size_(size) {} -    bool operator<(const FullEntry<Order> &other) const { -      for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { -        if (*i < *j) return true; -        if (*i > *j) return false; -      } -      return false; +    bool operator==(const EntryIterator &other) const { +      return ptr_ == other.ptr_; +    } +    bool operator<(const EntryIterator &other) const { +      return ptr_ < other.ptr_; +    } +    EntryIterator &operator+=(std::ptrdiff_t amount) { +      ptr_ += amount * size_; +      return *this; +    } +    std::ptrdiff_t operator-(const EntryIterator &other) const { +      return (ptr_ - other.ptr_) / size_;      } + +    const void *Data() const { return ptr_; } +    void *Data() { return ptr_; } +    std::size_t EntrySize() const { return size_; } +     +  private: +    uint8_t *ptr_; +    std::size_t size_;  }; -template <unsigned char Order> class ProbEntry { +class EntryProxy {    public: -    typedef Prob Weights; -    static const unsigned char kOrder = Order; +    EntryProxy() {} + +    EntryProxy(void *ptr, std::size_t size) : inner_(ptr, size) {} + +    operator std::string() const { +      return std::string(reinterpret_cast<const char*>(inner_.Data()), inner_.EntrySize()); +    } + +    EntryProxy &operator=(const EntryProxy &from) { +      memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize()); +      return *this; +    } + +    EntryProxy &operator=(const std::string &from) { +      memcpy(inner_.Data(), from.data(), inner_.EntrySize()); +      return *this; +    } + +    const WordIndex *Indices() const { +      return static_cast<const WordIndex*>(inner_.Data()); +    } + +  private: +    friend class util::ProxyIterator<EntryProxy>; + +    typedef std::string value_type; -    // reverse order -    WordIndex words[Order]; -    Weights weights; +    typedef EntryIterator InnerIterator; +    InnerIterator &Inner() { return inner_; } +    const InnerIterator &Inner() const { return inner_; }  +    InnerIterator inner_; +}; -    bool operator<(const ProbEntry<Order> &other) const { -      for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { -        if (*i < *j) return true; -        if (*i > *j) return false; +typedef util::ProxyIterator<EntryProxy> NGramIter; + +class CompareRecords : public std::binary_function<const EntryProxy &, const EntryProxy &, bool> { +  public: +    explicit CompareRecords(unsigned char order) : order_(order) {} + +    bool operator()(const EntryProxy &first, const EntryProxy &second) const { +      return Compare(first.Indices(), second.Indices()); +    } +    bool operator()(const EntryProxy &first, const std::string &second) const { +      return Compare(first.Indices(), reinterpret_cast<const WordIndex*>(second.data())); +    } +    bool operator()(const std::string &first, const EntryProxy &second) const { +      return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices()); +    } +    bool operator()(const std::string &first, const std::string &second) const { +      return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(first.data())); +    } +     +  private: +    bool Compare(const WordIndex *first, const WordIndex *second) const { +      const WordIndex *end = first + order_; +      for (; first != end; ++first, ++second) { +        if (*first < *second) return true; +        if (*first > *second) return false;        }        return false;      } + +    unsigned char order_;  };  void WriteOrThrow(FILE *to, const void *data, size_t size) { +  assert(size);    if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);  } @@ -84,21 +162,24 @@ void CopyOrThrow(FILE *from, FILE *to, size_t size) {    }  } -template <class Entry> std::string DiskFlush(const Entry *begin, const Entry *end, const std::string &file_prefix, std::size_t batch) { +std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) { +  const std::size_t entry_size = sizeof(WordIndex) * order + weights_size; +  const std::size_t prefix_size = sizeof(WordIndex) * (order - 1);    std::stringstream assembled; -  assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << '_' << batch; +  assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch;    std::string ret(assembled.str());    util::scoped_FILE out(fopen(ret.c_str(), "w"));    if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing"); -  for (const Entry *group_begin = begin; group_begin != end;) { -    const Entry *group_end = group_begin; -    for (++group_end; (group_end != end) && !memcmp(group_begin->words, group_end->words, sizeof(WordIndex) * (Entry::kOrder - 1)); ++group_end) {} -    WriteOrThrow(out.get(), group_begin->words, sizeof(WordIndex) * (Entry::kOrder - 1)); -    WordIndex group_size = group_end - group_begin; +  // Compress entries that being with the same (order-1) words. +  for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) { +    const uint8_t *group_end = group_begin; +    for (group_end += entry_size; (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size); group_end += entry_size) {} +    WriteOrThrow(out.get(), group_begin, prefix_size); +    WordIndex group_size = (group_end - group_begin) / entry_size;      WriteOrThrow(out.get(), &group_size, sizeof(group_size)); -    for (const Entry *i = group_begin; i != group_end; ++i) { -      WriteOrThrow(out.get(), &i->words[Entry::kOrder - 1], sizeof(WordIndex)); -      WriteOrThrow(out.get(), &i->weights, sizeof(typename Entry::Weights)); +    for (const uint8_t *i = group_begin; i != group_end; i += entry_size) { +      WriteOrThrow(out.get(), i + prefix_size, sizeof(WordIndex)); +      WriteOrThrow(out.get(), i + sizeof(WordIndex) * order, weights_size);      }      group_begin = group_end;    } @@ -219,25 +300,37 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha    }  } -template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix) { -  ConvertToSorted<FullEntry<Entry::kOrder - 1> >(f, vocab, counts, mem, file_prefix); - -  ReadNGramHeader(f, Entry::kOrder); -  const size_t count = counts[Entry::kOrder - 1]; -  const size_t batch_size = std::min(count, mem.size() / sizeof(Entry)); -  Entry *const begin = reinterpret_cast<Entry*>(mem.get()); +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { +  if (order == 1) return; +  ConvertToSorted(f, vocab, counts, mem, file_prefix, order - 1); + +  ReadNGramHeader(f, order); +  const size_t count = counts[order - 1]; +  // Size of weights.  Does it include backoff?   +  const size_t words_size = sizeof(WordIndex) * order; +  const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); +  const size_t entry_size = words_size + weights_size; +  const size_t batch_size = std::min(count, mem.size() / entry_size); +  uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get());    std::deque<std::string> files;    for (std::size_t batch = 0, done = 0; done < count; ++batch) { -    Entry *out = begin; -    Entry *out_end = out + std::min(count - done, batch_size); -    for (; out != out_end; ++out) { -      ReadNGram(f, Entry::kOrder, vocab, out->words, out->weights); +    uint8_t *out = begin; +    uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; +    if (order == counts.size()) { +      for (; out != out_end; out += entry_size) { +        ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size)); +      } +    } else { +      for (; out != out_end; out += entry_size) { +        ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size)); +      }      } -    //__gnu_parallel::sort(begin, out_end); -    std::sort(begin, out_end); +    // TODO: __gnu_parallel::sort here. +    EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); +    std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order)); -    files.push_back(DiskFlush(begin, out_end, file_prefix, batch)); -    done += out_end - begin; +    files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); +    done += (out_end - begin) / entry_size;    }    // All individual files created.  Merge them.   @@ -245,9 +338,9 @@ template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVoca    std::size_t merge_count = 0;    while (files.size() > 1) {      std::stringstream assembled; -    assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merge_" << (merge_count++); +    assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++);      files.push_back(assembled.str()); -    MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), sizeof(typename Entry::Weights), Entry::kOrder); +    MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), weights_size, order);      if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);      files.pop_front();      if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); @@ -255,14 +348,12 @@ template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVoca    }    if (!files.empty()) {      std::stringstream assembled; -    assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merged"; +    assembled << file_prefix << static_cast<unsigned int>(order) << "_merged";      std::string merged_name(assembled.str());      if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str());    }  } -template <> void ConvertToSorted<FullEntry<1> >(util::FilePiece &/*f*/, const SortedVocabulary &/*vocab*/, const std::vector<uint64_t> &/*counts*/, util::scoped_memory &/*mem*/, const std::string &/*file_prefix*/) {} -  void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {    {      std::string unigram_name = file_prefix + "unigrams"; @@ -275,7 +366,7 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts,    util::scoped_memory mem;    mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED);    if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); -  ConvertToSorted<ProbEntry<5> >(f, vocab, counts, mem, file_prefix); +  ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size());    ReadEnd(f);  } @@ -390,7 +481,8 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const    temporary_directory.resize(strlen(temporary_directory.c_str()));    // Add directory delimiter.  Assumes a real operating system.      temporary_directory += '/'; -  ARPAToSortedFiles(f, counts, config.building_memory, temporary_directory.c_str(), vocab); +  // At least 1MB sorting memory.   +  ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);    BuildTrie(temporary_directory.c_str(), counts, config.messages, *this);    if (rmdir(temporary_directory.c_str())) {      std::cerr << "Failed to delete " << temporary_directory << std::endl; diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 8ed7b2a2..04bd2079 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -15,21 +15,21 @@ namespace {  // Assumes key is first.    class JustKeyProxy {    public: -    JustKeyProxy() : inner_(), base_(), key_mask_(), total_bits_() {} +    JustKeyProxy() : inner_(), base_(), key_mask_(), key_bits_(), total_bits_() {}      operator uint64_t() const { return GetKey(); }      uint64_t GetKey() const {        uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_); -      return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_mask_); +      return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_bits_, key_mask_);      }    private:      friend class util::ProxyIterator<JustKeyProxy>; -    friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); +    friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); -    JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t total_bits) -      : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), total_bits_(total_bits) {} +    JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) +      : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}      // This is a read-only iterator.        JustKeyProxy &operator=(const JustKeyProxy &other); @@ -44,12 +44,12 @@ class JustKeyProxy {      uint64_t inner_;      const uint8_t *const base_;      const uint64_t key_mask_; -    const uint8_t total_bits_; +    const uint8_t key_bits_, total_bits_;  }; -bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { -  util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, total_bits)); -  util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, total_bits)); +bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { +  util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits)); +  util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits));    util::ProxyIterator<JustKeyProxy> out;    if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false;    at_index = out.Inner(); @@ -96,67 +96,67 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t    assert(next <= next_mask_);    uint64_t at_pointer = insert_index_ * total_bits_; -  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word); +  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word);    at_pointer += word_bits_;    util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);    at_pointer += prob_bits_;    util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff);    at_pointer += backoff_bits_; -  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next); +  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next);    ++insert_index_;  }  bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {    uint64_t at_pointer; -  if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false; +  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;    at_pointer *= total_bits_;    at_pointer += word_bits_;    prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);    at_pointer += prob_bits_;    backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);    at_pointer += backoff_bits_; -  range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); +  range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);    // Read the next entry's pointer.      at_pointer += total_bits_; -  range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); +  range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);    return true;  }  bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {    uint64_t at_pointer; -  if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false; +  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;    at_pointer *= total_bits_;    at_pointer += word_bits_;    at_pointer += prob_bits_;    backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);    at_pointer += backoff_bits_; -  range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); +  range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);    // Read the next entry's pointer.      at_pointer += total_bits_; -  range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); +  range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);    return true;  }  void BitPackedMiddle::FinishedLoading(uint64_t next_end) {    assert(next_end <= next_mask_);    uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; -  util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_end); +  util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end);  }  void BitPackedLongest::Insert(WordIndex index, float prob) {    assert(index <= word_mask_);    uint64_t at_pointer = insert_index_ * total_bits_; -  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, index); +  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, index);    at_pointer += word_bits_;    util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);    ++insert_index_;  } -bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &node) const { +bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const {    uint64_t at_pointer; -  if (!FindBitPacked(base_, word_mask_, total_bits_, node.begin, node.end, word, at_pointer)) return false; +  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;    at_pointer = at_pointer * total_bits_ + word_bits_;    prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);    return true;  | 
