diff options
author | Paul Baltescu <pauldb89@gmail.com> | 2013-06-19 16:46:52 +0100 |
---|---|---|
committer | Paul Baltescu <pauldb89@gmail.com> | 2013-06-19 16:46:52 +0100 |
commit | 22e6ab01aebca3e9012b07f9600153c7b593996e (patch) | |
tree | 844d1a650a302114ae619d37b8778ab66207a834 /klm/lm | |
parent | 02099a01350a41a99ec400e9b29df08a01d88979 (diff) | |
parent | 0dc7755f7fb1ef15db5a60c70866aa61b6367898 (diff) |
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'klm/lm')
-rw-r--r-- | klm/lm/builder/lmplz_main.cc | 15 | ||||
-rw-r--r-- | klm/lm/builder/ngram.hh | 2 | ||||
-rw-r--r-- | klm/lm/model.cc | 21 | ||||
-rw-r--r-- | klm/lm/model.hh | 5 | ||||
-rw-r--r-- | klm/lm/search_hashed.cc | 29 | ||||
-rw-r--r-- | klm/lm/search_hashed.hh | 19 | ||||
-rw-r--r-- | klm/lm/search_trie.hh | 3 | ||||
-rw-r--r-- | klm/lm/state.hh | 2 | ||||
-rw-r--r-- | klm/lm/virtual_interface.hh | 3 | ||||
-rw-r--r-- | klm/lm/vocab.hh | 2 |
10 files changed, 63 insertions, 38 deletions
diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc index 1e086dcc..c87abdb8 100644 --- a/klm/lm/builder/lmplz_main.cc +++ b/klm/lm/builder/lmplz_main.cc @@ -52,13 +52,14 @@ int main(int argc, char *argv[]) { std::cerr << "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n" "Please cite:\n" - "@inproceedings{kenlm,\n" - "author = {Kenneth Heafield},\n" - "title = {{KenLM}: Faster and Smaller Language Model Queries},\n" - "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n" - "month = {July}, year={2011},\n" - "address = {Edinburgh, UK},\n" - "publisher = {Association for Computational Linguistics},\n" + "@inproceedings{Heafield-estimate,\n" + " author = {Kenneth Heafield and Ivan Pouzyrevsky and Jonathan H. Clark and Philipp Koehn},\n" + " title = {Scalable Modified {Kneser-Ney} Language Model Estimation},\n" + " year = {2013},\n" + " month = {8},\n" + " booktitle = {Proceedings of the 51st Annual Meeting of the Association for Computational Linguistics},\n" + " address = {Sofia, Bulgaria},\n" + " url = {http://kheafield.com/professional/edinburgh/estimate\\_paper.pdf},\n" "}\n\n" "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n" "the model (-o) is the only mandatory option. As this is an on-disk program,\n" diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh index 2984ed0b..f5681516 100644 --- a/klm/lm/builder/ngram.hh +++ b/klm/lm/builder/ngram.hh @@ -53,7 +53,7 @@ class NGram { Payload &Value() { return *reinterpret_cast<Payload *>(end_); } uint64_t &Count() { return Value().count; } - const uint64_t Count() const { return Value().count; } + uint64_t Count() const { return Value().count; } std::size_t Order() const { return end_ - begin_; } diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a40fd2fb..a26654a6 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -304,5 +304,26 @@ template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiks template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>; } // namespace detail + +base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) { + RecognizeBinary(file_name, model_type); + switch (model_type) { + case PROBING: + return new ProbingModel(file_name, config); + case REST_PROBING: + return new RestProbingModel(file_name, config); + case TRIE: + return new TrieModel(file_name, config); + case QUANT_TRIE: + return new QuantTrieModel(file_name, config); + case ARRAY_TRIE: + return new ArrayTrieModel(file_name, config); + case QUANT_ARRAY_TRIE: + return new QuantArrayTrieModel(file_name, config); + default: + UTIL_THROW(FormatLoadException, "Confused by model type " << model_type); + } +} + } // namespace ngram } // namespace lm diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 13ff864e..60f55110 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -153,6 +153,11 @@ LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<Separat typedef ::lm::ngram::ProbingVocabulary Vocabulary; typedef ProbingModel Model; +/* Autorecognize the file type, load, and return the virtual base class. Don't + * use the virtual base class if you can avoid it. Instead, use the above + * classes as template arguments to your own virtual feature function.*/ +base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING); + } // namespace ngram } // namespace lm diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 2d6f15b2..62275d27 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -54,7 +54,7 @@ template <class Weights> class ActivateUnigram { Weights *modify_; }; -// Find the lower order entry, inserting blanks along the way as necessary. +// Find the lower order entry, inserting blanks along the way as necessary. template <class Value> void FindLower( const std::vector<uint64_t> &keys, typename Value::Weights &unigram, @@ -64,7 +64,7 @@ template <class Value> void FindLower( typename Value::ProbingEntry entry; // Backoff will always be 0.0. We'll get the probability and rest in another pass. entry.value.backoff = kNoExtensionBackoff; - // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. + // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. for (int lower = keys.size() - 2; ; --lower) { if (lower == -1) { between.push_back(&unigram); @@ -77,11 +77,11 @@ template <class Value> void FindLower( } } -// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here. +// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here. template <class Added, class Build> void AdjustLower( const Added &added, const Build &build, - std::vector<typename Build::Value::Weights *> &between, + std::vector<typename Build::Value::Weights *> &between, const unsigned int n, const std::vector<WordIndex> &vocab_ids, typename Build::Value::Weights *unigrams, @@ -93,14 +93,14 @@ template <class Added, class Build> void AdjustLower( } typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle; float prob = -fabs(between.back()->prob); - // Order of the n-gram on which probabilities are based. + // Order of the n-gram on which probabilities are based. unsigned char basis = n - between.size(); assert(basis != 0); typename Build::Value::Weights **change = &between.back(); // Skip the basis. --change; if (basis == 1) { - // Hallucinate a bigram based on a unigram's backoff and a unigram probability. + // Hallucinate a bigram based on a unigram's backoff and a unigram probability. float &backoff = unigrams[vocab_ids[1]].backoff; SetExtension(backoff); prob += backoff; @@ -128,14 +128,14 @@ template <class Added, class Build> void AdjustLower( typename std::vector<typename Value::Weights *>::const_iterator i(between.begin()); build.MarkExtends(**i, added); const typename Value::Weights *longer = *i; - // Everything has probability but is not marked as extending. + // Everything has probability but is not marked as extending. for (++i; i != between.end(); ++i) { build.MarkExtends(**i, *longer); longer = *i; } } -// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds. +// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds. template <class Build> void MarkLower( const std::vector<uint64_t> &keys, const Build &build, @@ -144,15 +144,15 @@ template <class Build> void MarkLower( int start_order, const typename Build::Value::Weights &longer) { if (start_order == 0) return; - typename util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash>::MutableIterator iter; - // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code. + // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code. for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) { if (even_lower == -1) { build.MarkExtends(unigram, longer); return; } - middle[even_lower].UnsafeMutableFind(keys[even_lower], iter); - if (!build.MarkExtends(iter->value, longer)) return; + if (!build.MarkExtends( + middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value, + longer)) return; } } @@ -168,7 +168,6 @@ template <class Build, class Activate, class Store> void ReadNGrams( Store &store, PositiveProbWarn &warn) { typedef typename Build::Value Value; - typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle; assert(n >= 2); ReadNGramHeader(f, n); @@ -186,7 +185,7 @@ template <class Build, class Activate, class Store> void ReadNGrams( for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); } - // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. + // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. util::SetSign(entry.value.prob); entry.key = keys[n-2]; @@ -203,7 +202,7 @@ template <class Build, class Activate, class Store> void ReadNGrams( } // namespace namespace detail { - + template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { std::size_t allocated = Unigram::Size(counts[0]); unigram_ = Unigram(start, counts[0], allocated); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 00595796..9d067bc2 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -71,7 +71,7 @@ template <class Value> class HashedSearch { static const bool kDifferentRest = Value::kDifferentRest; static const unsigned int kVersion = 0; - // TODO: move probing_multiplier here with next binary file format update. + // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { @@ -102,14 +102,9 @@ template <class Value> class HashedSearch { return ret; } -#pragma GCC diagnostic ignored "-Wuninitialized" MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const { node = extend_pointer; - typename Middle::ConstIterator found; - bool got = middle_[extend_length - 2].Find(extend_pointer, found); - assert(got); - (void)got; - return MiddlePointer(found->value); + return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value); } MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const { @@ -126,14 +121,14 @@ template <class Value> class HashedSearch { } LongestPointer LookupLongest(WordIndex word, const Node &node) const { - // Sign bit is always on because longest n-grams do not extend left. + // Sign bit is always on because longest n-grams do not extend left. typename Longest::ConstIterator found; if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer(); return LongestPointer(found->value.prob); } - // Generate a node without necessarily checking that it actually exists. - // Optionally return false if it's know to not exist. + // Generate a node without necessarily checking that it actually exists. + // Optionally return false if it's know to not exist. bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { assert(begin != end); node = static_cast<Node>(*begin); @@ -144,7 +139,7 @@ template <class Value> class HashedSearch { } private: - // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. + // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); @@ -153,7 +148,7 @@ template <class Value> class HashedSearch { public: Unigram() {} - Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : + Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : unigram_(static_cast<typename Value::Weights*>(start)) #ifdef DEBUG , count_(count) diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 1264baf5..60be416b 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -41,7 +41,8 @@ template <class Quant, class Bhiksha> class TrieSearch { static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); - Bhiksha::UpdateConfigFromBinary(fd, config); + // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2. + if (counts.size() > 2) Bhiksha::UpdateConfigFromBinary(fd, config); } static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { diff --git a/klm/lm/state.hh b/klm/lm/state.hh index d8e6c132..a6b9accb 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -91,7 +91,7 @@ inline uint64_t hash_value(const Left &left) { } struct ChartState { - bool operator==(const ChartState &other) { + bool operator==(const ChartState &other) const { return (right == other.right) && (left == other.left); } diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 6a5a0196..17f064b2 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -6,6 +6,7 @@ #include "util/string_piece.hh" #include <string> +#include <string.h> namespace lm { namespace base { @@ -119,7 +120,9 @@ class Model { size_t StateSize() const { return state_size_; } const void *BeginSentenceMemory() const { return begin_sentence_memory_; } + void BeginSentenceWrite(void *to) const { memcpy(to, begin_sentence_memory_, StateSize()); } const void *NullContextMemory() const { return null_context_memory_; } + void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); } // Requires in_state != out_state virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 3902f117..226ae438 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -25,7 +25,7 @@ uint64_t HashForVocab(const char *str, std::size_t len); inline uint64_t HashForVocab(const StringPiece &str) { return HashForVocab(str.data(), str.length()); } -class ProbingVocabularyHeader; +struct ProbingVocabularyHeader; } // namespace detail class WriteWordsWrapper : public EnumerateVocab { |