diff options
| author | Paul Baltescu <pauldb89@gmail.com> | 2013-06-19 15:06:34 +0100 | 
|---|---|---|
| committer | Paul Baltescu <pauldb89@gmail.com> | 2013-06-19 15:06:34 +0100 | 
| commit | 459775095b46b4625ce26ea5a34001ec74ab3aa8 (patch) | |
| tree | 844d1a650a302114ae619d37b8778ab66207a834 /klm/lm | |
| parent | 02099a01350a41a99ec400e9b29df08a01d88979 (diff) | |
| parent | 0dc7755f7fb1ef15db5a60c70866aa61b6367898 (diff) | |
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'klm/lm')
| -rw-r--r-- | klm/lm/builder/lmplz_main.cc | 15 | ||||
| -rw-r--r-- | klm/lm/builder/ngram.hh | 2 | ||||
| -rw-r--r-- | klm/lm/model.cc | 21 | ||||
| -rw-r--r-- | klm/lm/model.hh | 5 | ||||
| -rw-r--r-- | klm/lm/search_hashed.cc | 29 | ||||
| -rw-r--r-- | klm/lm/search_hashed.hh | 19 | ||||
| -rw-r--r-- | klm/lm/search_trie.hh | 3 | ||||
| -rw-r--r-- | klm/lm/state.hh | 2 | ||||
| -rw-r--r-- | klm/lm/virtual_interface.hh | 3 | ||||
| -rw-r--r-- | klm/lm/vocab.hh | 2 | 
10 files changed, 63 insertions, 38 deletions
| diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc index 1e086dcc..c87abdb8 100644 --- a/klm/lm/builder/lmplz_main.cc +++ b/klm/lm/builder/lmplz_main.cc @@ -52,13 +52,14 @@ int main(int argc, char *argv[]) {        std::cerr <<           "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"          "Please cite:\n" -        "@inproceedings{kenlm,\n" -        "author    = {Kenneth Heafield},\n" -        "title     = {{KenLM}: Faster and Smaller Language Model Queries},\n" -        "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n" -        "month     = {July}, year={2011},\n" -        "address   = {Edinburgh, UK},\n" -        "publisher = {Association for Computational Linguistics},\n" +        "@inproceedings{Heafield-estimate,\n" +        "  author = {Kenneth Heafield and Ivan Pouzyrevsky and Jonathan H. Clark and Philipp Koehn},\n" +        "  title = {Scalable Modified {Kneser-Ney} Language Model Estimation},\n" +        "  year = {2013},\n" +        "  month = {8},\n" +        "  booktitle = {Proceedings of the 51st Annual Meeting of the Association for Computational Linguistics},\n" +        "  address = {Sofia, Bulgaria},\n" +        "  url = {http://kheafield.com/professional/edinburgh/estimate\\_paper.pdf},\n"          "}\n\n"          "Provide the corpus on stdin.  The ARPA file will be written to stdout.  Order of\n"          "the model (-o) is the only mandatory option.  As this is an on-disk program,\n" diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh index 2984ed0b..f5681516 100644 --- a/klm/lm/builder/ngram.hh +++ b/klm/lm/builder/ngram.hh @@ -53,7 +53,7 @@ class NGram {      Payload &Value() { return *reinterpret_cast<Payload *>(end_); }      uint64_t &Count() { return Value().count; } -    const uint64_t Count() const { return Value().count; } +    uint64_t Count() const { return Value().count; }      std::size_t Order() const { return end_ - begin_; } diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a40fd2fb..a26654a6 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -304,5 +304,26 @@ template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiks  template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;  } // namespace detail + +base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) { +  RecognizeBinary(file_name, model_type); +  switch (model_type) { +    case PROBING: +      return new ProbingModel(file_name, config); +    case REST_PROBING: +      return new RestProbingModel(file_name, config); +    case TRIE: +      return new TrieModel(file_name, config); +    case QUANT_TRIE: +      return new QuantTrieModel(file_name, config); +    case ARRAY_TRIE: +      return new ArrayTrieModel(file_name, config); +    case QUANT_ARRAY_TRIE: +      return new QuantArrayTrieModel(file_name, config); +    default: +      UTIL_THROW(FormatLoadException, "Confused by model type " << model_type); +  } +} +  } // namespace ngram  } // namespace lm diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 13ff864e..60f55110 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -153,6 +153,11 @@ LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<Separat  typedef ::lm::ngram::ProbingVocabulary Vocabulary;  typedef ProbingModel Model; +/* Autorecognize the file type, load, and return the virtual base class.  Don't + * use the virtual base class if you can avoid it.  Instead, use the above + * classes as template arguments to your own virtual feature function.*/ +base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING); +  } // namespace ngram  } // namespace lm diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 2d6f15b2..62275d27 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -54,7 +54,7 @@ template <class Weights> class ActivateUnigram {      Weights *modify_;  }; -// Find the lower order entry, inserting blanks along the way as necessary.   +// Find the lower order entry, inserting blanks along the way as necessary.  template <class Value> void FindLower(      const std::vector<uint64_t> &keys,      typename Value::Weights &unigram, @@ -64,7 +64,7 @@ template <class Value> void FindLower(    typename Value::ProbingEntry entry;    // Backoff will always be 0.0.  We'll get the probability and rest in another pass.    entry.value.backoff = kNoExtensionBackoff; -  // Go back and find the longest right-aligned entry, informing it that it extends left.  Normally this will match immediately, but sometimes SRI is dumb.   +  // Go back and find the longest right-aligned entry, informing it that it extends left.  Normally this will match immediately, but sometimes SRI is dumb.    for (int lower = keys.size() - 2; ; --lower) {      if (lower == -1) {        between.push_back(&unigram); @@ -77,11 +77,11 @@ template <class Value> void FindLower(    }  } -// Between usually has  single entry, the value to adjust.  But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.   +// Between usually has  single entry, the value to adjust.  But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.  template <class Added, class Build> void AdjustLower(      const Added &added,      const Build &build, -    std::vector<typename Build::Value::Weights *> &between,  +    std::vector<typename Build::Value::Weights *> &between,      const unsigned int n,      const std::vector<WordIndex> &vocab_ids,      typename Build::Value::Weights *unigrams, @@ -93,14 +93,14 @@ template <class Added, class Build> void AdjustLower(    }    typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;    float prob = -fabs(between.back()->prob); -  // Order of the n-gram on which probabilities are based.   +  // Order of the n-gram on which probabilities are based.    unsigned char basis = n - between.size();    assert(basis != 0);    typename Build::Value::Weights **change = &between.back();    // Skip the basis.    --change;    if (basis == 1) { -    // Hallucinate a bigram based on a unigram's backoff and a unigram probability.   +    // Hallucinate a bigram based on a unigram's backoff and a unigram probability.      float &backoff = unigrams[vocab_ids[1]].backoff;      SetExtension(backoff);      prob += backoff; @@ -128,14 +128,14 @@ template <class Added, class Build> void AdjustLower(    typename std::vector<typename Value::Weights *>::const_iterator i(between.begin());    build.MarkExtends(**i, added);    const typename Value::Weights *longer = *i; -  // Everything has probability but is not marked as extending.   +  // Everything has probability but is not marked as extending.    for (++i; i != between.end(); ++i) {      build.MarkExtends(**i, *longer);      longer = *i;    }  } -// Continue marking lower entries even they know that they extend left.  This is used for upper/lower bounds.   +// Continue marking lower entries even they know that they extend left.  This is used for upper/lower bounds.  template <class Build> void MarkLower(      const std::vector<uint64_t> &keys,      const Build &build, @@ -144,15 +144,15 @@ template <class Build> void MarkLower(      int start_order,      const typename Build::Value::Weights &longer) {    if (start_order == 0) return; -  typename util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash>::MutableIterator iter; -  // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.   +  // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.    for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) {      if (even_lower == -1) {        build.MarkExtends(unigram, longer);        return;      } -    middle[even_lower].UnsafeMutableFind(keys[even_lower], iter); -    if (!build.MarkExtends(iter->value, longer)) return; +    if (!build.MarkExtends( +          middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value, +          longer)) return;    }  } @@ -168,7 +168,6 @@ template <class Build, class Activate, class Store> void ReadNGrams(      Store &store,      PositiveProbWarn &warn) {    typedef typename Build::Value Value; -  typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;    assert(n >= 2);    ReadNGramHeader(f, n); @@ -186,7 +185,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(      for (unsigned int h = 1; h < n - 1; ++h) {        keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);      } -    // Initially the sign bit is on, indicating it does not extend left.  Most already have this but there might +0.0.   +    // Initially the sign bit is on, indicating it does not extend left.  Most already have this but there might +0.0.      util::SetSign(entry.value.prob);      entry.key = keys[n-2]; @@ -203,7 +202,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(  } // namespace  namespace detail { -  +  template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {    std::size_t allocated = Unigram::Size(counts[0]);    unigram_ = Unigram(start, counts[0], allocated); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 00595796..9d067bc2 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -71,7 +71,7 @@ template <class Value> class HashedSearch {      static const bool kDifferentRest = Value::kDifferentRest;      static const unsigned int kVersion = 0; -    // TODO: move probing_multiplier here with next binary file format update.   +    // TODO: move probing_multiplier here with next binary file format update.      static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}      static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { @@ -102,14 +102,9 @@ template <class Value> class HashedSearch {        return ret;      } -#pragma GCC diagnostic ignored "-Wuninitialized"      MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {        node = extend_pointer; -      typename Middle::ConstIterator found; -      bool got = middle_[extend_length - 2].Find(extend_pointer, found); -      assert(got); -      (void)got; -      return MiddlePointer(found->value); +      return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value);      }      MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const { @@ -126,14 +121,14 @@ template <class Value> class HashedSearch {      }      LongestPointer LookupLongest(WordIndex word, const Node &node) const { -      // Sign bit is always on because longest n-grams do not extend left.   +      // Sign bit is always on because longest n-grams do not extend left.        typename Longest::ConstIterator found;        if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer();        return LongestPointer(found->value.prob);      } -    // Generate a node without necessarily checking that it actually exists.   -    // Optionally return false if it's know to not exist.   +    // Generate a node without necessarily checking that it actually exists. +    // Optionally return false if it's know to not exist.      bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {        assert(begin != end);        node = static_cast<Node>(*begin); @@ -144,7 +139,7 @@ template <class Value> class HashedSearch {      }    private: -    // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.   +    // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.      void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);      template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); @@ -153,7 +148,7 @@ template <class Value> class HashedSearch {        public:          Unigram() {} -        Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :  +        Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :            unigram_(static_cast<typename Value::Weights*>(start))  #ifdef DEBUG           ,  count_(count) diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 1264baf5..60be416b 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -41,7 +41,8 @@ template <class Quant, class Bhiksha> class TrieSearch {      static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {        Quant::UpdateConfigFromBinary(fd, counts, config);        util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); -      Bhiksha::UpdateConfigFromBinary(fd, config); +      // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2. +      if (counts.size() > 2) Bhiksha::UpdateConfigFromBinary(fd, config);      }      static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { diff --git a/klm/lm/state.hh b/klm/lm/state.hh index d8e6c132..a6b9accb 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -91,7 +91,7 @@ inline uint64_t hash_value(const Left &left) {  }  struct ChartState { -  bool operator==(const ChartState &other) { +  bool operator==(const ChartState &other) const {      return (right == other.right) && (left == other.left);    } diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 6a5a0196..17f064b2 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -6,6 +6,7 @@  #include "util/string_piece.hh"  #include <string> +#include <string.h>  namespace lm {  namespace base { @@ -119,7 +120,9 @@ class Model {      size_t StateSize() const { return state_size_; }      const void *BeginSentenceMemory() const { return begin_sentence_memory_; } +    void BeginSentenceWrite(void *to) const { memcpy(to, begin_sentence_memory_, StateSize()); }      const void *NullContextMemory() const { return null_context_memory_; } +    void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); }      // Requires in_state != out_state      virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 3902f117..226ae438 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -25,7 +25,7 @@ uint64_t HashForVocab(const char *str, std::size_t len);  inline uint64_t HashForVocab(const StringPiece &str) {    return HashForVocab(str.data(), str.length());  } -class ProbingVocabularyHeader; +struct ProbingVocabularyHeader;  } // namespace detail  class WriteWordsWrapper : public EnumerateVocab { | 
