diff options
| author | Patrick Simianer <p@simianer.de> | 2011-09-09 15:33:35 +0200 | 
|---|---|---|
| committer | Patrick Simianer <p@simianer.de> | 2011-09-23 19:13:58 +0200 | 
| commit | edb0cc0cbae1e75e4aeedb6360eab325effe6573 (patch) | |
| tree | a2fed4614b88f177f91e88fef3b269fa75e80188 | |
| parent | 2e6ef7cbec77b22ce3d64416a5ada3a6c081f9e2 (diff) | |
partial merge, ruleid feature
46 files changed, 1490 insertions, 440 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 244da2de..e5f7505f 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -61,10 +61,12 @@ libcdec_a_SOURCES = \    phrasetable_fst.cc \    trule.cc \    ff.cc \ +  ff_rules.cc \    ff_wordset.cc \    ff_charset.cc \    ff_lm.cc \    ff_klm.cc \ +  ff_ngrams.cc \    ff_spans.cc \    ff_ruleshape.cc \    ff_wordalign.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 31f88a4f..588842f1 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -4,10 +4,12 @@  #include "ff_spans.h"  #include "ff_lm.h"  #include "ff_klm.h" +#include "ff_ngrams.h"  #include "ff_csplit.h"  #include "ff_wordalign.h"  #include "ff_tagger.h"  #include "ff_factory.h" +#include "ff_rules.h"  #include "ff_ruleshape.h"  #include "ff_bleu.h"  #include "ff_lm_fsa.h" @@ -51,12 +53,11 @@ void register_feature_functions() {    ff_registry.Register("RandLM", new FFFactory<LanguageModelRandLM>);  #endif    ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>()); +  ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>()); +  ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>());    ff_registry.Register("RuleNgramFeatures", new FFFactory<RuleNgramFeatures>());    ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory<CMR2008ReorderingFeatures>()); -  ff_registry.Register("KLanguageModel", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >()); -  ff_registry.Register("KLanguageModel_Trie", new FFFactory<KLanguageModel<lm::ngram::TrieModel> >()); -  ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory<KLanguageModel<lm::ngram::QuantTrieModel> >()); -  ff_registry.Register("KLanguageModel_Probing", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >()); +  ff_registry.Register("KLanguageModel", new KLanguageModelFactory());    ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>);    ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>);    ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 9b7fe2d3..24dcb9c3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -9,6 +9,7 @@  #include "stringlib.h"  #include "hg.h"  #include "tdict.h" +#include "lm/model.hh"  #include "lm/enumerate_vocab.hh"  using namespace std; @@ -434,8 +435,37 @@ void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state,      features->set_value(oov_fid_, oovs);  } -// instantiate templates -template class KLanguageModel<lm::ngram::ProbingModel>; -template class KLanguageModel<lm::ngram::TrieModel>; -template class KLanguageModel<lm::ngram::QuantTrieModel>; +template <class Model> boost::shared_ptr<FeatureFunction> CreateModel(const std::string ¶m) { +  KLanguageModel<Model> *ret = new KLanguageModel<Model>(param); +  ret->Init(); +  return boost::shared_ptr<FeatureFunction>(ret); +} +boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string param) const { +  using namespace lm::ngram; +  std::string filename, ignored_map; +  bool ignored_markers; +  std::string ignored_featname; +  ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); +  ModelType m; +  if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + +  switch (m) { +    case HASH_PROBING: +      return CreateModel<ProbingModel>(param); +    case TRIE_SORTED: +      return CreateModel<TrieModel>(param); +    case ARRAY_TRIE_SORTED: +      return CreateModel<ArrayTrieModel>(param); +    case QUANT_TRIE_SORTED: +      return CreateModel<QuantTrieModel>(param); +    case QUANT_ARRAY_TRIE_SORTED: +      return CreateModel<QuantArrayTrieModel>(param); +    default: +      UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); +  } +} + +std::string  KLanguageModelFactory::usage(bool params,bool verbose) const { +  return KLanguageModel<lm::ngram::Model>::usage(params, verbose); +} diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 5eafe8be..6efe50f6 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -4,8 +4,8 @@  #include <vector>  #include <string> +#include "ff_factory.h"  #include "ff.h" -#include "lm/model.hh"  template <class Model> struct KLanguageModelImpl; @@ -34,4 +34,9 @@ class KLanguageModel : public FeatureFunction {    KLanguageModelImpl<Model>* pimpl_;  }; +struct KLanguageModelFactory : public FactoryBase<FeatureFunction> { +  FP Create(std::string param) const; +  std::string usage(bool params,bool verbose) const; +}; +  #endif diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc new file mode 100644 index 00000000..04dd1906 --- /dev/null +++ b/decoder/ff_ngrams.cc @@ -0,0 +1,341 @@ +#include "ff_ngrams.h" + +#include <cstring> +#include <iostream> + +#include <boost/scoped_ptr.hpp> + +#include "filelib.h" +#include "stringlib.h" +#include "hg.h" +#include "tdict.h" + +using namespace std; + +static const unsigned char HAS_FULL_CONTEXT = 1; +static const unsigned char HAS_EOS_ON_RIGHT = 2; +static const unsigned char MASK             = 7; + +namespace { +template <unsigned MAX_ORDER = 5> +struct State { +  explicit State() { +    memset(state, 0, sizeof(state)); +  } +  explicit State(int order) { +    memset(state, 0, (order - 1) * sizeof(WordID)); +  } +  State<MAX_ORDER>(char order, const WordID* mem) { +    memcpy(state, mem, (order - 1) * sizeof(WordID)); +  } +  State(const State<MAX_ORDER>& other) { +    memcpy(state, other.state, sizeof(state)); +  } +  const State& operator=(const State<MAX_ORDER>& other) { +    memcpy(state, other.state, sizeof(state)); +  } +  explicit State(const State<MAX_ORDER>& other, unsigned order, WordID extend) { +    char om1 = order - 1; +    assert(om1 > 0); +    for (char i = 1; i < om1; ++i) state[i - 1]= other.state[i]; +    state[om1 - 1] = extend; +  } +  const WordID& operator[](size_t i) const { return state[i]; } +  WordID& operator[](size_t i) { return state[i]; } +  WordID state[MAX_ORDER]; +}; +} + +namespace { +  string Escape(const string& x) { +    string y = x; +    for (int i = 0; i < y.size(); ++i) { +      if (y[i] == '=') y[i]='_'; +      if (y[i] == ';') y[i]='_'; +    } +    return y; +  } +} + +class NgramDetectorImpl { + +  // returns the number of unscored words at the left edge of a span +  inline int UnscoredSize(const void* state) const { +    return *(static_cast<const char*>(state) + unscored_size_offset_); +  } + +  inline void SetUnscoredSize(int size, void* state) const { +    *(static_cast<char*>(state) + unscored_size_offset_) = size; +  } + +  inline State<5> RemnantLMState(const void* cstate) const { +    return State<5>(order_, static_cast<const WordID*>(cstate)); +  } + +  inline const State<5> BeginSentenceState() const { +    State<5> state(order_); +    state.state[0] = kSOS_; +    return state; +  } + +  inline void SetRemnantLMState(const State<5>& lmstate, void* state) const { +    // if we were clever, we could use the memory pointed to by state to do all +    // the work, avoiding this copy +    memcpy(state, lmstate.state, (order_-1) * sizeof(WordID)); +  } + +  WordID IthUnscoredWord(int i, const void* state) const { +    const WordID* const mem = reinterpret_cast<const WordID*>(static_cast<const char*>(state) + unscored_words_offset_); +    return mem[i]; +  } + +  void SetIthUnscoredWord(int i, const WordID index, void *state) const { +    WordID* mem = reinterpret_cast<WordID*>(static_cast<char*>(state) + unscored_words_offset_); +    mem[i] = index; +  } + +  inline bool GetFlag(const void *state, unsigned char flag) const { +    return (*(static_cast<const char*>(state) + is_complete_offset_) & flag); +  } + +  inline void SetFlag(bool on, unsigned char flag, void *state) const { +    if (on) { +      *(static_cast<char*>(state) + is_complete_offset_) |= flag; +    } else { +      *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag); +    } +  } + +  inline bool HasFullContext(const void *state) const { +    return GetFlag(state, HAS_FULL_CONTEXT); +  } + +  inline void SetHasFullContext(bool flag, void *state) const { +    SetFlag(flag, HAS_FULL_CONTEXT, state); +  } + +  void FireFeatures(const State<5>& state, WordID cur, SparseVector<double>* feats) { +    FidTree* ft = &fidroot_; +    int n = 0; +    WordID buf[10]; +    int ci = order_ - 1; +    WordID curword = cur; +    while(curword) { +      buf[n] = curword; +      int& fid = ft->fids[curword]; +      ++n; +      if (!fid) { +        const char* code="_UBT456789"; // prefix code (unigram, bigram, etc.) +        ostringstream os; +        os << code[n] << ':'; +        for (int i = n-1; i >= 0; --i) { +          os << (i != n-1 ? "_" : ""); +          const string& tok = TD::Convert(buf[i]); +          if (tok.find('=') == string::npos) +            os << tok; +          else +            os << Escape(tok); +        } +        fid = FD::Convert(os.str()); +      } +      feats->set_value(fid, 1); +      ft = &ft->levels[curword]; +      --ci; +      if (ci < 0) break; +      curword = state[ci]; +    } +  } + + public: +  void LookupWords(const TRule& rule, const vector<const void*>& ant_states, SparseVector<double>* feats, SparseVector<double>* est_feats, void* remnant) { +    double sum = 0.0; +    double est_sum = 0.0; +    int num_scored = 0; +    int num_estimated = 0; +    bool saw_eos = false; +    bool has_some_history = false; +    State<5> state; +    const vector<WordID>& e = rule.e(); +    bool context_complete = false; +    for (int j = 0; j < e.size(); ++j) { +      if (e[j] < 1) {   // handle non-terminal substitution +        const void* astate = (ant_states[-e[j]]); +        int unscored_ant_len = UnscoredSize(astate); +        for (int k = 0; k < unscored_ant_len; ++k) { +          const WordID cur_word = IthUnscoredWord(k, astate); +          const bool is_oov = (cur_word == 0); +          SparseVector<double> p; +          if (cur_word == kSOS_) { +            state = BeginSentenceState(); +            if (has_some_history) {  // this is immediately fully scored, and bad +              p.set_value(FD::Convert("Malformed"), 1.0); +              context_complete = true; +            } else {  // this might be a real <s> +              num_scored = max(0, order_ - 2); +            } +          } else { +            FireFeatures(state, cur_word, &p); +            const State<5> scopy = State<5>(state, order_, cur_word); +            state = scopy; +            if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); } +            saw_eos = (cur_word == kEOS_); +          } +          has_some_history = true; +          ++num_scored; +          if (!context_complete) { +            if (num_scored >= order_) context_complete = true; +          } +          if (context_complete) { +            (*feats) += p; +          } else { +            if (remnant) +              SetIthUnscoredWord(num_estimated, cur_word, remnant); +            ++num_estimated; +            (*est_feats) += p; +          } +        } +        saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); +        if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 +          state = RemnantLMState(astate); +          context_complete = true; +        } +      } else {   // handle terminal +        const WordID cur_word = e[j]; +        SparseVector<double> p; +        if (cur_word == kSOS_) { +          state = BeginSentenceState(); +          if (has_some_history) {  // this is immediately fully scored, and bad +            p.set_value(FD::Convert("Malformed"), -100); +            context_complete = true; +          } else {  // this might be a real <s> +            num_scored = max(0, order_ - 2); +          } +        } else { +          FireFeatures(state, cur_word, &p); +          const State<5> scopy = State<5>(state, order_, cur_word); +          state = scopy; +          if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); } +          saw_eos = (cur_word == kEOS_); +        } +        has_some_history = true; +        ++num_scored; +        if (!context_complete) { +          if (num_scored >= order_) context_complete = true; +        } +        if (context_complete) { +          (*feats) += p; +        } else { +          if (remnant) +            SetIthUnscoredWord(num_estimated, cur_word, remnant); +          ++num_estimated; +          (*est_feats) += p; +        } +      } +    } +    if (remnant) { +      SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); +      SetRemnantLMState(state, remnant); +      SetUnscoredSize(num_estimated, remnant); +      SetHasFullContext(context_complete || (num_scored >= order_), remnant); +    } +  } + +  // this assumes no target words on final unary -> goal rule.  is that ok? +  // for <s> (n-1 left words) and (n-1 right words) </s> +  void FinalTraversal(const void* state, SparseVector<double>* feats) { +    if (add_sos_eos_) {  // rules do not produce <s> </s>, so do it here +      SetRemnantLMState(BeginSentenceState(), dummy_state_); +      SetHasFullContext(1, dummy_state_); +      SetUnscoredSize(0, dummy_state_); +      dummy_ants_[1] = state; +      LookupWords(*dummy_rule_, dummy_ants_, feats, NULL, NULL); +    } else {  // rules DO produce <s> ... </s> +#if 0 +      double p = 0; +      if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } +      if (UnscoredSize(state) > 0) {  // are there unscored words +        if (kSOS_ != IthUnscoredWord(0, state)) { +          p -= 100 * UnscoredSize(state); +        } +      } +      return p; +#endif +    } +  } + + public: +  explicit NgramDetectorImpl(bool explicit_markers) : +      kCDEC_UNK(TD::Convert("<unk>")) , +      add_sos_eos_(!explicit_markers) { +    order_ = 3; +    state_size_ = (order_ - 1) * sizeof(WordID) + 2 + (order_ - 1) * sizeof(WordID); +    unscored_size_offset_ = (order_ - 1) * sizeof(WordID); +    is_complete_offset_ = unscored_size_offset_ + 1; +    unscored_words_offset_ = is_complete_offset_ + 1; + +    // special handling of beginning / ending sentence markers +    dummy_state_ = new char[state_size_]; +    memset(dummy_state_, 0, state_size_); +    dummy_ants_.push_back(dummy_state_); +    dummy_ants_.push_back(NULL); +    dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0")); +    kSOS_ = TD::Convert("<s>"); +    kEOS_ = TD::Convert("</s>"); +  } + +  ~NgramDetectorImpl() { +    delete[] dummy_state_; +  } + +  int ReserveStateSize() const { return state_size_; } + + private: +  const WordID kCDEC_UNK; +  WordID kSOS_;  // <s> - requires special handling. +  WordID kEOS_;  // </s> +  const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s> +                     // if this is true, FinalTransitionFeatures will "add" <s> and </s> +                     // if false, FinalTransitionFeatures will score anything with the +                     // markers in the right place (i.e., the beginning and end of +                     // the sentence) with 0, and anything else with -100 + +  int order_; +  int state_size_; +  int unscored_size_offset_; +  int is_complete_offset_; +  int unscored_words_offset_; +  char* dummy_state_; +  vector<const void*> dummy_ants_; +  TRulePtr dummy_rule_; +  struct FidTree { +    map<WordID, int> fids; +    map<WordID, FidTree> levels; +  }; +  mutable FidTree fidroot_; +}; + +NgramDetector::NgramDetector(const string& param) { +  string filename, mapfile, featname; +  bool explicit_markers = (param == "-x"); +  pimpl_ = new NgramDetectorImpl(explicit_markers); +  SetStateSize(pimpl_->ReserveStateSize()); +} + +NgramDetector::~NgramDetector() { +  delete pimpl_; +} + +void NgramDetector::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, +                                          const Hypergraph::Edge& edge, +                                          const vector<const void*>& ant_states, +                                          SparseVector<double>* features, +                                          SparseVector<double>* estimated_features, +                                          void* state) const { +  pimpl_->LookupWords(*edge.rule_, ant_states, features, estimated_features, state); +} + +void NgramDetector::FinalTraversalFeatures(const void* ant_state, +                                           SparseVector<double>* features) const { +  pimpl_->FinalTraversal(ant_state, features); +} + diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h new file mode 100644 index 00000000..82f61b33 --- /dev/null +++ b/decoder/ff_ngrams.h @@ -0,0 +1,29 @@ +#ifndef _NGRAMS_FF_H_ +#define _NGRAMS_FF_H_ + +#include <vector> +#include <map> +#include <string> + +#include "ff.h" + +struct NgramDetectorImpl; +class NgramDetector : public FeatureFunction { + public: +  // param = "filename.lm [-o n]" +  NgramDetector(const std::string& param); +  ~NgramDetector(); +  virtual void FinalTraversalFeatures(const void* context, +                                      SparseVector<double>* features) const; + protected: +  virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, +                                     const Hypergraph::Edge& edge, +                                     const std::vector<const void*>& ant_contexts, +                                     SparseVector<double>* features, +                                     SparseVector<double>* estimated_features, +                                     void* out_context) const; + private: +  NgramDetectorImpl* pimpl_; +}; + +#endif diff --git a/decoder/ff_rules.cc b/decoder/ff_rules.cc new file mode 100644 index 00000000..bd4c4cc0 --- /dev/null +++ b/decoder/ff_rules.cc @@ -0,0 +1,107 @@ +#include "ff_rules.h" + +#include <sstream> +#include <cassert> +#include <cmath> + +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" +#include "verbose.h" + +using namespace std; + +namespace { +  string Escape(const string& x) { +    string y = x; +    for (int i = 0; i < y.size(); ++i) { +      if (y[i] == '=') y[i]='_'; +      if (y[i] == ';') y[i]='_'; +    } +    return y; +  } +} + +RuleIdentityFeatures::RuleIdentityFeatures(const std::string& param) { +} + +void RuleIdentityFeatures::PrepareForInput(const SentenceMetadata& smeta) { +//  std::map<const TRule*, SparseVector<double> > +  rule2_fid_.clear(); +} + +void RuleIdentityFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, +                                         const Hypergraph::Edge& edge, +                                         const vector<const void*>& ant_contexts, +                                         SparseVector<double>* features, +                                         SparseVector<double>* estimated_features, +                                         void* context) const { +  map<const TRule*, int>::iterator it = rule2_fid_.find(edge.rule_.get()); +  if (it == rule2_fid_.end()) { +    const TRule& rule = *edge.rule_; +    ostringstream os; +    os << "R:"; +    if (rule.lhs_ < 0) os << TD::Convert(-rule.lhs_) << ':'; +    for (unsigned i = 0; i < rule.f_.size(); ++i) { +      if (i > 0) os << '_'; +      WordID w = rule.f_[i]; +      if (w < 0) { os << 'N'; w = -w; } +      assert(w > 0); +      os << TD::Convert(w); +    } +    os << ':'; +    for (unsigned i = 0; i < rule.e_.size(); ++i) { +      if (i > 0) os << '_'; +      WordID w = rule.e_[i]; +      if (w <= 0) { +        os << 'N' << (1-w); +      } else { +        os << TD::Convert(w); +      } +    } +    it = rule2_fid_.insert(make_pair(&rule, FD::Convert(Escape(os.str())))).first; +  } +  features->add_value(it->second, 1); +} + +RuleNgramFeatures::RuleNgramFeatures(const std::string& param) { +} + +void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) { +//  std::map<const TRule*, SparseVector<double> > +  rule2_feats_.clear(); +} + +void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, +                                         const Hypergraph::Edge& edge, +                                         const vector<const void*>& ant_contexts, +                                         SparseVector<double>* features, +                                         SparseVector<double>* estimated_features, +                                         void* context) const { +  map<const TRule*, SparseVector<double> >::iterator it = rule2_feats_.find(edge.rule_.get()); +  if (it == rule2_feats_.end()) { +    const TRule& rule = *edge.rule_; +    it = rule2_feats_.insert(make_pair(&rule, SparseVector<double>())).first; +    SparseVector<double>& f = it->second; +    string prev = "<r>"; +    for (int i = 0; i < rule.f_.size(); ++i) { +      WordID w = rule.f_[i]; +      if (w < 0) w = -w; +      assert(w > 0); +      const string& cur = TD::Convert(w); +      ostringstream os; +      os << "RB:" << prev << '_' << cur; +      const int fid = FD::Convert(Escape(os.str())); +      if (fid <= 0) return; +      f.add_value(fid, 1.0); +      prev = cur; +    } +    ostringstream os; +    os << "RB:" << prev << '_' << "</r>"; +    f.set_value(FD::Convert(Escape(os.str())), 1.0); +  } +  (*features) += it->second; +} + diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h new file mode 100644 index 00000000..48d8bd05 --- /dev/null +++ b/decoder/ff_rules.h @@ -0,0 +1,40 @@ +#ifndef _FF_RULES_H_ +#define _FF_RULES_H_ + +#include <vector> +#include <map> +#include "ff.h" +#include "array2d.h" +#include "wordid.h" + +class RuleIdentityFeatures : public FeatureFunction { + public: +  RuleIdentityFeatures(const std::string& param); + protected: +  virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, +                                     const Hypergraph::Edge& edge, +                                     const std::vector<const void*>& ant_contexts, +                                     SparseVector<double>* features, +                                     SparseVector<double>* estimated_features, +                                     void* context) const; +  virtual void PrepareForInput(const SentenceMetadata& smeta); + private: +  mutable std::map<const TRule*, int> rule2_fid_; +}; + +class RuleNgramFeatures : public FeatureFunction { + public: +  RuleNgramFeatures(const std::string& param); + protected: +  virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, +                                     const Hypergraph::Edge& edge, +                                     const std::vector<const void*>& ant_contexts, +                                     SparseVector<double>* features, +                                     SparseVector<double>* estimated_features, +                                     void* context) const; +  virtual void PrepareForInput(const SentenceMetadata& smeta); + private: +  mutable std::map<const TRule*, SparseVector<double> > rule2_feats_; +}; + +#endif diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc index e1da088d..0483517b 100644 --- a/decoder/ff_spans.cc +++ b/decoder/ff_spans.cc @@ -13,6 +13,17 @@  using namespace std; +namespace { +  string Escape(const string& x) { +    string y = x; +    for (int i = 0; i < y.size(); ++i) { +      if (y[i] == '=') y[i]='_'; +      if (y[i] == ';') y[i]='_'; +    } +    return y; +  } +} +  // log transform to make long spans cluster together  // but preserve differences  int SpanSizeTransform(unsigned span_size) { @@ -140,19 +151,19 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {      word = MapIfNecessary(word);      ostringstream sfid;      sfid << "ES:" << TD::Convert(word); -    end_span_ids_[i] = FD::Convert(sfid.str()); +    end_span_ids_[i] = FD::Convert(Escape(sfid.str()));      ostringstream esbiid;      esbiid << "EBI:" << TD::Convert(bword) << "_" << TD::Convert(word); -    end_bigram_ids_[i] = FD::Convert(esbiid.str()); +    end_bigram_ids_[i] = FD::Convert(Escape(esbiid.str()));      ostringstream bsbiid;      bsbiid << "BBI:" << TD::Convert(bword) << "_" << TD::Convert(word); -    beg_bigram_ids_[i] = FD::Convert(bsbiid.str()); +    beg_bigram_ids_[i] = FD::Convert(Escape(bsbiid.str()));      ostringstream bfid;      bfid << "BS:" << TD::Convert(bword); -    beg_span_ids_[i] = FD::Convert(bfid.str()); +    beg_span_ids_[i] = FD::Convert(Escape(bfid.str()));      if (use_collapsed_features_) { -      end_span_vals_[i] = feat2val_[sfid.str()] + feat2val_[esbiid.str()]; -      beg_span_vals_[i] = feat2val_[bfid.str()] + feat2val_[bsbiid.str()]; +      end_span_vals_[i] = feat2val_[Escape(sfid.str())] + feat2val_[Escape(esbiid.str())]; +      beg_span_vals_[i] = feat2val_[Escape(bfid.str())] + feat2val_[Escape(bsbiid.str())];      }    }    for (int i = 0; i <= lattice.size(); ++i) { @@ -167,60 +178,21 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {        word = MapIfNecessary(word);        ostringstream pf;        pf << "S:" << TD::Convert(bword) << "_" << TD::Convert(word); -      span_feats_(i,j).first = FD::Convert(pf.str()); -      span_feats_(i,j).second = FD::Convert("S_" + pf.str()); +      span_feats_(i,j).first = FD::Convert(Escape(pf.str())); +      span_feats_(i,j).second = FD::Convert(Escape("S_" + pf.str()));        ostringstream lf;        const unsigned span_size = (i < j ? j - i : i - j);        lf << "LS:" << SpanSizeTransform(span_size) << "_" << TD::Convert(bword) << "_" << TD::Convert(word); -      len_span_feats_(i,j).first = FD::Convert(lf.str()); -      len_span_feats_(i,j).second = FD::Convert("S_" + lf.str()); +      len_span_feats_(i,j).first = FD::Convert(Escape(lf.str())); +      len_span_feats_(i,j).second = FD::Convert(Escape("S_" + lf.str()));        if (use_collapsed_features_) { -        span_vals_(i,j).first = feat2val_[pf.str()] + feat2val_[lf.str()]; -        span_vals_(i,j).second = feat2val_["S_" + pf.str()] + feat2val_["S_" + lf.str()]; +        span_vals_(i,j).first = feat2val_[Escape(pf.str())] + feat2val_[Escape(lf.str())]; +        span_vals_(i,j).second = feat2val_[Escape("S_" + pf.str())] + feat2val_[Escape("S_" + lf.str())];        }      }    }   } -RuleNgramFeatures::RuleNgramFeatures(const std::string& param) { -} - -void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) { -//  std::map<const TRule*, SparseVector<double> > -  rule2_feats_.clear(); -} - -void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, -                                         const Hypergraph::Edge& edge, -                                         const vector<const void*>& ant_contexts, -                                         SparseVector<double>* features, -                                         SparseVector<double>* estimated_features, -                                         void* context) const { -  map<const TRule*, SparseVector<double> >::iterator it = rule2_feats_.find(edge.rule_.get()); -  if (it == rule2_feats_.end()) { -    const TRule& rule = *edge.rule_; -    it = rule2_feats_.insert(make_pair(&rule, SparseVector<double>())).first; -    SparseVector<double>& f = it->second; -    string prev = "<r>"; -    for (int i = 0; i < rule.f_.size(); ++i) { -      WordID w = rule.f_[i]; -      if (w < 0) w = -w; -      assert(w > 0); -      const string& cur = TD::Convert(w); -      ostringstream os; -      os << "RB:" << prev << '_' << cur; -      const int fid = FD::Convert(os.str()); -      if (fid <= 0) return; -      f.add_value(fid, 1.0); -      prev = cur; -    } -    ostringstream os; -    os << "RB:" << prev << '_' << "</r>"; -    f.set_value(FD::Convert(os.str()), 1.0); -  } -  (*features) += it->second; -} -  inline bool IsArity2RuleReordered(const TRule& rule) {    const vector<WordID>& e = rule.e_;    for (int i = 0; i < e.size(); ++i) { diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h index b22c4d03..24e0dede 100644 --- a/decoder/ff_spans.h +++ b/decoder/ff_spans.h @@ -44,21 +44,6 @@ class SpanFeatures : public FeatureFunction {    WordID oov_;  }; -class RuleNgramFeatures : public FeatureFunction { - public: -  RuleNgramFeatures(const std::string& param); - protected: -  virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, -                                     const Hypergraph::Edge& edge, -                                     const std::vector<const void*>& ant_contexts, -                                     SparseVector<double>* features, -                                     SparseVector<double>* estimated_features, -                                     void* context) const; -  virtual void PrepareForInput(const SentenceMetadata& smeta); - private: -  mutable std::map<const TRule*, SparseVector<double> > rule2_feats_; -}; -  class CMR2008ReorderingFeatures : public FeatureFunction {   public:    CMR2008ReorderingFeatures(const std::string& param); diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index d58478a8..35996d6d 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -215,7 +215,7 @@ main( int argc, char** argv )    // for the perceptron/SVM; TODO as params    double eta = 0.0005; -  double gamma = 0.01; // -> SVM +  double gamma = 0.;//01; // -> SVM    lambdas.add_value( FD::Convert("__bias"), 0 );    // for random sampling @@ -388,10 +388,8 @@ main( int argc, char** argv )      if ( !noup ) {        TrainingInstances pairs; - -      sample_all_rand(kb, pairs); -      cout << pairs.size() << endl; -             +      sample_all( kb, pairs ); +               for ( TrainingInstances::iterator ti = pairs.begin();              ti != pairs.end(); ti++ ) { @@ -401,31 +399,29 @@ main( int argc, char** argv )        //} else {          //dv = ti->first - ti->second;        //} -        dv.add_value( FD::Convert("__bias"), -1 ); +          dv.add_value( FD::Convert("__bias"), -1 ); -        SparseVector<double> reg; -        reg = lambdas * ( 2 * gamma ); -        dv -= reg; -        lambdas += dv * eta; - -        if ( verbose ) { -          cout << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl; -          cout << " i  " << TD::GetString(kb->sents[ti->first_rank]) << endl; -          cout << "    " << kb->feats[ti->first_rank] << endl; -          cout << " j  " << TD::GetString(kb->sents[ti->second_rank]) << endl; -          cout << "    " << kb->feats[ti->second_rank] << endl;  -          cout << " diff vec: " << dv << endl; -          cout << " lambdas after update: " << lambdas << endl; -          cout << "}}" << endl; -        } - +          //SparseVector<double> reg; +          //reg = lambdas * ( 2 * gamma ); +          //dv -= reg; +          lambdas += dv * eta; + +          if ( verbose ) { +            cout << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl; +            cout << " i  " << TD::GetString(kb->sents[ti->first_rank]) << endl; +            cout << "    " << kb->feats[ti->first_rank] << endl; +            cout << " j  " << TD::GetString(kb->sents[ti->second_rank]) << endl; +            cout << "    " << kb->feats[ti->second_rank] << endl;  +            cout << " diff vec: " << dv << endl; +            cout << " lambdas after update: " << lambdas << endl; +            cout << "}}" << endl; +          }          } else { -            //if ( 0 ) { -            SparseVector<double> reg; -            reg = lambdas * ( gamma * 2 ); -            lambdas += reg * ( -eta ); -            //} +          //SparseVector<double> reg; +          //reg = lambdas * ( 2 * gamma ); +          //lambdas += reg * ( -eta );          } +        }        //double l2 = lambdas.l2norm(); diff --git a/dtrain/run.sh b/dtrain/run.sh index 16575c25..97123dfa 100755 --- a/dtrain/run.sh +++ b/dtrain/run.sh @@ -2,9 +2,10 @@  #INI=test/blunsom08.dtrain.ini  #INI=test/nc-wmt11/dtrain.ini -#INI=test/EXAMPLE/dtrain.ini -INI=test/EXAMPLE/dtrain.ruleids.ini +INI=test/EXAMPLE/dtrain.ini +#INI=test/EXAMPLE/dtrain.ruleids.ini  #INI=test/toy.dtrain.ini +#INI=test/EXAMPLE/dtrain.cdecrid.ini  rm /tmp/dtrain-*  ./dtrain -c $INI $1 $2 $3 $4  diff --git a/dtrain/sample.h b/dtrain/sample.h index b6aa9abd..502901af 100644 --- a/dtrain/sample.h +++ b/dtrain/sample.h @@ -37,20 +37,20 @@ sample_all( KBestList* kb, TrainingInstances &training )  }  void -sample_all_rand( KBestList* kb, TrainingInstances &training ) +sample_rand( KBestList* kb, TrainingInstances &training )  {    srand( time(NULL) );    for ( size_t i = 0; i < kb->GetSize()-1; i++ ) {      for ( size_t j = i+1; j < kb->GetSize(); j++ ) {        if ( rand() % 2 ) { -      TPair p; -      p.first = kb->feats[i]; -      p.second = kb->feats[j]; -      p.first_rank = i; -      p.second_rank = j; -      p.first_score = kb->scores[i]; -      p.second_score = kb->scores[j]; -      training.push_back( p ); +        TPair p; +        p.first = kb->feats[i]; +        p.second = kb->feats[j]; +        p.first_rank = i; +        p.second_rank = j; +        p.first_score = kb->scores[i]; +        p.second_score = kb->scores[j]; +        training.push_back( p );        }      }    } diff --git a/dtrain/test/EXAMPLE/dtrain.ini b/dtrain/test/EXAMPLE/dtrain.ini index 7645921a..7221ba3f 100644 --- a/dtrain/test/EXAMPLE/dtrain.ini +++ b/dtrain/test/EXAMPLE/dtrain.ini @@ -5,6 +5,6 @@ epochs=3  input=test/EXAMPLE/dtrain.nc-1k  scorer=stupid_bleu  output=test/EXAMPLE/weights.gz -stop_after=100 +stop_after=1000  wprint=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 395494bc..fae6b41a 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -12,6 +12,7 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz  noinst_LIBRARIES = libklm.a  libklm_a_SOURCES = \ +  bhiksha.cc \    binary_format.cc \    config.cc \    lm_exception.cc \ diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc new file mode 100644 index 00000000..bf86fd4b --- /dev/null +++ b/klm/lm/bhiksha.cc @@ -0,0 +1,93 @@ +#include "lm/bhiksha.hh" +#include "lm/config.hh" + +#include <limits> + +namespace lm { +namespace ngram { +namespace trie { + +DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) :  +  next_(util::BitsMask::ByMax(max_next)) {} + +const uint8_t kArrayBhikshaVersion = 0; + +void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { +  uint8_t version; +  uint8_t configured_bits; +  if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) { +    UTIL_THROW(util::ErrnoException, "Could not read from binary file"); +  } +  if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); +  config.pointer_bhiksha_bits = configured_bits; +} + +namespace { + +// Find argmin_{chopped \in [0, RequiredBits(max_next)]} ChoppedDelta(max_offset) +uint8_t ChopBits(uint64_t max_offset, uint64_t max_next, const Config &config) { +  uint8_t required = util::RequiredBits(max_next); +  uint8_t best_chop = 0; +  int64_t lowest_change = std::numeric_limits<int64_t>::max(); +  // There are probably faster ways but I don't care because this is only done once per order at construction time.   +  for (uint8_t chop = 0; chop <= std::min(required, config.pointer_bhiksha_bits); ++chop) { +    int64_t change = (max_next >> (required - chop)) * 64 /* table cost in bits */ +      - max_offset * static_cast<int64_t>(chop); /* savings in bits*/ +    if (change < lowest_change) { +      lowest_change = change; +      best_chop = chop; +    } +  } +  return best_chop; +} + +std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &config) { +  uint8_t required = util::RequiredBits(max_next); +  uint8_t chopping = ChopBits(max_offset, max_next, config); +  return (max_next >> (required - chopping)) + 1 /* we store 0 too */; +} +} // namespace + +std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { +  return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */; +} + +uint8_t ArrayBhiksha::InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config) { +  return util::RequiredBits(max_next) - ChopBits(max_offset, max_next, config); +} + +namespace { + +void *AlignTo8(void *from) { +  uint8_t *val = reinterpret_cast<uint8_t*>(from); +  std::size_t remainder = reinterpret_cast<std::size_t>(val) & 7; +  if (!remainder) return val; +  return val + 8 - remainder; +} + +} // namespace + +ArrayBhiksha::ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_next, const Config &config) +  : next_inline_(util::BitsMask::ByBits(InlineBits(max_offset, max_next, config))), +    offset_begin_(reinterpret_cast<const uint64_t*>(AlignTo8(base)) + 1 /* 8-byte header */), +    offset_end_(offset_begin_ + ArrayCount(max_offset, max_next, config)), +    write_to_(reinterpret_cast<uint64_t*>(AlignTo8(base)) + 1 /* 8-byte header */ + 1 /* first entry is 0 */), +    original_base_(base) {} + +void ArrayBhiksha::FinishedLoading(const Config &config) { +  // *offset_begin_ = 0 but without a const_cast. +  *(write_to_ - (write_to_ - offset_begin_)) = 0; + +  if (write_to_ != offset_end_) UTIL_THROW(util::Exception, "Did not get all the array entries that were expected."); + +  uint8_t *head_write = reinterpret_cast<uint8_t*>(original_base_); +  *(head_write++) = kArrayBhikshaVersion; +  *(head_write++) = config.pointer_bhiksha_bits; +} + +void ArrayBhiksha::LoadedBinary() { +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh new file mode 100644 index 00000000..cfb2b053 --- /dev/null +++ b/klm/lm/bhiksha.hh @@ -0,0 +1,108 @@ +/* Simple implementation of + * @inproceedings{bhikshacompression, + *  author={Bhiksha Raj and Ed Whittaker}, + *  year={2003}, + *  title={Lossless Compression of Language Model Structure and Word Identifiers}, + *  booktitle={Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing}, + *  pages={388--391}, + *  } + * + *  Currently only used for next pointers.   + */ + +#include <inttypes.h> + +#include "lm/binary_format.hh" +#include "lm/trie.hh" +#include "util/bit_packing.hh" +#include "util/sorted_uniform.hh" + +namespace lm { +namespace ngram { +class Config; + +namespace trie { + +class DontBhiksha { +  public: +    static const ModelType kModelTypeAdd = static_cast<ModelType>(0); + +    static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} + +    static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } + +    static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) { +      return util::RequiredBits(max_next); +    } + +    DontBhiksha(const void *base, uint64_t max_offset, uint64_t max_next, const Config &config); + +    void ReadNext(const void *base, uint64_t bit_offset, uint64_t /*index*/, uint8_t total_bits, NodeRange &out) const { +      out.begin = util::ReadInt57(base, bit_offset, next_.bits, next_.mask); +      out.end = util::ReadInt57(base, bit_offset + total_bits, next_.bits, next_.mask); +      //assert(out.end >= out.begin); +    } + +    void WriteNext(void *base, uint64_t bit_offset, uint64_t /*index*/, uint64_t value) { +      util::WriteInt57(base, bit_offset, next_.bits, value); +    } + +    void FinishedLoading(const Config &/*config*/) {} + +    void LoadedBinary() {} + +    uint8_t InlineBits() const { return next_.bits; } + +  private: +    util::BitsMask next_; +}; + +class ArrayBhiksha { +  public: +    static const ModelType kModelTypeAdd = kArrayAdd; + +    static void UpdateConfigFromBinary(int fd, Config &config); + +    static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); + +    static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config); + +    ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config); + +    void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const { +      const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor<uint64_t>(), offset_begin_, offset_end_, index); +      const uint64_t *end_it; +      for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {} +      --end_it; +      out.begin = ((begin_it - offset_begin_) << next_inline_.bits) |  +        util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask); +      out.end = ((end_it - offset_begin_) << next_inline_.bits) |  +        util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask); +    } + +    void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) { +      uint64_t encode = value >> next_inline_.bits; +      for (; write_to_ <= offset_begin_ + encode; ++write_to_) *write_to_ = index; +      util::WriteInt57(base, bit_offset, next_inline_.bits, value & next_inline_.mask); +    } + +    void FinishedLoading(const Config &config); + +    void LoadedBinary(); + +    uint8_t InlineBits() const { return next_inline_.bits; } + +  private: +    const util::BitsMask next_inline_; + +    const uint64_t *const offset_begin_; +    const uint64_t *const offset_end_; + +    uint64_t *write_to_; + +    void *original_base_; +}; + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 92b1008b..e02e621a 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -40,7 +40,7 @@ struct Sanity {    }  }; -const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"}; +const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};  std::size_t Align8(std::size_t in) {    std::size_t off = in % 8; @@ -100,16 +100,17 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_    }  } -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) { +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { +  std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;    if (config.write_mmap) {      // Grow the file to accomodate the search, using zeros.   -    if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size)) -      UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed"); +    if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size)) +      UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed");      // We're skipping over the header and vocab for the search space mmap.  mmap likes page aligned offsets, so some arithmetic to round the offset down.        off_t page_size = sysconf(_SC_PAGE_SIZE); -    off_t alignment_cruft = backing.vocab.size() % page_size; -    backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); +    off_t alignment_cruft = adjusted_vocab % page_size; +    backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);      return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;    } else { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 2b32b450..d28cb6c5 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -16,7 +16,12 @@  namespace lm {  namespace ngram { -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType; +/* Not the best numbering system, but it grew this way for historical reasons + * and I want to preserve existing binary files. */ +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; + +const static ModelType kQuantAdd = static_cast<ModelType>(QUANT_TRIE_SORTED - TRIE_SORTED); +const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE_SORTED - TRIE_SORTED);  /*Inspect a file to determine if it is a binary lm.  If not, return false.     * If so, return true and set recognized to the type.  This is the only API in @@ -55,7 +60,7 @@ void AdvanceOrThrow(int fd, off_t off);  // Create just enough of a binary file to write vocabulary to it.    uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);  // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.   -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing); +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing);  // Write header to binary file.  This is done last to prevent incomplete files  // from loading.    diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 4552c419..b7aee4de 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,12 +15,12 @@ namespace ngram {  namespace {  void Usage(const char *name) { -  std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n" -"-u sets the default log10 probability for <unk> if the ARPA file does not have\n" -"one.\n" +  std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n" +"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n" +"   Default is -100.  The ARPA file will always take precedence.\n"  "-s allows models to be built even if they do not have <s> and </s>.\n" -"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"type is either probing or trie:\n\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n" +"type is either probing or trie.  Default is probing.\n\n"  "probing uses a probing hash table.  It is the fastest but uses the most memory.\n"  "-p sets the space multiplier and must be >1.0.  The default is 1.5.\n\n"  "trie is a straightforward trie with bit-level packing.  It uses the least\n" @@ -29,10 +29,11 @@ void Usage(const char *name) {  "-t is the temporary directory prefix.  Default is the output file name.\n"  "-m limits memory use for sorting.  Measured in MB.  Default is 1024MB.\n"  "-q turns quantization on and sets the number of bits (e.g. -q 8).\n" -"-b sets backoff quantization bits.  Requires -q and defaults to that value.\n\n" -"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" -"Passing only an input file will print memory usage of each data structure.\n" -"If the ARPA file does not have <unk>, -u sets <unk>'s probability; default 0.0.\n"; +"-b sets backoff quantization bits.  Requires -q and defaults to that value.\n" +"-a compresses pointers using an array of offsets.  The parameter is the\n" +"   maximum number of bits encoded by the array.  Memory is minimized subject\n" +"   to the maximum, so pick 255 to minimize memory.\n\n" +"Get a memory estimate by passing an ARPA file without an output file name.\n";    exit(1);  } @@ -63,12 +64,14 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {    std::vector<uint64_t> counts;    util::FilePiece f(file);    lm::ReadARPACounts(f, counts); -  std::size_t sizes[3]; +  std::size_t sizes[5];    sizes[0] = ProbingModel::Size(counts, config);    sizes[1] = TrieModel::Size(counts, config);    sizes[2] = QuantTrieModel::Size(counts, config); -  std::size_t max_length = *std::max_element(sizes, sizes + 3); -  std::size_t min_length = *std::max_element(sizes, sizes + 3); +  sizes[3] = ArrayTrieModel::Size(counts, config); +  sizes[4] = QuantArrayTrieModel::Size(counts, config); +  std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); +  std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));    std::size_t divide;    char prefix;    if (min_length < (1 << 10) * 10) { @@ -91,7 +94,9 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {    std::cout << prefix << "B\n"      "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"      "trie    " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" -    "trie    " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"; +    "trie    " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" +    "trie    " << std::setw(length) << (sizes[3] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" +    "trie    " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";  }  void ProbingQuantizationUnsupported() { @@ -106,11 +111,11 @@ void ProbingQuantizationUnsupported() {  int main(int argc, char *argv[]) {    using namespace lm::ngram; -  bool quantize = false, set_backoff_bits = false;    try { +    bool quantize = false, set_backoff_bits = false, bhiksha = false;      lm::ngram::Config config;      int opt; -    while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) { +    while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) {        switch(opt) {          case 'q':            config.prob_bits = ParseBitCount(optarg); @@ -121,6 +126,9 @@ int main(int argc, char *argv[]) {            config.backoff_bits = ParseBitCount(optarg);            set_backoff_bits = true;            break; +        case 'a': +          config.pointer_bhiksha_bits = ParseBitCount(optarg); +          bhiksha = true;          case 'u':            config.unknown_missing_logprob = ParseFloat(optarg);            break; @@ -162,9 +170,17 @@ int main(int argc, char *argv[]) {          ProbingModel(from_file, config);        } else if (!strcmp(model_type, "trie")) {          if (quantize) { -          QuantTrieModel(from_file, config); +          if (bhiksha) { +            QuantArrayTrieModel(from_file, config); +          } else { +            QuantTrieModel(from_file, config); +          }          } else { -          TrieModel(from_file, config); +          if (bhiksha) { +            ArrayTrieModel(from_file, config); +          } else { +            TrieModel(from_file, config); +          }          }        } else {          Usage(argv[0]); @@ -173,9 +189,9 @@ int main(int argc, char *argv[]) {        Usage(argv[0]);      }    } -  catch (std::exception &e) { +  catch (const std::exception &e) {      std::cerr << e.what() << std::endl; -    abort(); +    return 1;    }    return 0;  } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 08e1af5c..297589a4 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -20,6 +20,7 @@ Config::Config() :    include_vocab(true),    prob_bits(8),    backoff_bits(8), +  pointer_bhiksha_bits(22),    load_method(util::POPULATE_OR_READ) {}  } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index dcc7cf35..227b8512 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -73,9 +73,12 @@ struct Config {    // Quantization options.  Only effective for QuantTrieModel.  One value is    // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used -  // to quantize.   +  // to quantize (and one of the remaining backoffs will be 0).      uint8_t prob_bits, backoff_bits; +  // Bhiksha compression (simple form).  Only works with trie. +  uint8_t pointer_bhiksha_bits; +    // ONLY EFFECTIVE WHEN READING BINARY diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a1d10b3d..27e24b1c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -21,6 +21,8 @@ size_t hash_value(const State &state) {  namespace detail { +template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType; +  template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {    return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);  } @@ -56,35 +58,40 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {    // Backing file is the ARPA.  Steal it so we can make the backing file the mmap output if any.      util::FilePiece f(backing_.file.release(), file, config.messages); -  std::vector<uint64_t> counts; -  // File counts do not include pruned trigrams that extend to quadgrams etc.   These will be fixed by search_. -  ReadARPACounts(f, counts); - -  if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ".  Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); -  if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); -  if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); - -  std::size_t vocab_size = VocabularyT::Size(counts[0], config); -  // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs.   -  vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); - -  if (config.write_mmap) { -    WriteWordsWrapper wrap(config.enumerate_vocab); -    vocab_.ConfigureEnumerate(&wrap, counts[0]); -    search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); -    wrap.Write(backing_.file.get()); -  } else { -    vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); -    search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); -  } +  try { +    std::vector<uint64_t> counts; +    // File counts do not include pruned trigrams that extend to quadgrams etc.   These will be fixed by search_. +    ReadARPACounts(f, counts); + +    if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ".  Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); +    if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); +    if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); + +    std::size_t vocab_size = VocabularyT::Size(counts[0], config); +    // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs.   +    vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); + +    if (config.write_mmap) { +      WriteWordsWrapper wrap(config.enumerate_vocab); +      vocab_.ConfigureEnumerate(&wrap, counts[0]); +      search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); +      wrap.Write(backing_.file.get()); +    } else { +      vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); +      search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); +    } -  if (!vocab_.SawUnk()) { -    assert(config.unknown_missing != THROW_UP); -    // Default probabilities for unknown.   -    search_.unigram.Unknown().backoff = 0.0; -    search_.unigram.Unknown().prob = config.unknown_missing_logprob; +    if (!vocab_.SawUnk()) { +      assert(config.unknown_missing != THROW_UP); +      // Default probabilities for unknown.   +      search_.unigram.Unknown().backoff = 0.0; +      search_.unigram.Unknown().prob = config.unknown_missing_logprob; +    } +    FinishFile(config, kModelType, counts, backing_); +  } catch (util::Exception &e) { +    e << " Byte: " << f.Offset(); +    throw;    } -  FinishFile(config, kModelType, counts, backing_);  }  template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { @@ -225,8 +232,10 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,  }  template class GenericModel<ProbingHashedSearch, ProbingVocabulary>;  // HASH_PROBING -template class GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary>; // TRIE_SORTED -template class GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED +template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary>; +template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;  } // namespace detail  } // namespace ngram diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 1f49a382..21595321 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -1,6 +1,7 @@  #ifndef LM_MODEL__  #define LM_MODEL__ +#include "lm/bhiksha.hh"  #include "lm/binary_format.hh"  #include "lm/config.hh"  #include "lm/facade.hh" @@ -71,6 +72,9 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod    private:      typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;    public: +    // This is the model type returned by RecognizeBinary. +    static const ModelType kModelType; +      /* Get the size of memory that will be mapped given ngram counts.  This       * does not include small non-mapped control structures, such as this class       * itself.   @@ -131,8 +135,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod      Backing &MutableBacking() { return backing_; } -    static const ModelType kModelType = Search::kModelType; -      Backing backing_;      VocabularyT vocab_; @@ -152,9 +154,11 @@ typedef ProbingModel Model;  // Smaller implementation.  typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary> ArrayTrieModel; -typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary> QuantArrayTrieModel;  } // namespace ngram  } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 8bf040ff..57c7291c 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -193,6 +193,14 @@ template <class M> void Stateless(const M &model) {    BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]);  } +template <class M> void NoUnkCheck(const M &model) { +  WordIndex unk_index = 0; +  State state; + +  FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state); +  BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001); +} +  template <class M> void Everything(const M &m) {    Starters(m);    Continuation(m); @@ -231,25 +239,38 @@ template <class ModelT> void LoadingTest() {    Config config;    config.arpa_complain = Config::NONE;    config.messages = NULL; -  ExpectEnumerateVocab enumerate; -  config.enumerate_vocab = &enumerate;    config.probing_multiplier = 2.0; -  ModelT m("test.arpa", config); -  enumerate.Check(m.GetVocabulary()); -  Everything(m); +  { +    ExpectEnumerateVocab enumerate; +    config.enumerate_vocab = &enumerate; +    ModelT m("test.arpa", config); +    enumerate.Check(m.GetVocabulary()); +    Everything(m); +  } +  { +    ExpectEnumerateVocab enumerate; +    config.enumerate_vocab = &enumerate; +    ModelT m("test_nounk.arpa", config); +    enumerate.Check(m.GetVocabulary()); +    NoUnkCheck(m); +  }  }  BOOST_AUTO_TEST_CASE(probing) {    LoadingTest<Model>();  } -  BOOST_AUTO_TEST_CASE(trie) {    LoadingTest<TrieModel>();  } - -BOOST_AUTO_TEST_CASE(quant) { +BOOST_AUTO_TEST_CASE(quant_trie) {    LoadingTest<QuantTrieModel>();  } +BOOST_AUTO_TEST_CASE(bhiksha_trie) { +  LoadingTest<ArrayTrieModel>(); +} +BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) { +  LoadingTest<QuantArrayTrieModel>(); +}  template <class ModelT> void BinaryTest() {    Config config; @@ -267,10 +288,34 @@ template <class ModelT> void BinaryTest() {    config.write_mmap = NULL; -  ModelT binary("test.binary", config); -  enumerate.Check(binary.GetVocabulary()); -  Everything(binary); +  ModelType type; +  BOOST_REQUIRE(RecognizeBinary("test.binary", type)); +  BOOST_CHECK_EQUAL(ModelT::kModelType, type); + +  { +    ModelT binary("test.binary", config); +    enumerate.Check(binary.GetVocabulary()); +    Everything(binary); +  }    unlink("test.binary"); + +  // Now test without <unk>. +  config.write_mmap = "test_nounk.binary"; +  config.messages = NULL; +  enumerate.Clear(); +  { +    ModelT copy_model("test_nounk.arpa", config); +    enumerate.Check(copy_model.GetVocabulary()); +    enumerate.Clear(); +    NoUnkCheck(copy_model); +  } +  config.write_mmap = NULL; +  { +    ModelT binary("test_nounk.binary", config); +    enumerate.Check(binary.GetVocabulary()); +    NoUnkCheck(binary); +  } +  unlink("test_nounk.binary");  }  BOOST_AUTO_TEST_CASE(write_and_read_probing) { @@ -282,6 +327,12 @@ BOOST_AUTO_TEST_CASE(write_and_read_trie) {  BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) {    BinaryTest<QuantTrieModel>();  } +BOOST_AUTO_TEST_CASE(write_and_read_array_trie) { +  BinaryTest<ArrayTrieModel>(); +} +BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) { +  BinaryTest<QuantArrayTrieModel>(); +}  } // namespace  } // namespace ngram diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 9454a6d1..d9db4aa2 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -99,6 +99,15 @@ int main(int argc, char *argv[]) {        case lm::ngram::TRIE_SORTED:          Query<lm::ngram::TrieModel>(argv[1], sentence_context);          break; +      case lm::ngram::QUANT_TRIE_SORTED: +        Query<lm::ngram::QuantTrieModel>(argv[1], sentence_context); +        break; +      case lm::ngram::ARRAY_TRIE_SORTED: +        Query<lm::ngram::ArrayTrieModel>(argv[1], sentence_context); +        break; +      case lm::ngram::QUANT_ARRAY_TRIE_SORTED: +        Query<lm::ngram::QuantArrayTrieModel>(argv[1], sentence_context); +        break;        case lm::ngram::HASH_SORTED:        default:          std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index 4bb6b1b8..fd371cc8 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -43,6 +43,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64    if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1)       UTIL_THROW(util::ErrnoException, "Failed to read header for quantization.");    if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); +  AdvanceOrThrow(fd, -3);  }  void SeparatelyQuantize::SetupMemory(void *start, const Config &config) { diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index aae72b34..0b71d14a 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -21,7 +21,7 @@ class Config;  /* Store values directly and don't quantize. */  class DontQuantize {    public: -    static const ModelType kModelType = TRIE_SORTED; +    static const ModelType kModelTypeAdd = static_cast<ModelType>(0);      static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}      static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }      static uint8_t MiddleBits(const Config &/*config*/) { return 63; } @@ -108,7 +108,7 @@ class SeparatelyQuantize {      };    public: -    static const ModelType kModelType = QUANT_TRIE_SORTED; +    static const ModelType kModelTypeAdd = kQuantAdd;      static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 060a97ea..455bc4ba 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -31,15 +31,15 @@ const char kBinaryMagic[] = "mmap lm http://kheafield.com/code";  void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {    number.clear();    StringPiece line; -  if (!IsEntirelyWhiteSpace(line = in.ReadLine())) { +  while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} +  if (line != "\\data\\") {      if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {        UTIL_THROW(FormatLoadException, "Looks like a gzip file.  If this is an ARPA file, pipe " << in.FileName() << " through zcat.  If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");      }      if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)         UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser.  Did you compress the binary file or pass a binary file where only ARPA files are accepted?"); -    UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank"); +    UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\.");    } -  if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\.");    while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {      if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");      // So strtol doesn't go off the end of line.   diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index c56ba7b8..82c53ec8 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -98,7 +98,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT,  template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {    // TODO: fix sorted. -  SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config); +  SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config);    PositiveProbWarn warn(config.positive_log_probability); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index f3acdefc..c62985e4 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -52,12 +52,11 @@ struct HashedSearch {    Unigram unigram; -  bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { +  void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {      const ProbBackoff &entry = unigram.Lookup(word);      prob = entry.prob;      backoff = entry.backoff;      next = static_cast<Node>(word); -    return true;    }  }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 91f87f1c..05059ffb 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,6 +1,7 @@  /* This is where the trie is built.  It's on-disk.  */  #include "lm/search_trie.hh" +#include "lm/bhiksha.hh"  #include "lm/blank.hh"  #include "lm/lm_exception.hh"  #include "lm/max_order.hh" @@ -543,8 +544,8 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin      std::string unigram_name = file_prefix + "unigrams";      util::scoped_fd unigram_file;      // In case <unk> appears.   -    size_t extra_count = counts[0] + 1; -    util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); +    size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); +    util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out);      Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);      CheckSpecials(config, vocab);      if (!vocab.SawUnk()) ++counts[0]; @@ -610,9 +611,9 @@ class JustCount {  };  // Phase to actually write n-grams to the trie.   -template <class Quant> class WriteEntries { +template <class Quant, class Bhiksha> class WriteEntries {    public: -    WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :  +    WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :         contexts_(contexts),        unigrams_(unigrams),        middle_(middle), @@ -649,7 +650,7 @@ template <class Quant> class WriteEntries {    private:      ContextReader *contexts_;      UnigramValue *const unigrams_; -    BitPackedMiddle<typename Quant::Middle> *const middle_; +    BitPackedMiddle<typename Quant::Middle, Bhiksha> *const middle_;      BitPackedLongest<typename Quant::Longest> &longest_;      BitPacked &bigram_pack_;  }; @@ -821,7 +822,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, So  } // namespace -template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing) { +template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {    std::vector<SortedFileReader> inputs(counts.size() - 1);    std::vector<ContextReader> contexts(counts.size() - 1); @@ -846,7 +847,7 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto    SanityCheckCounts(counts, fixed_counts);    counts = fixed_counts; -  out.SetupMemory(GrowForSearch(config, TrieSearch<Quant>::Size(fixed_counts, config), backing), fixed_counts, config); +  out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);    if (Quant::kTrain) {      util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); @@ -863,7 +864,7 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto    UnigramValue *unigrams = out.unigram.Raw();    // Fill entries except unigram probabilities.      { -    RecursiveInsert<WriteEntries<Quant> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); +    RecursiveInsert<WriteEntries<Quant, Bhiksha> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());      inserter.Apply(config.messages, "Building trie", fixed_counts[0]);    } @@ -901,14 +902,14 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto    /* Set ending offsets so the last entry will be sized properly */    // Last entry for unigrams was already set.      if (out.middle_begin_ != out.middle_end_) { -    for (typename TrieSearch<Quant>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { -      i->FinishedLoading((i+1)->InsertIndex()); +    for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { +      i->FinishedLoading((i+1)->InsertIndex(), config);      } -    (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); +    (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config);    }    } -template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { +template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {    quant_.SetupMemory(start, config);    start += Quant::Size(counts.size(), config);    unigram.Init(start); @@ -919,22 +920,24 @@ template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, c    std::vector<uint8_t*> middle_starts(counts.size() - 2);    for (unsigned char i = 2; i < counts.size(); ++i) {      middle_starts[i-2] = start; -    start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); +    start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config);    } -  // Crazy backwards thing so we initialize in the correct order.   +  // Crazy backwards thing so we initialize using pointers to ones that have already been initialized    for (unsigned char i = counts.size() - 1; i >= 2; --i) {      new (middle_begin_ + i - 2) Middle(          middle_starts[i-2],          quant_.Mid(i), +        counts[i-1],          counts[0],          counts[i], -        (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1])); +        (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]), +        config);    }    longest.Init(start, quant_.Long(counts.size()), counts[0]);    return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);  } -template <class Quant> void TrieSearch<Quant>::LoadedBinary() { +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {    unigram.LoadedBinary();    for (Middle *i = middle_begin_; i != middle_end_; ++i) {      i->LoadedBinary(); @@ -942,7 +945,7 @@ template <class Quant> void TrieSearch<Quant>::LoadedBinary() {    longest.LoadedBinary();  } -template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {    std::string temporary_directory;    if (config.temporary_directory_prefix) {      temporary_directory = config.temporary_directory_prefix; @@ -966,14 +969,16 @@ template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *fi    // At least 1MB sorting memory.      ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); -  BuildTrie(temporary_directory, counts, config, *this, quant_, backing); +  BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing);    if (rmdir(temporary_directory.c_str()) && config.messages) {      *config.messages << "Failed to delete " << temporary_directory << std::endl;    }  } -template class TrieSearch<DontQuantize>; -template class TrieSearch<SeparatelyQuantize>; +template class TrieSearch<DontQuantize, DontBhiksha>; +template class TrieSearch<DontQuantize, ArrayBhiksha>; +template class TrieSearch<SeparatelyQuantize, DontBhiksha>; +template class TrieSearch<SeparatelyQuantize, ArrayBhiksha>;  } // namespace trie  } // namespace ngram diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0a52acb5..2f39c09f 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,31 +13,33 @@ struct Backing;  class SortedVocabulary;  namespace trie { -template <class Quant> class TrieSearch; -template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing); +template <class Quant, class Bhiksha> class TrieSearch; +template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); -template <class Quant> class TrieSearch { +template <class Quant, class Bhiksha> class TrieSearch {    public:      typedef NodeRange Node;      typedef ::lm::ngram::trie::Unigram Unigram;      Unigram unigram; -    typedef trie::BitPackedMiddle<typename Quant::Middle> Middle; +    typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;      typedef trie::BitPackedLongest<typename Quant::Longest> Longest;      Longest longest; -    static const ModelType kModelType = Quant::kModelType; +    static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);      static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {        Quant::UpdateConfigFromBinary(fd, counts, config); +      AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); +      Bhiksha::UpdateConfigFromBinary(fd, config);      }      static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {        std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);        for (unsigned char i = 1; i < counts.size() - 1; ++i) { -        ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); +        ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);        }        return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);      } @@ -55,8 +57,8 @@ template <class Quant> class TrieSearch {      void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); -    bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { -      return unigram.Find(word, prob, backoff, node); +    void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { +      unigram.Find(word, prob, backoff, node);      }      bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { @@ -83,7 +85,7 @@ template <class Quant> class TrieSearch {      }    private: -    friend void BuildTrie<Quant>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing); +    friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);      // Middles are managed manually so we can delay construction and they don't have to be copyable.        void FreeMiddles() { diff --git a/klm/lm/test_nounk.arpa b/klm/lm/test_nounk.arpa new file mode 100644 index 00000000..060733d9 --- /dev/null +++ b/klm/lm/test_nounk.arpa @@ -0,0 +1,120 @@ + +\data\ +ngram 1=36 +ngram 2=45 +ngram 3=10 +ngram 4=6 +ngram 5=4 + +\1-grams: +-1.383514	,	-0.30103 +-1.139057	.	-0.845098 +-1.029493	</s> +-99	<s>	-0.4149733 +-1.285941	a	-0.69897 +-1.687872	also	-0.30103 +-1.687872	beyond	-0.30103 +-1.687872	biarritz	-0.30103 +-1.687872	call	-0.30103 +-1.687872	concerns	-0.30103 +-1.687872	consider	-0.30103 +-1.687872	considering	-0.30103 +-1.687872	for	-0.30103 +-1.509559	higher	-0.30103 +-1.687872	however	-0.30103 +-1.687872	i	-0.30103 +-1.687872	immediate	-0.30103 +-1.687872	in	-0.30103 +-1.687872	is	-0.30103 +-1.285941	little	-0.69897 +-1.383514	loin	-0.30103 +-1.687872	look	-0.30103 +-1.285941	looking	-0.4771212 +-1.206319	more	-0.544068 +-1.509559	on	-0.4771212 +-1.509559	screening	-0.4771212 +-1.687872	small	-0.30103 +-1.687872	the	-0.30103 +-1.687872	to	-0.30103 +-1.687872	watch	-0.30103 +-1.687872	watching	-0.30103 +-1.687872	what	-0.30103 +-1.687872	would	-0.30103 +-3.141592	foo +-2.718281	bar	3.0 +-6.535897	baz	-0.0 + +\2-grams: +-0.6925742	, . +-0.7522095	, however +-0.7522095	, is +-0.0602359	. </s> +-0.4846522	<s> looking	-0.4771214 +-1.051485	<s> screening +-1.07153	<s> the +-1.07153	<s> watching +-1.07153	<s> what +-0.09132547	a little	-0.69897 +-0.2922095	also call +-0.2922095	beyond immediate +-0.2705918	biarritz . +-0.2922095	call for +-0.2922095	concerns in +-0.2922095	consider watch +-0.2922095	considering consider +-0.2834328	for , +-0.5511513	higher more +-0.5845945	higher small +-0.2834328	however , +-0.2922095	i would +-0.2922095	immediate concerns +-0.2922095	in biarritz +-0.2922095	is to +-0.09021038	little more	-0.1998621 +-0.7273645	loin , +-0.6925742	loin . +-0.6708385	loin </s> +-0.2922095	look beyond +-0.4638903	looking higher +-0.4638903	looking on	-0.4771212 +-0.5136299	more .	-0.4771212 +-0.3561665	more loin +-0.1649931	on a	-0.4771213 +-0.1649931	screening a	-0.4771213 +-0.2705918	small . +-0.287799	the screening +-0.2922095	to look +-0.2622373	watch </s> +-0.2922095	watching considering +-0.2922095	what i +-0.2922095	would also +-2	also would	-6 +-6	foo bar + +\3-grams: +-0.01916512	more . </s> +-0.0283603	on a little	-0.4771212 +-0.0283603	screening a little	-0.4771212 +-0.01660496	a little more	-0.09409451 +-0.3488368	<s> looking higher +-0.3488368	<s> looking on	-0.4771212 +-0.1892331	little more loin +-0.04835128	looking on a	-0.4771212 +-3	also would consider	-7 +-7	to look good + +\4-grams: +-0.009249173	looking on a little	-0.4771212 +-0.005464747	on a little more	-0.4771212 +-0.005464747	screening a little more +-0.1453306	a little more loin +-0.01552657	<s> looking on a	-0.4771212 +-4	also would consider higher	-8 + +\5-grams: +-0.003061223	<s> looking on a little +-0.001813953	looking on a little more +-0.0432557	on a little more loin +-5	also would consider higher looking + +\end\ diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 63c2a612..8c536e66 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,5 +1,6 @@  #include "lm/trie.hh" +#include "lm/bhiksha.hh"  #include "lm/quantize.hh"  #include "util/bit_packing.hh"  #include "util/exception.hh" @@ -57,16 +58,21 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)    max_vocab_ = max_vocab;  } -template <class Quant> std::size_t BitPackedMiddle<Quant>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { -  return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr)); +template <class Quant, class Bhiksha> std::size_t BitPackedMiddle<Quant, Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { +  return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));  } -template <class Quant> BitPackedMiddle<Quant>::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) { -  if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order.  Edit util/bit_packing.hh and fix the bit packing functions."); -  BaseInit(base, max_vocab, quant.TotalBits() + next_bits_); +template <class Quant, class Bhiksha> BitPackedMiddle<Quant, Bhiksha>::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : +  BitPacked(), +  quant_(quant), +  // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary. +  bhiksha_(base, entries + 1, max_next, config), +  next_source_(&next_source) { +  if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57)))  UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order.  Edit util/bit_packing.hh and fix the bit packing functions."); +  BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits());  } -template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float prob, float backoff) { +template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::Insert(WordIndex word, float prob, float backoff) {    assert(word <= word_mask_);    uint64_t at_pointer = insert_index_ * total_bits_; @@ -75,47 +81,42 @@ template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float    quant_.Write(base_, at_pointer, prob, backoff);    at_pointer += quant_.TotalBits();    uint64_t next = next_source_->InsertIndex(); -  assert(next <= next_mask_); -  util::WriteInt57(base_, at_pointer, next_bits_, next); +  bhiksha_.WriteNext(base_, at_pointer, insert_index_, next);    ++insert_index_;  } -template <class Quant> bool BitPackedMiddle<Quant>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {    uint64_t at_pointer;    if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {      return false;    } +  uint64_t index = at_pointer;    at_pointer *= total_bits_;    at_pointer += word_bits_;    quant_.Read(base_, at_pointer, prob, backoff);    at_pointer += quant_.TotalBits(); -  range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); -  // Read the next entry's pointer.   -  at_pointer += total_bits_; -  range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); +  bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); +    return true;  } -template <class Quant> bool BitPackedMiddle<Quant>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { -  uint64_t at_pointer; -  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; -  at_pointer *= total_bits_; +template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { +  uint64_t index; +  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false; +  uint64_t at_pointer = index * total_bits_;    at_pointer += word_bits_;    quant_.ReadBackoff(base_, at_pointer, backoff);    at_pointer += quant_.TotalBits(); -  range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); -  // Read the next entry's pointer.   -  at_pointer += total_bits_; -  range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); +  bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);    return true;  } -template <class Quant> void BitPackedMiddle<Quant>::FinishedLoading(uint64_t next_end) { -  assert(next_end <= next_mask_); -  uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; -  util::WriteInt57(base_, last_next_write, next_bits_, next_end); +template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) { +  uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); +  bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); +  bhiksha_.FinishedLoading(config);  }  template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) { @@ -135,8 +136,10 @@ template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float    return true;  } -template class BitPackedMiddle<DontQuantize::Middle>; -template class BitPackedMiddle<SeparatelyQuantize::Middle>; +template class BitPackedMiddle<DontQuantize::Middle, DontBhiksha>; +template class BitPackedMiddle<DontQuantize::Middle, ArrayBhiksha>; +template class BitPackedMiddle<SeparatelyQuantize::Middle, DontBhiksha>; +template class BitPackedMiddle<SeparatelyQuantize::Middle, ArrayBhiksha>;  template class BitPackedLongest<DontQuantize::Longest>;  template class BitPackedLongest<SeparatelyQuantize::Longest>; diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 8fa21aaf..53612064 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -10,6 +10,7 @@  namespace lm {  namespace ngram { +class Config;  namespace trie {  struct NodeRange { @@ -46,13 +47,12 @@ class Unigram {      void LoadedBinary() {} -    bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { +    void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {        UnigramValue *val = unigram_ + word;        prob = val->weights.prob;        backoff = val->weights.backoff;        next.begin = val->next;        next.end = (val+1)->next; -      return true;      }    private: @@ -67,8 +67,6 @@ class BitPacked {        return insert_index_;      } -    void LoadedBinary() {} -    protected:      static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); @@ -83,30 +81,30 @@ class BitPacked {      uint64_t insert_index_, max_vocab_;  }; -template <class Quant> class BitPackedMiddle : public BitPacked { +template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked {    public: -    static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); +    static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);      // next_source need not be initialized.   -    BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); +    BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);      void Insert(WordIndex word, float prob, float backoff); +    void FinishedLoading(uint64_t next_end, const Config &config); + +    void LoadedBinary() { bhiksha_.LoadedBinary(); } +      bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;      bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; -    void FinishedLoading(uint64_t next_end); -    private:      Quant quant_; -    uint8_t next_bits_; -    uint64_t next_mask_; +    Bhiksha bhiksha_;      const BitPacked *next_source_;  }; -  template <class Quant> class BitPackedLongest : public BitPacked {    public:      static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { @@ -120,6 +118,8 @@ template <class Quant> class BitPackedLongest : public BitPacked {        BaseInit(base, max_vocab, quant_.TotalBits());      } +    void LoadedBinary() {} +      void Insert(WordIndex word, float prob);      bool Find(WordIndex word, float &prob, const NodeRange &node) const; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 7defd5c1..04979d51 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -37,14 +37,14 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {    WordIndex index = 0;    while (true) {      ssize_t got = read(fd, &buf[0], kInitialRead); -    if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); +    UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words");      if (got == 0) return index;      buf.resize(got);      while (buf[buf.size() - 1]) {        char next_char;        ssize_t ret = read(fd, &next_char, 1); -      if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); -      if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word."); +      UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words"); +      UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word.");        buf.push_back(next_char);      }      // Ok now we have null terminated strings.   diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index c92518e4..9d218fff 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -61,6 +61,7 @@ class SortedVocabulary : public base::Vocabulary {        }      } +    // Size for purposes of file writing      static size_t Size(std::size_t entries, const Config &config);      // Vocab words are [0, Bound())  Only valid after FinishedLoading/LoadedBinary.   @@ -77,6 +78,9 @@ class SortedVocabulary : public base::Vocabulary {      // Reorders reorder_vocab so that the IDs are sorted.        void FinishedLoading(ProbBackoff *reorder_vocab); +    // Trie stores the correct counts including <unk> in the header.  If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>. +    std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); } +      bool SawUnk() const { return saw_unk_; }      void LoadedBinary(int fd, EnumerateVocab *to); diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index b35d80c8..9f47d559 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -107,9 +107,20 @@ void BitPackingSanity();  uint8_t RequiredBits(uint64_t max_value);  struct BitsMask { +  static BitsMask ByMax(uint64_t max_value) { +    BitsMask ret; +    ret.FromMax(max_value); +    return ret; +  } +  static BitsMask ByBits(uint8_t bits) { +    BitsMask ret; +    ret.bits = bits; +    ret.mask = (1ULL << bits) - 1; +    return ret; +  }    void FromMax(uint64_t max_value) {      bits = RequiredBits(max_value); -    mask = (1 << bits) - 1; +    mask = (1ULL << bits) - 1;    }    uint8_t bits;    uint64_t mask; diff --git a/klm/util/exception.cc b/klm/util/exception.cc index 84f9fe7c..62280970 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -1,5 +1,9 @@  #include "util/exception.hh" +#ifdef __GXX_RTTI +#include <typeinfo> +#endif +  #include <errno.h>  #include <string.h> @@ -22,6 +26,30 @@ const char *Exception::what() const throw() {    return text_.c_str();  } +void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) { +  /* The child class might have set some text, but we want this to come first. +   * Another option would be passing this information to the constructor, but +   * then child classes would have to accept constructor arguments and pass +   * them down.   +   */ +  text_ = stream_.str(); +  stream_.str(""); +  stream_ << file << ':' << line; +  if (func) stream_ << " in " << func << " threw "; +  if (child_name) { +    stream_ << child_name; +  } else { +#ifdef __GXX_RTTI +    stream_ << typeid(this).name(); +#else +    stream_ << "an exception"; +#endif +  } +  if (condition) stream_ << " because `" << condition; +  stream_ << "'.\n"; +  stream_ << text_; +} +  namespace {  // The XOPEN version.  const char *HandleStrerror(int ret, const char *buf) { diff --git a/klm/util/exception.hh b/klm/util/exception.hh index c6936914..81675a57 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -1,8 +1,6 @@  #ifndef UTIL_EXCEPTION__  #define UTIL_EXCEPTION__ -#include "util/string_piece.hh" -  #include <exception>  #include <sstream>  #include <string> @@ -22,6 +20,14 @@ class Exception : public std::exception {      // Not threadsafe, but probably doesn't matter.  FWIW, Boost's exception guidance implies that what() isn't threadsafe.        const char *what() const throw(); +    // For use by the UTIL_THROW macros.   +    void SetLocation( +        const char *file, +        unsigned int line, +        const char *func, +        const char *child_name, +        const char *condition); +    private:      template <class Except, class Data> friend typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data); @@ -43,7 +49,49 @@ template <class Except, class Data> typename Except::template ExceptionTag<Excep    return e;  } -#define UTIL_THROW(Exception, Modify) { Exception UTIL_e; {UTIL_e << Modify;} throw UTIL_e; } +#ifdef __GNUC__ +#define UTIL_FUNC_NAME __PRETTY_FUNCTION__ +#else +#ifdef _WIN32 +#define UTIL_FUNC_NAME __FUNCTION__ +#else +#define UTIL_FUNC_NAME NULL +#endif +#endif + +#define UTIL_SET_LOCATION(UTIL_e, child, condition) do { \ +  (UTIL_e).SetLocation(__FILE__, __LINE__, UTIL_FUNC_NAME, (child), (condition)); \ +} while (0) + +/* Create an instance of Exception, add the message Modify, and throw it. + * Modify is appended to the what() message and can contain << for ostream + * operations.   + * + * do .. while kludge to swallow trailing ; character + * http://gcc.gnu.org/onlinedocs/cpp/Swallowing-the-Semicolon.html .   + */ +#define UTIL_THROW(Exception, Modify) do { \ +  Exception UTIL_e; \ +  UTIL_SET_LOCATION(UTIL_e, #Exception, NULL); \ +  UTIL_e << Modify; \ +  throw UTIL_e; \ +} while (0) + +#define UTIL_THROW_VAR(Var, Modify) do { \ +  Exception &UTIL_e = (Var); \ +  UTIL_SET_LOCATION(UTIL_e, NULL, NULL); \ +  UTIL_e << Modify; \ +  throw UTIL_e; \ +} while (0) + +#define UTIL_THROW_IF(Condition, Exception, Modify) do { \ +  if (Condition) { \ +    Exception UTIL_e; \ +    UTIL_SET_LOCATION(UTIL_e, #Exception, #Condition); \ +    UTIL_e << Modify; \ +    throw UTIL_e; \ +  } \ +} while (0)  class ErrnoException : public Exception {    public: @@ -51,7 +99,7 @@ class ErrnoException : public Exception {      virtual ~ErrnoException() throw(); -    int Error() { return errno_; } +    int Error() const throw() { return errno_; }    private:      int errno_; diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index f447a70c..cbe4234f 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -41,8 +41,8 @@ GZException::GZException(void *file) {  const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};  int OpenReadOrThrow(const char *name) { -  int ret = open(name, O_RDONLY); -  if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading"); +  int ret; +  UTIL_THROW_IF(-1 == (ret = open(name, O_RDONLY)), ErrnoException, "while opening " << name);    return ret;  } @@ -52,13 +52,13 @@ off_t SizeFile(int fd) {    return sb.st_size;  } -FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :  +FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) :     file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),    progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {    Initialize(name, show_progress, min_buffer);  } -FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :  +FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, off_t min_buffer)  :     file_(fd), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),    progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {    Initialize(name, show_progress, min_buffer); @@ -78,7 +78,7 @@ FilePiece::~FilePiece() {  #endif  } -StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) { +StringPiece FilePiece::ReadLine(char delim) {    size_t skip = 0;    while (true) {      for (const char *i = position_ + skip; i < position_end_; ++i) { @@ -97,20 +97,20 @@ StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileExcepti    }  } -float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) { +float FilePiece::ReadFloat() {    return ReadNumber<float>();  } -double FilePiece::ReadDouble() throw(GZException, EndOfFileException, ParseNumberException) { +double FilePiece::ReadDouble() {    return ReadNumber<double>();  } -long int FilePiece::ReadLong() throw(GZException, EndOfFileException, ParseNumberException) { +long int FilePiece::ReadLong() {    return ReadNumber<long int>();  } -unsigned long int FilePiece::ReadULong() throw(GZException, EndOfFileException, ParseNumberException) { +unsigned long int FilePiece::ReadULong() {    return ReadNumber<unsigned long int>();  } -void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) { +void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer)  {  #ifdef HAVE_ZLIB    gz_file_ = NULL;  #endif @@ -163,7 +163,7 @@ void ParseNumber(const char *begin, char *&end, unsigned long int &out) {  }  } // namespace -template <class T> T FilePiece::ReadNumber() throw(GZException, EndOfFileException, ParseNumberException) { +template <class T> T FilePiece::ReadNumber() {    SkipSpaces();    while (last_space_ < position_) {      if (at_end_) { @@ -186,7 +186,7 @@ template <class T> T FilePiece::ReadNumber() throw(GZException, EndOfFileExcepti    return ret;  } -const char *FilePiece::FindDelimiterOrEOF(const bool *delim) throw (GZException, EndOfFileException) { +const char *FilePiece::FindDelimiterOrEOF(const bool *delim)  {    size_t skip = 0;    while (true) {      for (const char *i = position_ + skip; i < position_end_; ++i) { @@ -201,7 +201,7 @@ const char *FilePiece::FindDelimiterOrEOF(const bool *delim) throw (GZException,    }  } -void FilePiece::Shift() throw(GZException, EndOfFileException) { +void FilePiece::Shift() {    if (at_end_) {      progress_.Finished();      throw EndOfFileException(); @@ -217,7 +217,7 @@ void FilePiece::Shift() throw(GZException, EndOfFileException) {    }  } -void FilePiece::MMapShift(off_t desired_begin) throw() { +void FilePiece::MMapShift(off_t desired_begin) {    // Use mmap.      off_t ignore = desired_begin % page_;    // Duplicate request for Shift means give more data.   @@ -259,25 +259,23 @@ void FilePiece::MMapShift(off_t desired_begin) throw() {    progress_.Set(desired_begin);  } -void FilePiece::TransitionToRead() throw (GZException) { +void FilePiece::TransitionToRead() {    assert(!fallback_to_read_);    fallback_to_read_ = true;    data_.reset();    data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); -  if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_); +  UTIL_THROW_IF(!data_.get(), ErrnoException, "malloc failed for " << default_map_size_);    position_ = data_.begin();    position_end_ = position_;  #ifdef HAVE_ZLIB    assert(!gz_file_);    gz_file_ = gzdopen(file_.get(), "r"); -  if (!gz_file_) { -    UTIL_THROW(GZException, "zlib failed to open " << file_name_); -  } +  UTIL_THROW_IF(!gz_file_, GZException, "zlib failed to open " << file_name_);  #endif  } -void FilePiece::ReadShift() throw(GZException, EndOfFileException) { +void FilePiece::ReadShift() {    assert(fallback_to_read_);    // Bytes [data_.begin(), position_) have been consumed.      // Bytes [position_, position_end_) have been read into the buffer.   @@ -297,7 +295,7 @@ void FilePiece::ReadShift() throw(GZException, EndOfFileException) {        std::size_t valid_length = position_end_ - position_;        default_map_size_ *= 2;        data_.call_realloc(default_map_size_); -      if (!data_.get()) UTIL_THROW(ErrnoException, "realloc failed for " << default_map_size_); +      UTIL_THROW_IF(!data_.get(), ErrnoException, "realloc failed for " << default_map_size_);        position_ = data_.begin();        position_end_ = position_ + valid_length;      } else { @@ -320,7 +318,7 @@ void FilePiece::ReadShift() throw(GZException, EndOfFileException) {    }  #else    read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read); -  if (read_return == -1) UTIL_THROW(ErrnoException, "read failed"); +  UTIL_THROW_IF(read_return == -1, ErrnoException, "read failed");    progress_.Set(mapped_offset_);  #endif    if (read_return == 0) { diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index 870ae5a3..a5c00910 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -45,13 +45,13 @@ off_t SizeFile(int fd);  class FilePiece {    public:      // 32 MB default. -    explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException); +    explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);      // Takes ownership of fd.  name is used for messages.   -    explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException); +    explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);      ~FilePiece(); -    char get() throw(GZException, EndOfFileException) {  +    char get() {         if (position_ == position_end_) {          Shift();          if (at_end_) throw EndOfFileException(); @@ -60,22 +60,22 @@ class FilePiece {      }      // Leaves the delimiter, if any, to be returned by get().  Delimiters defined by isspace().   -    StringPiece ReadDelimited(const bool *delim = kSpaces) throw(GZException, EndOfFileException) { +    StringPiece ReadDelimited(const bool *delim = kSpaces) {        SkipSpaces(delim);        return Consume(FindDelimiterOrEOF(delim));      }      // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.      // It is similar to getline in that way.   -    StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException); +    StringPiece ReadLine(char delim = '\n'); -    float ReadFloat() throw(GZException, EndOfFileException, ParseNumberException); -    double ReadDouble() throw(GZException, EndOfFileException, ParseNumberException); -    long int ReadLong() throw(GZException, EndOfFileException, ParseNumberException); -    unsigned long int ReadULong() throw(GZException, EndOfFileException, ParseNumberException); +    float ReadFloat(); +    double ReadDouble(); +    long int ReadLong(); +    unsigned long int ReadULong();      // Skip spaces defined by isspace.   -    void SkipSpaces(const bool *delim = kSpaces) throw (GZException, EndOfFileException) { +    void SkipSpaces(const bool *delim = kSpaces) {        for (; ; ++position_) {          if (position_ == position_end_) Shift();          if (!delim[static_cast<unsigned char>(*position_)]) return; @@ -89,9 +89,9 @@ class FilePiece {      const std::string &FileName() const { return file_name_; }    private: -    void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw(GZException); +    void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer); -    template <class T> T ReadNumber() throw(GZException, EndOfFileException, ParseNumberException); +    template <class T> T ReadNumber();      StringPiece Consume(const char *to) {        StringPiece ret(position_, to - position_); @@ -99,14 +99,14 @@ class FilePiece {        return ret;      } -    const char *FindDelimiterOrEOF(const bool *delim = kSpaces) throw (GZException, EndOfFileException); +    const char *FindDelimiterOrEOF(const bool *delim = kSpaces); -    void Shift() throw (EndOfFileException, GZException); +    void Shift();      // Backends to Shift(). -    void MMapShift(off_t desired_begin) throw (); +    void MMapShift(off_t desired_begin); -    void TransitionToRead() throw (GZException); -    void ReadShift() throw (GZException, EndOfFileException); +    void TransitionToRead(); +    void ReadShift();      const char *position_, *last_space_, *position_end_; diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index d58a0727..fec47fd9 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -1,129 +1,129 @@ -/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All
 - * code is released to the public domain. For business purposes, Murmurhash is
 - * under the MIT license."
 - * This is modified from the original:
 - * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit.  
 - * length changed to unsigned int.  
 - * placed in namespace util
 - * add MurmurHashNative
 - * default option = 0 for seed
 - */
 -
 -#include "util/murmur_hash.hh"
 -
 -namespace util {
 -
 -//-----------------------------------------------------------------------------
 -// MurmurHash2, 64-bit versions, by Austin Appleby
 -
 -// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment 
 -// and endian-ness issues if used across multiple platforms.
 -
 -// 64-bit hash for 64-bit platforms
 -
 -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed )
 -{
 -  const uint64_t m = 0xc6a4a7935bd1e995ULL;
 -  const int r = 47;
 -
 -  uint64_t h = seed ^ (len * m);
 -
 -  const uint64_t * data = (const uint64_t *)key;
 -  const uint64_t * end = data + (len/8);
 -
 -  while(data != end)
 -  {
 -    uint64_t k = *data++;
 -
 -    k *= m; 
 -    k ^= k >> r; 
 -    k *= m; 
 -    
 -    h ^= k;
 -    h *= m; 
 -  }
 -
 -  const unsigned char * data2 = (const unsigned char*)data;
 -
 -  switch(len & 7)
 -  {
 -  case 7: h ^= uint64_t(data2[6]) << 48;
 -  case 6: h ^= uint64_t(data2[5]) << 40;
 -  case 5: h ^= uint64_t(data2[4]) << 32;
 -  case 4: h ^= uint64_t(data2[3]) << 24;
 -  case 3: h ^= uint64_t(data2[2]) << 16;
 -  case 2: h ^= uint64_t(data2[1]) << 8;
 -  case 1: h ^= uint64_t(data2[0]);
 -          h *= m;
 -  };
 - 
 -  h ^= h >> r;
 -  h *= m;
 -  h ^= h >> r;
 -
 -  return h;
 -} 
 -
 -
 -// 64-bit hash for 32-bit platforms
 -
 -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed )
 -{
 -  const unsigned int m = 0x5bd1e995;
 -  const int r = 24;
 -
 -  unsigned int h1 = seed ^ len;
 -  unsigned int h2 = 0;
 -
 -  const unsigned int * data = (const unsigned int *)key;
 -
 -  while(len >= 8)
 -  {
 -    unsigned int k1 = *data++;
 -    k1 *= m; k1 ^= k1 >> r; k1 *= m;
 -    h1 *= m; h1 ^= k1;
 -    len -= 4;
 -
 -    unsigned int k2 = *data++;
 -    k2 *= m; k2 ^= k2 >> r; k2 *= m;
 -    h2 *= m; h2 ^= k2;
 -    len -= 4;
 -  }
 -
 -  if(len >= 4)
 -  {
 -    unsigned int k1 = *data++;
 -    k1 *= m; k1 ^= k1 >> r; k1 *= m;
 -    h1 *= m; h1 ^= k1;
 -    len -= 4;
 -  }
 -
 -  switch(len)
 -  {
 -  case 3: h2 ^= ((unsigned char*)data)[2] << 16;
 -  case 2: h2 ^= ((unsigned char*)data)[1] << 8;
 -  case 1: h2 ^= ((unsigned char*)data)[0];
 -      h2 *= m;
 -  };
 -
 -  h1 ^= h2 >> 18; h1 *= m;
 -  h2 ^= h1 >> 22; h2 *= m;
 -  h1 ^= h2 >> 17; h1 *= m;
 -  h2 ^= h1 >> 19; h2 *= m;
 -
 -  uint64_t h = h1;
 -
 -  h = (h << 32) | h2;
 -
 -  return h;
 -}
 -
 -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) {
 -  if (sizeof(int) == 4) {
 -    return MurmurHash64B(key, len, seed);
 -  } else {
 -    return MurmurHash64A(key, len, seed);
 -  }
 -}
 -
 -} // namespace util
 +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit.   + * length changed to unsigned int.   + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment  +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ +  const uint64_t m = 0xc6a4a7935bd1e995ULL; +  const int r = 47; + +  uint64_t h = seed ^ (len * m); + +  const uint64_t * data = (const uint64_t *)key; +  const uint64_t * end = data + (len/8); + +  while(data != end) +  { +    uint64_t k = *data++; + +    k *= m;  +    k ^= k >> r;  +    k *= m;  +     +    h ^= k; +    h *= m;  +  } + +  const unsigned char * data2 = (const unsigned char*)data; + +  switch(len & 7) +  { +  case 7: h ^= uint64_t(data2[6]) << 48; +  case 6: h ^= uint64_t(data2[5]) << 40; +  case 5: h ^= uint64_t(data2[4]) << 32; +  case 4: h ^= uint64_t(data2[3]) << 24; +  case 3: h ^= uint64_t(data2[2]) << 16; +  case 2: h ^= uint64_t(data2[1]) << 8; +  case 1: h ^= uint64_t(data2[0]); +          h *= m; +  }; +  +  h ^= h >> r; +  h *= m; +  h ^= h >> r; + +  return h; +}  + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ +  const unsigned int m = 0x5bd1e995; +  const int r = 24; + +  unsigned int h1 = seed ^ len; +  unsigned int h2 = 0; + +  const unsigned int * data = (const unsigned int *)key; + +  while(len >= 8) +  { +    unsigned int k1 = *data++; +    k1 *= m; k1 ^= k1 >> r; k1 *= m; +    h1 *= m; h1 ^= k1; +    len -= 4; + +    unsigned int k2 = *data++; +    k2 *= m; k2 ^= k2 >> r; k2 *= m; +    h2 *= m; h2 ^= k2; +    len -= 4; +  } + +  if(len >= 4) +  { +    unsigned int k1 = *data++; +    k1 *= m; k1 ^= k1 >> r; k1 *= m; +    h1 *= m; h1 ^= k1; +    len -= 4; +  } + +  switch(len) +  { +  case 3: h2 ^= ((unsigned char*)data)[2] << 16; +  case 2: h2 ^= ((unsigned char*)data)[1] << 8; +  case 1: h2 ^= ((unsigned char*)data)[0]; +      h2 *= m; +  }; + +  h1 ^= h2 >> 18; h1 *= m; +  h2 ^= h1 >> 22; h2 *= m; +  h1 ^= h2 >> 17; h1 *= m; +  h2 ^= h1 >> 19; h2 *= m; + +  uint64_t h = h1; + +  h = (h << 32) | h2; + +  return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { +  if (sizeof(int) == 4) { +    return MurmurHash64B(key, len, seed); +  } else { +    return MurmurHash64A(key, len, seed); +  } +} + +} // namespace util diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 00be0ed7..2ec342a6 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -57,7 +57,7 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac          equal_(equal_func),          entries_(0)  #ifdef DEBUG -        , initialized_(true), +        , initialized_(true)  #endif      {} diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 84d7aa02..0d6ecbbd 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -12,7 +12,7 @@ namespace util {  template <class T> class IdentityAccessor {    public:      typedef T Key; -    T operator()(const uint64_t *in) const { return *in; } +    T operator()(const T *in) const { return *in; }  };  struct Pivot64 { @@ -101,6 +101,27 @@ template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(co    return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out);  } +// May return begin - 1. +template <class Iterator, class Accessor> Iterator BinaryBelow( +    const Accessor &accessor, +    Iterator begin, +    Iterator end, +    const typename Accessor::Key key) { +  while (end > begin) { +    Iterator pivot(begin + (end - begin) / 2); +    typename Accessor::Key mid(accessor(pivot)); +    if (mid < key) { +      begin = pivot + 1; +    } else if (mid > key) { +      end = pivot; +    } else { +      for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {} +      return pivot - 1; +    } +  } +  return begin - 1; +} +  // To use this template, you need to define a Pivot function to match Key.    template <class PackingT> class SortedUniformMap {    public:  | 
