diff options
Diffstat (limited to 'klm')
| -rw-r--r-- | klm/lm/builder/lmplz_main.cc | 15 | ||||
| -rw-r--r-- | klm/lm/builder/ngram.hh | 2 | ||||
| -rw-r--r-- | klm/lm/model.cc | 21 | ||||
| -rw-r--r-- | klm/lm/model.hh | 5 | ||||
| -rw-r--r-- | klm/lm/search_hashed.cc | 29 | ||||
| -rw-r--r-- | klm/lm/search_hashed.hh | 19 | ||||
| -rw-r--r-- | klm/lm/search_trie.hh | 3 | ||||
| -rw-r--r-- | klm/lm/state.hh | 2 | ||||
| -rw-r--r-- | klm/lm/virtual_interface.hh | 3 | ||||
| -rw-r--r-- | klm/lm/vocab.hh | 2 | ||||
| -rw-r--r-- | klm/search/Makefile.am | 17 | ||||
| -rw-r--r-- | klm/search/context.hh | 12 | ||||
| -rw-r--r-- | klm/search/edge_generator.cc | 12 | ||||
| -rw-r--r-- | klm/search/vertex.cc | 204 | ||||
| -rw-r--r-- | klm/search/vertex.hh | 121 | ||||
| -rw-r--r-- | klm/search/vertex_generator.hh | 36 | ||||
| -rw-r--r-- | klm/util/double-conversion/utils.h | 15 | ||||
| -rw-r--r-- | klm/util/file.cc | 14 | ||||
| -rw-r--r-- | klm/util/pool.cc | 4 | ||||
| -rw-r--r-- | klm/util/probing_hash_table.hh | 23 | ||||
| -rw-r--r-- | klm/util/proxy_iterator.hh | 25 | ||||
| -rw-r--r-- | klm/util/sized_iterator.hh | 21 | ||||
| -rw-r--r-- | klm/util/stream/chain.hh | 2 | ||||
| -rw-r--r-- | klm/util/usage.cc | 80 | 
24 files changed, 473 insertions, 214 deletions
| diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc index 1e086dcc..c87abdb8 100644 --- a/klm/lm/builder/lmplz_main.cc +++ b/klm/lm/builder/lmplz_main.cc @@ -52,13 +52,14 @@ int main(int argc, char *argv[]) {        std::cerr <<           "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"          "Please cite:\n" -        "@inproceedings{kenlm,\n" -        "author    = {Kenneth Heafield},\n" -        "title     = {{KenLM}: Faster and Smaller Language Model Queries},\n" -        "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n" -        "month     = {July}, year={2011},\n" -        "address   = {Edinburgh, UK},\n" -        "publisher = {Association for Computational Linguistics},\n" +        "@inproceedings{Heafield-estimate,\n" +        "  author = {Kenneth Heafield and Ivan Pouzyrevsky and Jonathan H. Clark and Philipp Koehn},\n" +        "  title = {Scalable Modified {Kneser-Ney} Language Model Estimation},\n" +        "  year = {2013},\n" +        "  month = {8},\n" +        "  booktitle = {Proceedings of the 51st Annual Meeting of the Association for Computational Linguistics},\n" +        "  address = {Sofia, Bulgaria},\n" +        "  url = {http://kheafield.com/professional/edinburgh/estimate\\_paper.pdf},\n"          "}\n\n"          "Provide the corpus on stdin.  The ARPA file will be written to stdout.  Order of\n"          "the model (-o) is the only mandatory option.  As this is an on-disk program,\n" diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh index 2984ed0b..f5681516 100644 --- a/klm/lm/builder/ngram.hh +++ b/klm/lm/builder/ngram.hh @@ -53,7 +53,7 @@ class NGram {      Payload &Value() { return *reinterpret_cast<Payload *>(end_); }      uint64_t &Count() { return Value().count; } -    const uint64_t Count() const { return Value().count; } +    uint64_t Count() const { return Value().count; }      std::size_t Order() const { return end_ - begin_; } diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a40fd2fb..a26654a6 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -304,5 +304,26 @@ template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiks  template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;  } // namespace detail + +base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) { +  RecognizeBinary(file_name, model_type); +  switch (model_type) { +    case PROBING: +      return new ProbingModel(file_name, config); +    case REST_PROBING: +      return new RestProbingModel(file_name, config); +    case TRIE: +      return new TrieModel(file_name, config); +    case QUANT_TRIE: +      return new QuantTrieModel(file_name, config); +    case ARRAY_TRIE: +      return new ArrayTrieModel(file_name, config); +    case QUANT_ARRAY_TRIE: +      return new QuantArrayTrieModel(file_name, config); +    default: +      UTIL_THROW(FormatLoadException, "Confused by model type " << model_type); +  } +} +  } // namespace ngram  } // namespace lm diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 13ff864e..60f55110 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -153,6 +153,11 @@ LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<Separat  typedef ::lm::ngram::ProbingVocabulary Vocabulary;  typedef ProbingModel Model; +/* Autorecognize the file type, load, and return the virtual base class.  Don't + * use the virtual base class if you can avoid it.  Instead, use the above + * classes as template arguments to your own virtual feature function.*/ +base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING); +  } // namespace ngram  } // namespace lm diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 2d6f15b2..62275d27 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -54,7 +54,7 @@ template <class Weights> class ActivateUnigram {      Weights *modify_;  }; -// Find the lower order entry, inserting blanks along the way as necessary.   +// Find the lower order entry, inserting blanks along the way as necessary.  template <class Value> void FindLower(      const std::vector<uint64_t> &keys,      typename Value::Weights &unigram, @@ -64,7 +64,7 @@ template <class Value> void FindLower(    typename Value::ProbingEntry entry;    // Backoff will always be 0.0.  We'll get the probability and rest in another pass.    entry.value.backoff = kNoExtensionBackoff; -  // Go back and find the longest right-aligned entry, informing it that it extends left.  Normally this will match immediately, but sometimes SRI is dumb.   +  // Go back and find the longest right-aligned entry, informing it that it extends left.  Normally this will match immediately, but sometimes SRI is dumb.    for (int lower = keys.size() - 2; ; --lower) {      if (lower == -1) {        between.push_back(&unigram); @@ -77,11 +77,11 @@ template <class Value> void FindLower(    }  } -// Between usually has  single entry, the value to adjust.  But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.   +// Between usually has  single entry, the value to adjust.  But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.  template <class Added, class Build> void AdjustLower(      const Added &added,      const Build &build, -    std::vector<typename Build::Value::Weights *> &between,  +    std::vector<typename Build::Value::Weights *> &between,      const unsigned int n,      const std::vector<WordIndex> &vocab_ids,      typename Build::Value::Weights *unigrams, @@ -93,14 +93,14 @@ template <class Added, class Build> void AdjustLower(    }    typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;    float prob = -fabs(between.back()->prob); -  // Order of the n-gram on which probabilities are based.   +  // Order of the n-gram on which probabilities are based.    unsigned char basis = n - between.size();    assert(basis != 0);    typename Build::Value::Weights **change = &between.back();    // Skip the basis.    --change;    if (basis == 1) { -    // Hallucinate a bigram based on a unigram's backoff and a unigram probability.   +    // Hallucinate a bigram based on a unigram's backoff and a unigram probability.      float &backoff = unigrams[vocab_ids[1]].backoff;      SetExtension(backoff);      prob += backoff; @@ -128,14 +128,14 @@ template <class Added, class Build> void AdjustLower(    typename std::vector<typename Value::Weights *>::const_iterator i(between.begin());    build.MarkExtends(**i, added);    const typename Value::Weights *longer = *i; -  // Everything has probability but is not marked as extending.   +  // Everything has probability but is not marked as extending.    for (++i; i != between.end(); ++i) {      build.MarkExtends(**i, *longer);      longer = *i;    }  } -// Continue marking lower entries even they know that they extend left.  This is used for upper/lower bounds.   +// Continue marking lower entries even they know that they extend left.  This is used for upper/lower bounds.  template <class Build> void MarkLower(      const std::vector<uint64_t> &keys,      const Build &build, @@ -144,15 +144,15 @@ template <class Build> void MarkLower(      int start_order,      const typename Build::Value::Weights &longer) {    if (start_order == 0) return; -  typename util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash>::MutableIterator iter; -  // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.   +  // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.    for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) {      if (even_lower == -1) {        build.MarkExtends(unigram, longer);        return;      } -    middle[even_lower].UnsafeMutableFind(keys[even_lower], iter); -    if (!build.MarkExtends(iter->value, longer)) return; +    if (!build.MarkExtends( +          middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value, +          longer)) return;    }  } @@ -168,7 +168,6 @@ template <class Build, class Activate, class Store> void ReadNGrams(      Store &store,      PositiveProbWarn &warn) {    typedef typename Build::Value Value; -  typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;    assert(n >= 2);    ReadNGramHeader(f, n); @@ -186,7 +185,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(      for (unsigned int h = 1; h < n - 1; ++h) {        keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);      } -    // Initially the sign bit is on, indicating it does not extend left.  Most already have this but there might +0.0.   +    // Initially the sign bit is on, indicating it does not extend left.  Most already have this but there might +0.0.      util::SetSign(entry.value.prob);      entry.key = keys[n-2]; @@ -203,7 +202,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(  } // namespace  namespace detail { -  +  template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {    std::size_t allocated = Unigram::Size(counts[0]);    unigram_ = Unigram(start, counts[0], allocated); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 00595796..9d067bc2 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -71,7 +71,7 @@ template <class Value> class HashedSearch {      static const bool kDifferentRest = Value::kDifferentRest;      static const unsigned int kVersion = 0; -    // TODO: move probing_multiplier here with next binary file format update.   +    // TODO: move probing_multiplier here with next binary file format update.      static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}      static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { @@ -102,14 +102,9 @@ template <class Value> class HashedSearch {        return ret;      } -#pragma GCC diagnostic ignored "-Wuninitialized"      MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {        node = extend_pointer; -      typename Middle::ConstIterator found; -      bool got = middle_[extend_length - 2].Find(extend_pointer, found); -      assert(got); -      (void)got; -      return MiddlePointer(found->value); +      return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value);      }      MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const { @@ -126,14 +121,14 @@ template <class Value> class HashedSearch {      }      LongestPointer LookupLongest(WordIndex word, const Node &node) const { -      // Sign bit is always on because longest n-grams do not extend left.   +      // Sign bit is always on because longest n-grams do not extend left.        typename Longest::ConstIterator found;        if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer();        return LongestPointer(found->value.prob);      } -    // Generate a node without necessarily checking that it actually exists.   -    // Optionally return false if it's know to not exist.   +    // Generate a node without necessarily checking that it actually exists. +    // Optionally return false if it's know to not exist.      bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {        assert(begin != end);        node = static_cast<Node>(*begin); @@ -144,7 +139,7 @@ template <class Value> class HashedSearch {      }    private: -    // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.   +    // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.      void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);      template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); @@ -153,7 +148,7 @@ template <class Value> class HashedSearch {        public:          Unigram() {} -        Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :  +        Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :            unigram_(static_cast<typename Value::Weights*>(start))  #ifdef DEBUG           ,  count_(count) diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 1264baf5..60be416b 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -41,7 +41,8 @@ template <class Quant, class Bhiksha> class TrieSearch {      static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {        Quant::UpdateConfigFromBinary(fd, counts, config);        util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); -      Bhiksha::UpdateConfigFromBinary(fd, config); +      // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2. +      if (counts.size() > 2) Bhiksha::UpdateConfigFromBinary(fd, config);      }      static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { diff --git a/klm/lm/state.hh b/klm/lm/state.hh index d8e6c132..a6b9accb 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -91,7 +91,7 @@ inline uint64_t hash_value(const Left &left) {  }  struct ChartState { -  bool operator==(const ChartState &other) { +  bool operator==(const ChartState &other) const {      return (right == other.right) && (left == other.left);    } diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 6a5a0196..17f064b2 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -6,6 +6,7 @@  #include "util/string_piece.hh"  #include <string> +#include <string.h>  namespace lm {  namespace base { @@ -119,7 +120,9 @@ class Model {      size_t StateSize() const { return state_size_; }      const void *BeginSentenceMemory() const { return begin_sentence_memory_; } +    void BeginSentenceWrite(void *to) const { memcpy(to, begin_sentence_memory_, StateSize()); }      const void *NullContextMemory() const { return null_context_memory_; } +    void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); }      // Requires in_state != out_state      virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 3902f117..226ae438 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -25,7 +25,7 @@ uint64_t HashForVocab(const char *str, std::size_t len);  inline uint64_t HashForVocab(const StringPiece &str) {    return HashForVocab(str.data(), str.length());  } -class ProbingVocabularyHeader; +struct ProbingVocabularyHeader;  } // namespace detail  class WriteWordsWrapper : public EnumerateVocab { diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am index 03554276..b8c8a050 100644 --- a/klm/search/Makefile.am +++ b/klm/search/Makefile.am @@ -1,23 +1,10 @@  noinst_LIBRARIES = libksearch.a  libksearch_a_SOURCES = \ -  applied.hh \ -  config.hh \ -  context.hh \ -  dedupe.hh \ -  edge.hh \ -  edge_generator.hh \ -  header.hh \ -  nbest.hh \ -  rule.hh \ -  types.hh \ -  vertex.hh \ -  vertex_generator.hh \    edge_generator.cc \  	nbest.cc \    rule.cc \ -  vertex.cc \ -  vertex_generator.cc +  vertex.cc -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/search/context.hh b/klm/search/context.hh index 08f21bbf..c3c8e53b 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -12,16 +12,6 @@ class ContextBase {    public:      explicit ContextBase(const Config &config) : config_(config) {} -    VertexNode *NewVertexNode() { -      VertexNode *ret = vertex_node_pool_.construct(); -      assert(ret); -      return ret; -    } - -    void DeleteVertexNode(VertexNode *node) { -      vertex_node_pool_.destroy(node); -    } -      unsigned int PopLimit() const { return config_.PopLimit(); }      Score LMWeight() const { return config_.LMWeight(); } @@ -29,8 +19,6 @@ class ContextBase {      const Config &GetConfig() const { return config_; }    private: -    boost::object_pool<VertexNode> vertex_node_pool_; -      Config config_;  }; diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index eacf5de5..dd9d61e4 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -54,20 +54,20 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {    Arity victim = 0;    Arity victim_completed;    Arity incomplete; +  unsigned char lowest_niceness = 255;    // Select victim or return if complete.       {      Arity completed = 0; -    unsigned char lowest_length = 255;      for (Arity i = 0; i != arity; ++i) {        if (top_nt[i].Complete()) {          ++completed; -      } else if (top_nt[i].Length() < lowest_length) { -        lowest_length = top_nt[i].Length(); +      } else if (top_nt[i].Niceness() < lowest_niceness) { +        lowest_niceness = top_nt[i].Niceness();          victim = i;          victim_completed = completed;        }      } -    if (lowest_length == 255) { +    if (lowest_niceness == 255) {        return top;      }      incomplete = arity - completed; @@ -92,10 +92,14 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {      generate_.push(alternate);    } +#ifndef NDEBUG   +  Score before = top.GetScore(); +#endif    // top is now the continuation.    FastScore(context, victim, victim - victim_completed, incomplete, old_value, top);    // TODO: dedupe?      generate_.push(top); +  assert(lowest_niceness != 254 || top.GetScore() == before);    // Invalid indicates no new hypothesis generated.      return PartialEdge(); diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index 45842982..bf40810e 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -2,6 +2,8 @@  #include "search/context.hh" +#include <boost/unordered_map.hpp> +  #include <algorithm>  #include <functional> @@ -11,45 +13,193 @@ namespace search {  namespace { -struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> { -  bool operator()(const VertexNode *first, const VertexNode *second) const { -    return first->Bound() > second->Bound(); +const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); + +class DivideLeft { +  public: +    explicit DivideLeft(unsigned char index) +      : index_(index) {} + +    uint64_t operator()(const lm::ngram::ChartState &state) const { +      return (index_ < state.left.length) ?  +        state.left.pointers[index_] : +        (kCompleteAdd - state.left.full); +    } + +  private: +    unsigned char index_; +}; + +class DivideRight { +  public: +    explicit DivideRight(unsigned char index) +      : index_(index) {} + +    uint64_t operator()(const lm::ngram::ChartState &state) const { +      return (index_ < state.right.length) ? +        static_cast<uint64_t>(state.right.words[index_]) : +        (kCompleteAdd - state.left.full); +    } + +  private: +    unsigned char index_; +}; + +template <class Divider> void Split(const Divider ÷r, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) { +  // Map from divider to index in extend. +  typedef boost::unordered_map<uint64_t, std::size_t> Lookup; +  Lookup lookup; +  for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) { +    uint64_t key = divider(i->state); +    std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size()))); +    if (res.second) { +      extend.resize(extend.size() + 1); +      extend.back().AppendHypothesis(*i); +    } else { +      extend[res.first->second].AppendHypothesis(*i); +    }    } +  //assert((extend.size() != 1) || (hypos.size() == 1)); +} + +lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) { +  return right.words[index]; +} + +uint64_t Identify(const lm::ngram::Left &left, unsigned char index) { +  return left.pointers[index]; +} + +template <class Side> class DetermineSame { +  public: +    DetermineSame(const Side &side, unsigned char guaranteed)  +      : side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {} + +    void Consider(const Side &other) { +      if (shared_ != other.length) { +        complete_ = false; +        if (shared_ > other.length) +          shared_ = other.length; +      } +      for (unsigned char i = guaranteed_; i < shared_; ++i) { +        if (Identify(side_, i) != Identify(other, i)) { +          shared_ = i; +          complete_ = false; +          return; +        } +      } +    } + +    unsigned char Shared() const { return shared_; } + +    bool Complete() const { return complete_; } + +  private: +    const Side &side_; +    unsigned char guaranteed_, shared_; +    bool complete_;  }; +// Custom enum to save memory: valid values of policy_. +// Alternate and there is still alternation to do. +const unsigned char kPolicyAlternate = 0; +// Branch based on left state only, because right ran out or this is a left tree. +const unsigned char kPolicyOneLeft = 1; +// Branch based on right state only. +const unsigned char kPolicyOneRight = 2; +// Reveal everything in the next branch.  Used to terminate the left/right policies. +//    static const unsigned char kPolicyEverything = 3; + +} // namespace + +namespace { +struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> { +  bool operator()(const HypoState &first, const HypoState &second) const { +    return first.score > second.score; +  } +};  } // namespace -void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) { -  if (Complete()) { -    assert(end_); -    assert(extend_.empty()); -    return; +void VertexNode::FinishRoot() { +  std::sort(hypos_.begin(), hypos_.end(), GreaterByScore()); +  extend_.clear(); +  // HACK: extend to one hypo so that root can be blank. +  state_.left.full = false; +  state_.left.length = 0; +  state_.right.length = 0; +  right_full_ = false; +  niceness_ = 0; +  policy_ = kPolicyAlternate; +  if (hypos_.size() == 1) { +    extend_.resize(1); +    extend_.front().AppendHypothesis(hypos_.front()); +    extend_.front().FinishedAppending(0, 0); +  } +  if (hypos_.empty()) { +    bound_ = -INFINITY; +  } else { +    bound_ = hypos_.front().score;    } -  if (extend_.size() == 1) { -    parent_ptr = extend_[0]; -    extend_[0]->RecursiveSortAndSet(context, parent_ptr); -    context.DeleteVertexNode(this); -    return; +} + +void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) { +  assert(!hypos_.empty()); +  assert(extend_.empty()); +  bound_ = hypos_.front().score; +  state_ = hypos_.front().state; +  bool all_full = state_.left.full; +  bool all_non_full = !state_.left.full; +  DetermineSame<lm::ngram::Left> left(state_.left, common_left); +  DetermineSame<lm::ngram::Right> right(state_.right, common_right); +  for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) { +    all_full &= i->state.left.full; +    all_non_full &= !i->state.left.full; +    left.Consider(i->state.left); +    right.Consider(i->state.right);    } -  for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { -    (*i)->RecursiveSortAndSet(context, *i); +  state_.left.full = all_full && left.Complete(); +  right_full_ = all_full && right.Complete(); +  state_.left.length = left.Shared(); +  state_.right.length = right.Shared(); + +  if (!all_full && !all_non_full) { +    policy_ = kPolicyAlternate; +  } else if (left.Complete()) { +    policy_ = kPolicyOneRight; +  } else if (right.Complete()) { +    policy_ = kPolicyOneLeft; +  } else { +    policy_ = kPolicyAlternate;    } -  std::sort(extend_.begin(), extend_.end(), GreaterByBound()); -  bound_ = extend_.front()->Bound(); +  niceness_ = state_.left.length + state_.right.length;  } -void VertexNode::SortAndSet(ContextBase &context) { -  // This is the root.  The root might be empty.   -  if (extend_.empty()) { -    bound_ = -INFINITY; -    return; +void VertexNode::BuildExtend() { +  // Already built. +  if (!extend_.empty()) return; +  // Nothing to build since this is a leaf. +  if (hypos_.size() <= 1) return; +  bool left_branch = true; +  switch (policy_) { +    case kPolicyAlternate: +      left_branch = (state_.left.length <= state_.right.length); +      break; +    case kPolicyOneLeft: +      left_branch = true; +      break; +    case kPolicyOneRight: +      left_branch = false; +      break; +  } +  if (left_branch) { +    Split(DivideLeft(state_.left.length), hypos_, extend_); +  } else { +    Split(DivideRight(state_.right.length), hypos_, extend_);    } -  // The root cannot be replaced.  There's always one transition.   -  for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { -    (*i)->RecursiveSortAndSet(context, *i); +  for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) { +    // TODO: provide more here for branching? +    i->FinishedAppending(state_.left.length, state_.right.length);    } -  std::sort(extend_.begin(), extend_.end(), GreaterByBound()); -  bound_ = extend_.front()->Bound();  }  } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index ca9a4fcd..81c3cfed 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -16,59 +16,74 @@ namespace search {  class ContextBase; +struct HypoState { +  History history; +  lm::ngram::ChartState state; +  Score score; +}; +  class VertexNode {    public: -    VertexNode() : end_() {} - -    void InitRoot() { -      extend_.clear(); -      state_.left.full = false; -      state_.left.length = 0; -      state_.right.length = 0; -      right_full_ = false; -      end_ = History(); +    VertexNode() {} + +    void InitRoot() { hypos_.clear(); } + +    /* The steps of building a VertexNode: +     * 1. Default construct. +     * 2. AppendHypothesis at least once, possibly multiple times. +     * 3. FinishAppending with the number of words on left and right guaranteed +     * to be common. +     * 4. If !Complete(), call BuildExtend to construct the extensions +     */ +    // Must default construct, call AppendHypothesis 1 or more times then do FinishedAppending. +    void AppendHypothesis(const NBestComplete &best) { +      assert(hypos_.empty() || !(hypos_.front().state == *best.state)); +      HypoState hypo; +      hypo.history = best.history; +      hypo.state = *best.state; +      hypo.score = best.score; +      hypos_.push_back(hypo); +    } +    void AppendHypothesis(const HypoState &hypo) { +      hypos_.push_back(hypo);      } -    lm::ngram::ChartState &MutableState() { return state_; } -    bool &MutableRightFull() { return right_full_; } +    // Sort hypotheses for the root. +    void FinishRoot(); -    void AddExtend(VertexNode *next) { -      extend_.push_back(next); -    } +    void FinishedAppending(const unsigned char common_left, const unsigned char common_right); -    void SetEnd(History end, Score score) { -      assert(!end_); -      end_ = end; -      bound_ = score; -    } -     -    void SortAndSet(ContextBase &context); +    void BuildExtend();      // Should only happen to a root node when the entire vertex is empty.         bool Empty() const { -      return !end_ && extend_.empty(); +      return hypos_.empty() && extend_.empty();      }      bool Complete() const { -      return end_; +      // HACK: prevent root from being complete.  TODO: allow root to be complete. +      return hypos_.size() == 1 && extend_.empty();      }      const lm::ngram::ChartState &State() const { return state_; }      bool RightFull() const { return right_full_; } +    // Priority relative to other non-terminals.  0 is highest. +    unsigned char Niceness() const { return niceness_; } +      Score Bound() const {        return bound_;      } -    unsigned char Length() const { -      return state_.left.length + state_.right.length; -    } -      // Will be invalid unless this is a leaf.    -    History End() const { return end_; } +    History End() const { +      assert(hypos_.size() == 1); +      return hypos_.front().history; +    } -    const VertexNode &operator[](size_t index) const { -      return *extend_[index]; +    VertexNode &operator[](size_t index) { +      assert(!extend_.empty()); +      return extend_[index];      }      size_t Size() const { @@ -76,22 +91,26 @@ class VertexNode {      }    private: -    void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); +    // Hypotheses to be split. +    std::vector<HypoState> hypos_; -    std::vector<VertexNode*> extend_; +    std::vector<VertexNode> extend_;      lm::ngram::ChartState state_;      bool right_full_; +    unsigned char niceness_; + +    unsigned char policy_; +      Score bound_; -    History end_;  };  class PartialVertex {    public:      PartialVertex() {} -    explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} +    explicit PartialVertex(VertexNode &back) : back_(&back), index_(0) {}      bool Empty() const { return back_->Empty(); } @@ -100,17 +119,14 @@ class PartialVertex {      const lm::ngram::ChartState &State() const { return back_->State(); }      bool RightFull() const { return back_->RightFull(); } -    Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } - -    unsigned char Length() const { return back_->Length(); } +    Score Bound() const { return index_ ? (*back_)[index_].Bound() : back_->Bound(); } -    bool HasAlternative() const { -      return index_ + 1 < back_->Size(); -    } +    unsigned char Niceness() const { return back_->Niceness(); }      // Split into continuation and alternative, rendering this the continuation.      bool Split(PartialVertex &alternative) {        assert(!Complete()); +      back_->BuildExtend();        bool ret;        if (index_ + 1 < back_->Size()) {          alternative.index_ = index_ + 1; @@ -129,7 +145,7 @@ class PartialVertex {      }    private: -    const VertexNode *back_; +    VertexNode *back_;      unsigned int index_;  }; @@ -139,10 +155,21 @@ class Vertex {    public:      Vertex() {} -    PartialVertex RootPartial() const { return PartialVertex(root_); } +    //PartialVertex RootFirst() const { return PartialVertex(right_); } +    PartialVertex RootAlternate() { return PartialVertex(root_); } +    //PartialVertex RootLast() const { return PartialVertex(left_); } + +    bool Empty() const { +      return root_.Empty(); +    } + +    Score Bound() const { +      return root_.Bound(); +    } -    History BestChild() const { -      PartialVertex top(RootPartial()); +    History BestChild() { +      // left_ and right_ are not set at the root. +      PartialVertex top(RootAlternate());        if (top.Empty()) {          return History();        } else { @@ -158,6 +185,12 @@ class Vertex {      template <class Output> friend class VertexGenerator;      template <class Output> friend class RootVertexGenerator;      VertexNode root_; + +    // These will not be set for the root vertex. +    // Branches only on left state. +    //VertexNode left_; +    // Branches only on right state. +    //VertexNode right_;  };  } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 646b8189..91000012 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -4,10 +4,8 @@  #include "search/edge.hh"  #include "search/types.hh"  #include "search/vertex.hh" -#include "util/exception.hh"  #include <boost/unordered_map.hpp> -#include <boost/version.hpp>  namespace lm {  namespace ngram { @@ -19,45 +17,25 @@ namespace search {  class ContextBase; -#if BOOST_VERSION > 104200 -// Parallel structure to VertexNode.   -struct Trie { -  Trie() : under(NULL) {} - -  VertexNode *under; -  boost::unordered_map<uint64_t, Trie> extend; -}; - -void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); - -#endif // BOOST_VERSION -  // Output makes the single-best or n-best list.     template <class Output> class VertexGenerator {    public: -    VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { -      gen.root_.InitRoot(); -    } +    VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {}      void NewHypothesis(PartialEdge partial) {        nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);      }      void FinishedSearch() { -#if BOOST_VERSION > 104200 -      Trie root; -      root.under = &gen_.root_; +      gen_.root_.InitRoot();        for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { -        AddHypothesis(context_, root, nbest_.Complete(i->second)); +        gen_.root_.AppendHypothesis(nbest_.Complete(i->second));        }        existing_.clear(); -      root.under->SortAndSet(context_); -#else -      UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); -#endif +      gen_.root_.FinishRoot();      } -    const Vertex &Generating() const { return gen_; } +    Vertex &Generating() { return gen_; }    private:      ContextBase &context_; @@ -84,8 +62,8 @@ template <class Output> class RootVertexGenerator {      void FinishedSearch() {        gen_.root_.InitRoot(); -      NBestComplete completed(out_.Complete(combine_)); -      gen_.root_.SetEnd(completed.history, completed.score); +      gen_.root_.AppendHypothesis(out_.Complete(combine_)); +      gen_.root_.FinishRoot();      }    private: diff --git a/klm/util/double-conversion/utils.h b/klm/util/double-conversion/utils.h index 767094b8..9ccb3b65 100644 --- a/klm/util/double-conversion/utils.h +++ b/klm/util/double-conversion/utils.h @@ -218,7 +218,8 @@ class StringBuilder {    // 0-characters; use the Finalize() method to terminate the string    // instead.    void AddCharacter(char c) { -    ASSERT(c != '\0'); +    // I just extract raw data not a cstr so null is fine. +    //ASSERT(c != '\0');      ASSERT(!is_finalized() && position_ < buffer_.length());      buffer_[position_++] = c;    } @@ -233,7 +234,8 @@ class StringBuilder {    // builder. The input string must have enough characters.    void AddSubstring(const char* s, int n) {      ASSERT(!is_finalized() && position_ + n < buffer_.length()); -    ASSERT(static_cast<size_t>(n) <= strlen(s)); +    // I just extract raw data not a cstr so null is fine. +    //ASSERT(static_cast<size_t>(n) <= strlen(s));      memmove(&buffer_[position_], s, n * kCharSize);      position_ += n;    } @@ -253,7 +255,8 @@ class StringBuilder {      buffer_[position_] = '\0';      // Make sure nobody managed to add a 0-character to the      // buffer while building the string. -    ASSERT(strlen(buffer_.start()) == static_cast<size_t>(position_)); +    // I just extract raw data not a cstr so null is fine. +    //ASSERT(strlen(buffer_.start()) == static_cast<size_t>(position_));      position_ = -1;      ASSERT(is_finalized());      return buffer_.start(); @@ -296,7 +299,11 @@ template <class Dest, class Source>  inline Dest BitCast(const Source& source) {    // Compile time assertion: sizeof(Dest) == sizeof(Source)    // A compile error here means your Dest and Source have different sizes. -  typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1]; +  typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1] +#if __GNUC__ > 4 || __GNUC__ == 4 && __GNUC_MINOR__ >= 8 +      __attribute__((unused)) +#endif +      ;    Dest dest;    memmove(&dest, &source, sizeof(dest)); diff --git a/klm/util/file.cc b/klm/util/file.cc index c7d8e23b..bef04cb1 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -116,7 +116,7 @@ std::size_t GuardLarge(std::size_t size) {    // The following operating systems have broken read/write/pread/pwrite that    // only supports up to 2^31.  #if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID) -  return std::min(static_cast<std::size_t>(INT_MAX), size); +  return std::min(static_cast<std::size_t>(static_cast<unsigned>(-1)), size);  #else    return size;  #endif @@ -209,7 +209,7 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) {  #endif      errno = 0;      do { -      ret =  +      ret =  #if defined(_WIN32) || defined(_WIN64)          _write  #else @@ -229,7 +229,7 @@ void WriteOrThrow(FILE *to, const void *data, std::size_t size) {  }  void FSyncOrThrow(int fd) { -// Apparently windows doesn't have fsync?   +// Apparently windows doesn't have fsync?  #if !defined(_WIN32) && !defined(_WIN64)    UTIL_THROW_IF_ARG(-1 == fsync(fd), FDException, (fd), "while syncing");  #endif @@ -248,7 +248,7 @@ template <> struct CheckOffT<8> {  typedef CheckOffT<sizeof(off_t)>::True IgnoredType;  #endif -// Can't we all just get along?   +// Can't we all just get along?  void InternalSeek(int fd, int64_t off, int whence) {    if (  #if defined(_WIN32) || defined(_WIN64) @@ -457,9 +457,9 @@ bool TryName(int fd, std::string &out) {    std::ostringstream convert;    convert << fd;    name += convert.str(); -   +    struct stat sb; -  if (-1 == lstat(name.c_str(), &sb))  +  if (-1 == lstat(name.c_str(), &sb))      return false;    out.resize(sb.st_size + 1);    ssize_t ret = readlink(name.c_str(), &out[0], sb.st_size + 1); @@ -471,7 +471,7 @@ bool TryName(int fd, std::string &out) {    }    out.resize(ret);    // Don't use the non-file names. -  if (!out.empty() && out[0] != '/')  +  if (!out.empty() && out[0] != '/')      return false;    return true;  #endif diff --git a/klm/util/pool.cc b/klm/util/pool.cc index 429ba158..db72a8ec 100644 --- a/klm/util/pool.cc +++ b/klm/util/pool.cc @@ -25,7 +25,9 @@ void Pool::FreeAll() {  }  void *Pool::More(std::size_t size) { -  std::size_t amount = std::max(static_cast<size_t>(32) << free_list_.size(), size); +  // Double until we hit 2^21 (2 MB).  Then grow in 2 MB blocks.  +  std::size_t desired_size = static_cast<size_t>(32) << std::min(static_cast<std::size_t>(16), free_list_.size()); +  std::size_t amount = std::max(desired_size, size);    uint8_t *ret = static_cast<uint8_t*>(MallocOrThrow(amount));    free_list_.push_back(ret);    current_ = ret + size; diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 57866ff9..51a2944d 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -109,9 +109,20 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry          if (equal_(got, key)) { out = i; return true; }          if (equal_(got, invalid_)) return false;          if (++i == end_) i = begin_; -      }    +      } +    } + +    // Like UnsafeMutableFind, but the key must be there. +    template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) { +       for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) { +        Key got(i->GetKey()); +        if (equal_(got, key)) { return i; } +        assert(!equal_(got, invalid_)); +        if (++i == end_) i = begin_; +      }      } +      template <class Key> bool Find(const Key key, ConstIterator &out) const {  #ifdef DEBUG        assert(initialized_); @@ -124,6 +135,16 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry        }          } +    // Like Find but we're sure it must be there. +    template <class Key> ConstIterator MustFind(const Key key) const { +      for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) { +        Key got(i->GetKey()); +        if (equal_(got, key)) { return i; } +        assert(!equal_(got, invalid_)); +        if (++i == end_) i = begin_; +      } +    } +      void Clear() {        Entry invalid;        invalid.SetKey(invalid_); diff --git a/klm/util/proxy_iterator.hh b/klm/util/proxy_iterator.hh index 121a45fa..0ee1716f 100644 --- a/klm/util/proxy_iterator.hh +++ b/klm/util/proxy_iterator.hh @@ -6,11 +6,11 @@  /* This is a RandomAccessIterator that uses a proxy to access the underlying   * data.  Useful for packing data at bit offsets but still using STL - * algorithms.   + * algorithms.   *   * Normally I would use boost::iterator_facade but some people are too lazy to   * install boost and still want to use my language model.  It's amazing how - * many operators an iterator has.  + * many operators an iterator has.   *   * The Proxy needs to provide:   *   class InnerIterator; @@ -22,15 +22,15 @@   *   operator<(InnerIterator)   *   operator+=(std::ptrdiff_t)   *   operator-(InnerIterator) - * and of course whatever Proxy needs to dereference it.   + * and of course whatever Proxy needs to dereference it.   * - * It's also a good idea to specialize std::swap for Proxy.   + * It's also a good idea to specialize std::swap for Proxy.   */  namespace util {  template <class Proxy> class ProxyIterator {    private: -    // Self.   +    // Self.      typedef ProxyIterator<Proxy> S;      typedef typename Proxy::InnerIterator InnerIterator; @@ -38,16 +38,21 @@ template <class Proxy> class ProxyIterator {      typedef std::random_access_iterator_tag iterator_category;      typedef typename Proxy::value_type value_type;      typedef std::ptrdiff_t difference_type; -    typedef Proxy reference; +    typedef Proxy & reference;      typedef Proxy * pointer;      ProxyIterator() {} -    // For cast from non const to const.   +    // For cast from non const to const.      template <class AlternateProxy> ProxyIterator(const ProxyIterator<AlternateProxy> &in) : p_(*in) {}      explicit ProxyIterator(const Proxy &p) : p_(p) {} -    // p_'s operator= does value copying, but here we want iterator copying.   +    // p_'s swap does value swapping, but here we want iterator swapping +    friend inline void swap(ProxyIterator<Proxy> &first, ProxyIterator<Proxy> &second) { +      swap(first.I(), second.I()); +    } + +    // p_'s operator= does value copying, but here we want iterator copying.      S &operator=(const S &other) {        I() = other.I();        return *this; @@ -72,8 +77,8 @@ template <class Proxy> class ProxyIterator {      std::ptrdiff_t operator-(const S &other) const { return I() - other.I(); } -    Proxy operator*() { return p_; } -    const Proxy operator*() const { return p_; } +    Proxy &operator*() { return p_; } +    const Proxy &operator*() const { return p_; }      Proxy *operator->() { return &p_; }      const Proxy *operator->() const { return &p_; }      Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); } diff --git a/klm/util/sized_iterator.hh b/klm/util/sized_iterator.hh index cf998953..dce8f229 100644 --- a/klm/util/sized_iterator.hh +++ b/klm/util/sized_iterator.hh @@ -36,6 +36,11 @@ class SizedInnerIterator {      void *Data() { return ptr_; }      std::size_t EntrySize() const { return size_; } +    friend inline void swap(SizedInnerIterator &first, SizedInnerIterator &second) { +      std::swap(first.ptr_, second.ptr_); +      std::swap(first.size_, second.size_); +    } +    private:      uint8_t *ptr_;      std::size_t size_; @@ -64,9 +69,19 @@ class SizedProxy {      const void *Data() const { return inner_.Data(); }      void *Data() { return inner_.Data(); } +  /** +     // TODO: this (deep) swap was recently added. why? if any std heap sort etc +     // algs are using swap, that's going to be worse performance than using +     // =. i'm not sure why we *want* a deep swap. if C++11 compilers are +     // choosing between move constructor and swap, then we'd better implement a +     // (deep) move constructor. it may also be that this is moot since i made +     // ProxyIterator a reference and added a shallow ProxyIterator swap? (I +     // need Ken or someone competent to judge whether that's correct also. - +     // let me know at graehl@gmail.com +  */      friend void swap(SizedProxy &first, SizedProxy &second) {        std::swap_ranges( -          static_cast<char*>(first.inner_.Data()),  +          static_cast<char*>(first.inner_.Data()),            static_cast<char*>(first.inner_.Data()) + first.inner_.EntrySize(),            static_cast<char*>(second.inner_.Data()));      } @@ -87,7 +102,7 @@ typedef ProxyIterator<SizedProxy> SizedIterator;  inline SizedIterator SizedIt(void *ptr, std::size_t size) { return SizedIterator(SizedProxy(ptr, size)); } -// Useful wrapper for a comparison function i.e. sort.   +// Useful wrapper for a comparison function i.e. sort.  template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public std::binary_function<const Proxy &, const Proxy &, bool> {    public:      explicit SizedCompare(const Delegate &delegate = Delegate()) : delegate_(delegate) {} @@ -106,7 +121,7 @@ template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public      }      const Delegate &GetDelegate() const { return delegate_; } -     +    private:      const Delegate delegate_;  }; diff --git a/klm/util/stream/chain.hh b/klm/util/stream/chain.hh index 154b9b33..0cc83a85 100644 --- a/klm/util/stream/chain.hh +++ b/klm/util/stream/chain.hh @@ -122,7 +122,7 @@ class Chain {        threads_.push_back(new Thread(Complete(), kRecycle));      } -    Chain &operator>>(const Recycler &recycle) { +    Chain &operator>>(const Recycler &) {        CompleteLoop();        return *this;      } diff --git a/klm/util/usage.cc b/klm/util/usage.cc index ad4dc7b4..8db375e1 100644 --- a/klm/util/usage.cc +++ b/klm/util/usage.cc @@ -5,51 +5,95 @@  #include <fstream>  #include <ostream>  #include <sstream> +#include <set> +#include <string>  #include <string.h>  #include <ctype.h>  #if !defined(_WIN32) && !defined(_WIN64)  #include <sys/resource.h>  #include <sys/time.h> +#include <time.h>  #include <unistd.h>  #endif  namespace util { -namespace {  #if !defined(_WIN32) && !defined(_WIN64) +namespace { + +// On Mac OS X, clock_gettime is not implemented. +// CLOCK_MONOTONIC is not defined either. +#ifdef __MACH__ +#define CLOCK_MONOTONIC 0 + +int clock_gettime(int clk_id, struct timespec *tp) { +  struct timeval tv; +  gettimeofday(&tv, NULL); +  tp->tv_sec = tv.tv_sec; +  tp->tv_nsec = tv.tv_usec * 1000; +  return 0; +} +#endif // __MACH__ +  float FloatSec(const struct timeval &tv) {    return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000.0);  } -#endif +float FloatSec(const struct timespec &tv) { +  return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_nsec) / 1000000000.0); +}  const char *SkipSpaces(const char *at) { -  for (; *at == ' '; ++at) {} +  for (; *at == ' ' || *at == '\t'; ++at) {}    return at;  } + +class RecordStart { +  public: +    RecordStart() { +      clock_gettime(CLOCK_MONOTONIC, &started_); +    } + +    const struct timespec &Started() const { +      return started_; +    } + +  private: +    struct timespec started_; +}; + +const RecordStart kRecordStart;  } // namespace +#endif  void PrintUsage(std::ostream &out) {  #if !defined(_WIN32) && !defined(_WIN64) +  // Linux doesn't set memory usage in getrusage :-( +  std::set<std::string> headers; +  headers.insert("VmPeak:"); +  headers.insert("VmRSS:"); +  headers.insert("Name:"); + +  std::ifstream status("/proc/self/status", std::ios::in); +  std::string header, value; +  while ((status >> header) && getline(status, value)) { +    if (headers.find(header) != headers.end()) { +      out << header << SkipSpaces(value.c_str()) << '\t'; +    } +  } +    struct rusage usage; -  if (getrusage(RUSAGE_SELF, &usage)) { +  if (getrusage(RUSAGE_CHILDREN, &usage)) {      perror("getrusage");      return;    } -  out << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n'; -  out << "CPU\t" << (FloatSec(usage.ru_utime) + FloatSec(usage.ru_stime)) << '\n'; -  // Linux doesn't set memory usage :-(.   -  std::ifstream status("/proc/self/status", std::ios::in); -  std::string line; -  while (getline(status, line)) { -    if (!strncmp(line.c_str(), "VmRSS:\t", 7)) { -      out << "RSSCur\t" << SkipSpaces(line.c_str() + 7) << '\n'; -      break; -    } else if (!strncmp(line.c_str(), "VmPeak:\t", 8)) { -      out << "VmPeak\t" << SkipSpaces(line.c_str() + 8) << '\n'; -    } -  } -  out << "RSSMax\t" << usage.ru_maxrss << " kB" << '\n'; +  out << "RSSMax:" << usage.ru_maxrss << " kB" << '\t'; +  out << "user:" << FloatSec(usage.ru_utime) << "\tsys:" << FloatSec(usage.ru_stime) << '\t'; +  out << "CPU:" << (FloatSec(usage.ru_utime) + FloatSec(usage.ru_stime)); + +  struct timespec current; +  clock_gettime(CLOCK_MONOTONIC, ¤t); +  out << "\treal:" << (FloatSec(current) - FloatSec(kRecordStart.Started())) << '\n';  #endif  } | 
