diff options
author | Kenneth Heafield <github@kheafield.com> | 2013-06-18 11:34:20 -0700 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2013-06-18 11:34:20 -0700 |
commit | 535d4016ec5179cb673b697c2e81500a2097924c (patch) | |
tree | 4ae43b02d23317f37017a93fd12552b55c8d2a06 /klm | |
parent | 5dc790adc222db09c25b8be1b7a443a142f70180 (diff) |
lazy dd880b4 including kenlm 6eef0f1
Diffstat (limited to 'klm')
-rw-r--r-- | klm/lm/builder/lmplz_main.cc | 15 | ||||
-rw-r--r-- | klm/lm/builder/ngram.hh | 2 | ||||
-rw-r--r-- | klm/lm/model.cc | 21 | ||||
-rw-r--r-- | klm/lm/model.hh | 5 | ||||
-rw-r--r-- | klm/lm/search_hashed.cc | 29 | ||||
-rw-r--r-- | klm/lm/search_hashed.hh | 19 | ||||
-rw-r--r-- | klm/lm/state.hh | 2 | ||||
-rw-r--r-- | klm/lm/virtual_interface.hh | 3 | ||||
-rw-r--r-- | klm/lm/vocab.hh | 2 | ||||
-rw-r--r-- | klm/search/Makefile.am | 17 | ||||
-rw-r--r-- | klm/search/context.hh | 12 | ||||
-rw-r--r-- | klm/search/edge_generator.cc | 12 | ||||
-rw-r--r-- | klm/search/vertex.cc | 204 | ||||
-rw-r--r-- | klm/search/vertex.hh | 121 | ||||
-rw-r--r-- | klm/search/vertex_generator.hh | 36 | ||||
-rw-r--r-- | klm/util/double-conversion/utils.h | 6 | ||||
-rw-r--r-- | klm/util/file.cc | 14 | ||||
-rw-r--r-- | klm/util/pool.cc | 4 | ||||
-rw-r--r-- | klm/util/probing_hash_table.hh | 23 | ||||
-rw-r--r-- | klm/util/proxy_iterator.hh | 25 | ||||
-rw-r--r-- | klm/util/sized_iterator.hh | 21 | ||||
-rw-r--r-- | klm/util/stream/chain.hh | 2 | ||||
-rw-r--r-- | klm/util/usage.cc | 15 |
23 files changed, 418 insertions, 192 deletions
diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc index 1e086dcc..c87abdb8 100644 --- a/klm/lm/builder/lmplz_main.cc +++ b/klm/lm/builder/lmplz_main.cc @@ -52,13 +52,14 @@ int main(int argc, char *argv[]) { std::cerr << "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n" "Please cite:\n" - "@inproceedings{kenlm,\n" - "author = {Kenneth Heafield},\n" - "title = {{KenLM}: Faster and Smaller Language Model Queries},\n" - "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n" - "month = {July}, year={2011},\n" - "address = {Edinburgh, UK},\n" - "publisher = {Association for Computational Linguistics},\n" + "@inproceedings{Heafield-estimate,\n" + " author = {Kenneth Heafield and Ivan Pouzyrevsky and Jonathan H. Clark and Philipp Koehn},\n" + " title = {Scalable Modified {Kneser-Ney} Language Model Estimation},\n" + " year = {2013},\n" + " month = {8},\n" + " booktitle = {Proceedings of the 51st Annual Meeting of the Association for Computational Linguistics},\n" + " address = {Sofia, Bulgaria},\n" + " url = {http://kheafield.com/professional/edinburgh/estimate\\_paper.pdf},\n" "}\n\n" "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n" "the model (-o) is the only mandatory option. As this is an on-disk program,\n" diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh index 2984ed0b..f5681516 100644 --- a/klm/lm/builder/ngram.hh +++ b/klm/lm/builder/ngram.hh @@ -53,7 +53,7 @@ class NGram { Payload &Value() { return *reinterpret_cast<Payload *>(end_); } uint64_t &Count() { return Value().count; } - const uint64_t Count() const { return Value().count; } + uint64_t Count() const { return Value().count; } std::size_t Order() const { return end_ - begin_; } diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a40fd2fb..a26654a6 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -304,5 +304,26 @@ template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiks template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>; } // namespace detail + +base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) { + RecognizeBinary(file_name, model_type); + switch (model_type) { + case PROBING: + return new ProbingModel(file_name, config); + case REST_PROBING: + return new RestProbingModel(file_name, config); + case TRIE: + return new TrieModel(file_name, config); + case QUANT_TRIE: + return new QuantTrieModel(file_name, config); + case ARRAY_TRIE: + return new ArrayTrieModel(file_name, config); + case QUANT_ARRAY_TRIE: + return new QuantArrayTrieModel(file_name, config); + default: + UTIL_THROW(FormatLoadException, "Confused by model type " << model_type); + } +} + } // namespace ngram } // namespace lm diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 13ff864e..60f55110 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -153,6 +153,11 @@ LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<Separat typedef ::lm::ngram::ProbingVocabulary Vocabulary; typedef ProbingModel Model; +/* Autorecognize the file type, load, and return the virtual base class. Don't + * use the virtual base class if you can avoid it. Instead, use the above + * classes as template arguments to your own virtual feature function.*/ +base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING); + } // namespace ngram } // namespace lm diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 2d6f15b2..62275d27 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -54,7 +54,7 @@ template <class Weights> class ActivateUnigram { Weights *modify_; }; -// Find the lower order entry, inserting blanks along the way as necessary. +// Find the lower order entry, inserting blanks along the way as necessary. template <class Value> void FindLower( const std::vector<uint64_t> &keys, typename Value::Weights &unigram, @@ -64,7 +64,7 @@ template <class Value> void FindLower( typename Value::ProbingEntry entry; // Backoff will always be 0.0. We'll get the probability and rest in another pass. entry.value.backoff = kNoExtensionBackoff; - // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. + // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. for (int lower = keys.size() - 2; ; --lower) { if (lower == -1) { between.push_back(&unigram); @@ -77,11 +77,11 @@ template <class Value> void FindLower( } } -// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here. +// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here. template <class Added, class Build> void AdjustLower( const Added &added, const Build &build, - std::vector<typename Build::Value::Weights *> &between, + std::vector<typename Build::Value::Weights *> &between, const unsigned int n, const std::vector<WordIndex> &vocab_ids, typename Build::Value::Weights *unigrams, @@ -93,14 +93,14 @@ template <class Added, class Build> void AdjustLower( } typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle; float prob = -fabs(between.back()->prob); - // Order of the n-gram on which probabilities are based. + // Order of the n-gram on which probabilities are based. unsigned char basis = n - between.size(); assert(basis != 0); typename Build::Value::Weights **change = &between.back(); // Skip the basis. --change; if (basis == 1) { - // Hallucinate a bigram based on a unigram's backoff and a unigram probability. + // Hallucinate a bigram based on a unigram's backoff and a unigram probability. float &backoff = unigrams[vocab_ids[1]].backoff; SetExtension(backoff); prob += backoff; @@ -128,14 +128,14 @@ template <class Added, class Build> void AdjustLower( typename std::vector<typename Value::Weights *>::const_iterator i(between.begin()); build.MarkExtends(**i, added); const typename Value::Weights *longer = *i; - // Everything has probability but is not marked as extending. + // Everything has probability but is not marked as extending. for (++i; i != between.end(); ++i) { build.MarkExtends(**i, *longer); longer = *i; } } -// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds. +// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds. template <class Build> void MarkLower( const std::vector<uint64_t> &keys, const Build &build, @@ -144,15 +144,15 @@ template <class Build> void MarkLower( int start_order, const typename Build::Value::Weights &longer) { if (start_order == 0) return; - typename util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash>::MutableIterator iter; - // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code. + // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code. for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) { if (even_lower == -1) { build.MarkExtends(unigram, longer); return; } - middle[even_lower].UnsafeMutableFind(keys[even_lower], iter); - if (!build.MarkExtends(iter->value, longer)) return; + if (!build.MarkExtends( + middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value, + longer)) return; } } @@ -168,7 +168,6 @@ template <class Build, class Activate, class Store> void ReadNGrams( Store &store, PositiveProbWarn &warn) { typedef typename Build::Value Value; - typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle; assert(n >= 2); ReadNGramHeader(f, n); @@ -186,7 +185,7 @@ template <class Build, class Activate, class Store> void ReadNGrams( for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); } - // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. + // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. util::SetSign(entry.value.prob); entry.key = keys[n-2]; @@ -203,7 +202,7 @@ template <class Build, class Activate, class Store> void ReadNGrams( } // namespace namespace detail { - + template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { std::size_t allocated = Unigram::Size(counts[0]); unigram_ = Unigram(start, counts[0], allocated); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 00595796..9d067bc2 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -71,7 +71,7 @@ template <class Value> class HashedSearch { static const bool kDifferentRest = Value::kDifferentRest; static const unsigned int kVersion = 0; - // TODO: move probing_multiplier here with next binary file format update. + // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { @@ -102,14 +102,9 @@ template <class Value> class HashedSearch { return ret; } -#pragma GCC diagnostic ignored "-Wuninitialized" MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const { node = extend_pointer; - typename Middle::ConstIterator found; - bool got = middle_[extend_length - 2].Find(extend_pointer, found); - assert(got); - (void)got; - return MiddlePointer(found->value); + return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value); } MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const { @@ -126,14 +121,14 @@ template <class Value> class HashedSearch { } LongestPointer LookupLongest(WordIndex word, const Node &node) const { - // Sign bit is always on because longest n-grams do not extend left. + // Sign bit is always on because longest n-grams do not extend left. typename Longest::ConstIterator found; if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer(); return LongestPointer(found->value.prob); } - // Generate a node without necessarily checking that it actually exists. - // Optionally return false if it's know to not exist. + // Generate a node without necessarily checking that it actually exists. + // Optionally return false if it's know to not exist. bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { assert(begin != end); node = static_cast<Node>(*begin); @@ -144,7 +139,7 @@ template <class Value> class HashedSearch { } private: - // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. + // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); @@ -153,7 +148,7 @@ template <class Value> class HashedSearch { public: Unigram() {} - Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : + Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : unigram_(static_cast<typename Value::Weights*>(start)) #ifdef DEBUG , count_(count) diff --git a/klm/lm/state.hh b/klm/lm/state.hh index d8e6c132..a6b9accb 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -91,7 +91,7 @@ inline uint64_t hash_value(const Left &left) { } struct ChartState { - bool operator==(const ChartState &other) { + bool operator==(const ChartState &other) const { return (right == other.right) && (left == other.left); } diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 6a5a0196..17f064b2 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -6,6 +6,7 @@ #include "util/string_piece.hh" #include <string> +#include <string.h> namespace lm { namespace base { @@ -119,7 +120,9 @@ class Model { size_t StateSize() const { return state_size_; } const void *BeginSentenceMemory() const { return begin_sentence_memory_; } + void BeginSentenceWrite(void *to) const { memcpy(to, begin_sentence_memory_, StateSize()); } const void *NullContextMemory() const { return null_context_memory_; } + void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); } // Requires in_state != out_state virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 3902f117..226ae438 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -25,7 +25,7 @@ uint64_t HashForVocab(const char *str, std::size_t len); inline uint64_t HashForVocab(const StringPiece &str) { return HashForVocab(str.data(), str.length()); } -class ProbingVocabularyHeader; +struct ProbingVocabularyHeader; } // namespace detail class WriteWordsWrapper : public EnumerateVocab { diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am index 03554276..b8c8a050 100644 --- a/klm/search/Makefile.am +++ b/klm/search/Makefile.am @@ -1,23 +1,10 @@ noinst_LIBRARIES = libksearch.a libksearch_a_SOURCES = \ - applied.hh \ - config.hh \ - context.hh \ - dedupe.hh \ - edge.hh \ - edge_generator.hh \ - header.hh \ - nbest.hh \ - rule.hh \ - types.hh \ - vertex.hh \ - vertex_generator.hh \ edge_generator.cc \ nbest.cc \ rule.cc \ - vertex.cc \ - vertex_generator.cc + vertex.cc -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/search/context.hh b/klm/search/context.hh index 08f21bbf..c3c8e53b 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -12,16 +12,6 @@ class ContextBase { public: explicit ContextBase(const Config &config) : config_(config) {} - VertexNode *NewVertexNode() { - VertexNode *ret = vertex_node_pool_.construct(); - assert(ret); - return ret; - } - - void DeleteVertexNode(VertexNode *node) { - vertex_node_pool_.destroy(node); - } - unsigned int PopLimit() const { return config_.PopLimit(); } Score LMWeight() const { return config_.LMWeight(); } @@ -29,8 +19,6 @@ class ContextBase { const Config &GetConfig() const { return config_; } private: - boost::object_pool<VertexNode> vertex_node_pool_; - Config config_; }; diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index eacf5de5..dd9d61e4 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -54,20 +54,20 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) { Arity victim = 0; Arity victim_completed; Arity incomplete; + unsigned char lowest_niceness = 255; // Select victim or return if complete. { Arity completed = 0; - unsigned char lowest_length = 255; for (Arity i = 0; i != arity; ++i) { if (top_nt[i].Complete()) { ++completed; - } else if (top_nt[i].Length() < lowest_length) { - lowest_length = top_nt[i].Length(); + } else if (top_nt[i].Niceness() < lowest_niceness) { + lowest_niceness = top_nt[i].Niceness(); victim = i; victim_completed = completed; } } - if (lowest_length == 255) { + if (lowest_niceness == 255) { return top; } incomplete = arity - completed; @@ -92,10 +92,14 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) { generate_.push(alternate); } +#ifndef NDEBUG + Score before = top.GetScore(); +#endif // top is now the continuation. FastScore(context, victim, victim - victim_completed, incomplete, old_value, top); // TODO: dedupe? generate_.push(top); + assert(lowest_niceness != 254 || top.GetScore() == before); // Invalid indicates no new hypothesis generated. return PartialEdge(); diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index 45842982..bf40810e 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -2,6 +2,8 @@ #include "search/context.hh" +#include <boost/unordered_map.hpp> + #include <algorithm> #include <functional> @@ -11,45 +13,193 @@ namespace search { namespace { -struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> { - bool operator()(const VertexNode *first, const VertexNode *second) const { - return first->Bound() > second->Bound(); +const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); + +class DivideLeft { + public: + explicit DivideLeft(unsigned char index) + : index_(index) {} + + uint64_t operator()(const lm::ngram::ChartState &state) const { + return (index_ < state.left.length) ? + state.left.pointers[index_] : + (kCompleteAdd - state.left.full); + } + + private: + unsigned char index_; +}; + +class DivideRight { + public: + explicit DivideRight(unsigned char index) + : index_(index) {} + + uint64_t operator()(const lm::ngram::ChartState &state) const { + return (index_ < state.right.length) ? + static_cast<uint64_t>(state.right.words[index_]) : + (kCompleteAdd - state.left.full); + } + + private: + unsigned char index_; +}; + +template <class Divider> void Split(const Divider ÷r, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) { + // Map from divider to index in extend. + typedef boost::unordered_map<uint64_t, std::size_t> Lookup; + Lookup lookup; + for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) { + uint64_t key = divider(i->state); + std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size()))); + if (res.second) { + extend.resize(extend.size() + 1); + extend.back().AppendHypothesis(*i); + } else { + extend[res.first->second].AppendHypothesis(*i); + } } + //assert((extend.size() != 1) || (hypos.size() == 1)); +} + +lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) { + return right.words[index]; +} + +uint64_t Identify(const lm::ngram::Left &left, unsigned char index) { + return left.pointers[index]; +} + +template <class Side> class DetermineSame { + public: + DetermineSame(const Side &side, unsigned char guaranteed) + : side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {} + + void Consider(const Side &other) { + if (shared_ != other.length) { + complete_ = false; + if (shared_ > other.length) + shared_ = other.length; + } + for (unsigned char i = guaranteed_; i < shared_; ++i) { + if (Identify(side_, i) != Identify(other, i)) { + shared_ = i; + complete_ = false; + return; + } + } + } + + unsigned char Shared() const { return shared_; } + + bool Complete() const { return complete_; } + + private: + const Side &side_; + unsigned char guaranteed_, shared_; + bool complete_; }; +// Custom enum to save memory: valid values of policy_. +// Alternate and there is still alternation to do. +const unsigned char kPolicyAlternate = 0; +// Branch based on left state only, because right ran out or this is a left tree. +const unsigned char kPolicyOneLeft = 1; +// Branch based on right state only. +const unsigned char kPolicyOneRight = 2; +// Reveal everything in the next branch. Used to terminate the left/right policies. +// static const unsigned char kPolicyEverything = 3; + +} // namespace + +namespace { +struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> { + bool operator()(const HypoState &first, const HypoState &second) const { + return first.score > second.score; + } +}; } // namespace -void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) { - if (Complete()) { - assert(end_); - assert(extend_.empty()); - return; +void VertexNode::FinishRoot() { + std::sort(hypos_.begin(), hypos_.end(), GreaterByScore()); + extend_.clear(); + // HACK: extend to one hypo so that root can be blank. + state_.left.full = false; + state_.left.length = 0; + state_.right.length = 0; + right_full_ = false; + niceness_ = 0; + policy_ = kPolicyAlternate; + if (hypos_.size() == 1) { + extend_.resize(1); + extend_.front().AppendHypothesis(hypos_.front()); + extend_.front().FinishedAppending(0, 0); + } + if (hypos_.empty()) { + bound_ = -INFINITY; + } else { + bound_ = hypos_.front().score; } - if (extend_.size() == 1) { - parent_ptr = extend_[0]; - extend_[0]->RecursiveSortAndSet(context, parent_ptr); - context.DeleteVertexNode(this); - return; +} + +void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) { + assert(!hypos_.empty()); + assert(extend_.empty()); + bound_ = hypos_.front().score; + state_ = hypos_.front().state; + bool all_full = state_.left.full; + bool all_non_full = !state_.left.full; + DetermineSame<lm::ngram::Left> left(state_.left, common_left); + DetermineSame<lm::ngram::Right> right(state_.right, common_right); + for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) { + all_full &= i->state.left.full; + all_non_full &= !i->state.left.full; + left.Consider(i->state.left); + right.Consider(i->state.right); } - for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { - (*i)->RecursiveSortAndSet(context, *i); + state_.left.full = all_full && left.Complete(); + right_full_ = all_full && right.Complete(); + state_.left.length = left.Shared(); + state_.right.length = right.Shared(); + + if (!all_full && !all_non_full) { + policy_ = kPolicyAlternate; + } else if (left.Complete()) { + policy_ = kPolicyOneRight; + } else if (right.Complete()) { + policy_ = kPolicyOneLeft; + } else { + policy_ = kPolicyAlternate; } - std::sort(extend_.begin(), extend_.end(), GreaterByBound()); - bound_ = extend_.front()->Bound(); + niceness_ = state_.left.length + state_.right.length; } -void VertexNode::SortAndSet(ContextBase &context) { - // This is the root. The root might be empty. - if (extend_.empty()) { - bound_ = -INFINITY; - return; +void VertexNode::BuildExtend() { + // Already built. + if (!extend_.empty()) return; + // Nothing to build since this is a leaf. + if (hypos_.size() <= 1) return; + bool left_branch = true; + switch (policy_) { + case kPolicyAlternate: + left_branch = (state_.left.length <= state_.right.length); + break; + case kPolicyOneLeft: + left_branch = true; + break; + case kPolicyOneRight: + left_branch = false; + break; + } + if (left_branch) { + Split(DivideLeft(state_.left.length), hypos_, extend_); + } else { + Split(DivideRight(state_.right.length), hypos_, extend_); } - // The root cannot be replaced. There's always one transition. - for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { - (*i)->RecursiveSortAndSet(context, *i); + for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) { + // TODO: provide more here for branching? + i->FinishedAppending(state_.left.length, state_.right.length); } - std::sort(extend_.begin(), extend_.end(), GreaterByBound()); - bound_ = extend_.front()->Bound(); } } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index ca9a4fcd..81c3cfed 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -16,59 +16,74 @@ namespace search { class ContextBase; +struct HypoState { + History history; + lm::ngram::ChartState state; + Score score; +}; + class VertexNode { public: - VertexNode() : end_() {} - - void InitRoot() { - extend_.clear(); - state_.left.full = false; - state_.left.length = 0; - state_.right.length = 0; - right_full_ = false; - end_ = History(); + VertexNode() {} + + void InitRoot() { hypos_.clear(); } + + /* The steps of building a VertexNode: + * 1. Default construct. + * 2. AppendHypothesis at least once, possibly multiple times. + * 3. FinishAppending with the number of words on left and right guaranteed + * to be common. + * 4. If !Complete(), call BuildExtend to construct the extensions + */ + // Must default construct, call AppendHypothesis 1 or more times then do FinishedAppending. + void AppendHypothesis(const NBestComplete &best) { + assert(hypos_.empty() || !(hypos_.front().state == *best.state)); + HypoState hypo; + hypo.history = best.history; + hypo.state = *best.state; + hypo.score = best.score; + hypos_.push_back(hypo); + } + void AppendHypothesis(const HypoState &hypo) { + hypos_.push_back(hypo); } - lm::ngram::ChartState &MutableState() { return state_; } - bool &MutableRightFull() { return right_full_; } + // Sort hypotheses for the root. + void FinishRoot(); - void AddExtend(VertexNode *next) { - extend_.push_back(next); - } + void FinishedAppending(const unsigned char common_left, const unsigned char common_right); - void SetEnd(History end, Score score) { - assert(!end_); - end_ = end; - bound_ = score; - } - - void SortAndSet(ContextBase &context); + void BuildExtend(); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_ && extend_.empty(); + return hypos_.empty() && extend_.empty(); } bool Complete() const { - return end_; + // HACK: prevent root from being complete. TODO: allow root to be complete. + return hypos_.size() == 1 && extend_.empty(); } const lm::ngram::ChartState &State() const { return state_; } bool RightFull() const { return right_full_; } + // Priority relative to other non-terminals. 0 is highest. + unsigned char Niceness() const { return niceness_; } + Score Bound() const { return bound_; } - unsigned char Length() const { - return state_.left.length + state_.right.length; - } - // Will be invalid unless this is a leaf. - History End() const { return end_; } + History End() const { + assert(hypos_.size() == 1); + return hypos_.front().history; + } - const VertexNode &operator[](size_t index) const { - return *extend_[index]; + VertexNode &operator[](size_t index) { + assert(!extend_.empty()); + return extend_[index]; } size_t Size() const { @@ -76,22 +91,26 @@ class VertexNode { } private: - void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); + // Hypotheses to be split. + std::vector<HypoState> hypos_; - std::vector<VertexNode*> extend_; + std::vector<VertexNode> extend_; lm::ngram::ChartState state_; bool right_full_; + unsigned char niceness_; + + unsigned char policy_; + Score bound_; - History end_; }; class PartialVertex { public: PartialVertex() {} - explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} + explicit PartialVertex(VertexNode &back) : back_(&back), index_(0) {} bool Empty() const { return back_->Empty(); } @@ -100,17 +119,14 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } - - unsigned char Length() const { return back_->Length(); } + Score Bound() const { return index_ ? (*back_)[index_].Bound() : back_->Bound(); } - bool HasAlternative() const { - return index_ + 1 < back_->Size(); - } + unsigned char Niceness() const { return back_->Niceness(); } // Split into continuation and alternative, rendering this the continuation. bool Split(PartialVertex &alternative) { assert(!Complete()); + back_->BuildExtend(); bool ret; if (index_ + 1 < back_->Size()) { alternative.index_ = index_ + 1; @@ -129,7 +145,7 @@ class PartialVertex { } private: - const VertexNode *back_; + VertexNode *back_; unsigned int index_; }; @@ -139,10 +155,21 @@ class Vertex { public: Vertex() {} - PartialVertex RootPartial() const { return PartialVertex(root_); } + //PartialVertex RootFirst() const { return PartialVertex(right_); } + PartialVertex RootAlternate() { return PartialVertex(root_); } + //PartialVertex RootLast() const { return PartialVertex(left_); } + + bool Empty() const { + return root_.Empty(); + } + + Score Bound() const { + return root_.Bound(); + } - History BestChild() const { - PartialVertex top(RootPartial()); + History BestChild() { + // left_ and right_ are not set at the root. + PartialVertex top(RootAlternate()); if (top.Empty()) { return History(); } else { @@ -158,6 +185,12 @@ class Vertex { template <class Output> friend class VertexGenerator; template <class Output> friend class RootVertexGenerator; VertexNode root_; + + // These will not be set for the root vertex. + // Branches only on left state. + //VertexNode left_; + // Branches only on right state. + //VertexNode right_; }; } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 646b8189..91000012 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -4,10 +4,8 @@ #include "search/edge.hh" #include "search/types.hh" #include "search/vertex.hh" -#include "util/exception.hh" #include <boost/unordered_map.hpp> -#include <boost/version.hpp> namespace lm { namespace ngram { @@ -19,45 +17,25 @@ namespace search { class ContextBase; -#if BOOST_VERSION > 104200 -// Parallel structure to VertexNode. -struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map<uint64_t, Trie> extend; -}; - -void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); - -#endif // BOOST_VERSION - // Output makes the single-best or n-best list. template <class Output> class VertexGenerator { public: - VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { - gen.root_.InitRoot(); - } + VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {} void NewHypothesis(PartialEdge partial) { nbest_.Add(existing_[hash_value(partial.CompletedState())], partial); } void FinishedSearch() { -#if BOOST_VERSION > 104200 - Trie root; - root.under = &gen_.root_; + gen_.root_.InitRoot(); for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { - AddHypothesis(context_, root, nbest_.Complete(i->second)); + gen_.root_.AppendHypothesis(nbest_.Complete(i->second)); } existing_.clear(); - root.under->SortAndSet(context_); -#else - UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); -#endif + gen_.root_.FinishRoot(); } - const Vertex &Generating() const { return gen_; } + Vertex &Generating() { return gen_; } private: ContextBase &context_; @@ -84,8 +62,8 @@ template <class Output> class RootVertexGenerator { void FinishedSearch() { gen_.root_.InitRoot(); - NBestComplete completed(out_.Complete(combine_)); - gen_.root_.SetEnd(completed.history, completed.score); + gen_.root_.AppendHypothesis(out_.Complete(combine_)); + gen_.root_.FinishRoot(); } private: diff --git a/klm/util/double-conversion/utils.h b/klm/util/double-conversion/utils.h index 2bd71605..9ccb3b65 100644 --- a/klm/util/double-conversion/utils.h +++ b/klm/util/double-conversion/utils.h @@ -299,7 +299,11 @@ template <class Dest, class Source> inline Dest BitCast(const Source& source) { // Compile time assertion: sizeof(Dest) == sizeof(Source) // A compile error here means your Dest and Source have different sizes. - typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1]; + typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1] +#if __GNUC__ > 4 || __GNUC__ == 4 && __GNUC_MINOR__ >= 8 + __attribute__((unused)) +#endif + ; Dest dest; memmove(&dest, &source, sizeof(dest)); diff --git a/klm/util/file.cc b/klm/util/file.cc index c7d8e23b..bef04cb1 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -116,7 +116,7 @@ std::size_t GuardLarge(std::size_t size) { // The following operating systems have broken read/write/pread/pwrite that // only supports up to 2^31. #if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID) - return std::min(static_cast<std::size_t>(INT_MAX), size); + return std::min(static_cast<std::size_t>(static_cast<unsigned>(-1)), size); #else return size; #endif @@ -209,7 +209,7 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) { #endif errno = 0; do { - ret = + ret = #if defined(_WIN32) || defined(_WIN64) _write #else @@ -229,7 +229,7 @@ void WriteOrThrow(FILE *to, const void *data, std::size_t size) { } void FSyncOrThrow(int fd) { -// Apparently windows doesn't have fsync? +// Apparently windows doesn't have fsync? #if !defined(_WIN32) && !defined(_WIN64) UTIL_THROW_IF_ARG(-1 == fsync(fd), FDException, (fd), "while syncing"); #endif @@ -248,7 +248,7 @@ template <> struct CheckOffT<8> { typedef CheckOffT<sizeof(off_t)>::True IgnoredType; #endif -// Can't we all just get along? +// Can't we all just get along? void InternalSeek(int fd, int64_t off, int whence) { if ( #if defined(_WIN32) || defined(_WIN64) @@ -457,9 +457,9 @@ bool TryName(int fd, std::string &out) { std::ostringstream convert; convert << fd; name += convert.str(); - + struct stat sb; - if (-1 == lstat(name.c_str(), &sb)) + if (-1 == lstat(name.c_str(), &sb)) return false; out.resize(sb.st_size + 1); ssize_t ret = readlink(name.c_str(), &out[0], sb.st_size + 1); @@ -471,7 +471,7 @@ bool TryName(int fd, std::string &out) { } out.resize(ret); // Don't use the non-file names. - if (!out.empty() && out[0] != '/') + if (!out.empty() && out[0] != '/') return false; return true; #endif diff --git a/klm/util/pool.cc b/klm/util/pool.cc index 429ba158..db72a8ec 100644 --- a/klm/util/pool.cc +++ b/klm/util/pool.cc @@ -25,7 +25,9 @@ void Pool::FreeAll() { } void *Pool::More(std::size_t size) { - std::size_t amount = std::max(static_cast<size_t>(32) << free_list_.size(), size); + // Double until we hit 2^21 (2 MB). Then grow in 2 MB blocks. + std::size_t desired_size = static_cast<size_t>(32) << std::min(static_cast<std::size_t>(16), free_list_.size()); + std::size_t amount = std::max(desired_size, size); uint8_t *ret = static_cast<uint8_t*>(MallocOrThrow(amount)); free_list_.push_back(ret); current_ = ret + size; diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 57866ff9..51a2944d 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -109,9 +109,20 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry if (equal_(got, key)) { out = i; return true; } if (equal_(got, invalid_)) return false; if (++i == end_) i = begin_; - } + } + } + + // Like UnsafeMutableFind, but the key must be there. + template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) { + for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) { + Key got(i->GetKey()); + if (equal_(got, key)) { return i; } + assert(!equal_(got, invalid_)); + if (++i == end_) i = begin_; + } } + template <class Key> bool Find(const Key key, ConstIterator &out) const { #ifdef DEBUG assert(initialized_); @@ -124,6 +135,16 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry } } + // Like Find but we're sure it must be there. + template <class Key> ConstIterator MustFind(const Key key) const { + for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) { + Key got(i->GetKey()); + if (equal_(got, key)) { return i; } + assert(!equal_(got, invalid_)); + if (++i == end_) i = begin_; + } + } + void Clear() { Entry invalid; invalid.SetKey(invalid_); diff --git a/klm/util/proxy_iterator.hh b/klm/util/proxy_iterator.hh index 121a45fa..0ee1716f 100644 --- a/klm/util/proxy_iterator.hh +++ b/klm/util/proxy_iterator.hh @@ -6,11 +6,11 @@ /* This is a RandomAccessIterator that uses a proxy to access the underlying * data. Useful for packing data at bit offsets but still using STL - * algorithms. + * algorithms. * * Normally I would use boost::iterator_facade but some people are too lazy to * install boost and still want to use my language model. It's amazing how - * many operators an iterator has. + * many operators an iterator has. * * The Proxy needs to provide: * class InnerIterator; @@ -22,15 +22,15 @@ * operator<(InnerIterator) * operator+=(std::ptrdiff_t) * operator-(InnerIterator) - * and of course whatever Proxy needs to dereference it. + * and of course whatever Proxy needs to dereference it. * - * It's also a good idea to specialize std::swap for Proxy. + * It's also a good idea to specialize std::swap for Proxy. */ namespace util { template <class Proxy> class ProxyIterator { private: - // Self. + // Self. typedef ProxyIterator<Proxy> S; typedef typename Proxy::InnerIterator InnerIterator; @@ -38,16 +38,21 @@ template <class Proxy> class ProxyIterator { typedef std::random_access_iterator_tag iterator_category; typedef typename Proxy::value_type value_type; typedef std::ptrdiff_t difference_type; - typedef Proxy reference; + typedef Proxy & reference; typedef Proxy * pointer; ProxyIterator() {} - // For cast from non const to const. + // For cast from non const to const. template <class AlternateProxy> ProxyIterator(const ProxyIterator<AlternateProxy> &in) : p_(*in) {} explicit ProxyIterator(const Proxy &p) : p_(p) {} - // p_'s operator= does value copying, but here we want iterator copying. + // p_'s swap does value swapping, but here we want iterator swapping + friend inline void swap(ProxyIterator<Proxy> &first, ProxyIterator<Proxy> &second) { + swap(first.I(), second.I()); + } + + // p_'s operator= does value copying, but here we want iterator copying. S &operator=(const S &other) { I() = other.I(); return *this; @@ -72,8 +77,8 @@ template <class Proxy> class ProxyIterator { std::ptrdiff_t operator-(const S &other) const { return I() - other.I(); } - Proxy operator*() { return p_; } - const Proxy operator*() const { return p_; } + Proxy &operator*() { return p_; } + const Proxy &operator*() const { return p_; } Proxy *operator->() { return &p_; } const Proxy *operator->() const { return &p_; } Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); } diff --git a/klm/util/sized_iterator.hh b/klm/util/sized_iterator.hh index cf998953..dce8f229 100644 --- a/klm/util/sized_iterator.hh +++ b/klm/util/sized_iterator.hh @@ -36,6 +36,11 @@ class SizedInnerIterator { void *Data() { return ptr_; } std::size_t EntrySize() const { return size_; } + friend inline void swap(SizedInnerIterator &first, SizedInnerIterator &second) { + std::swap(first.ptr_, second.ptr_); + std::swap(first.size_, second.size_); + } + private: uint8_t *ptr_; std::size_t size_; @@ -64,9 +69,19 @@ class SizedProxy { const void *Data() const { return inner_.Data(); } void *Data() { return inner_.Data(); } + /** + // TODO: this (deep) swap was recently added. why? if any std heap sort etc + // algs are using swap, that's going to be worse performance than using + // =. i'm not sure why we *want* a deep swap. if C++11 compilers are + // choosing between move constructor and swap, then we'd better implement a + // (deep) move constructor. it may also be that this is moot since i made + // ProxyIterator a reference and added a shallow ProxyIterator swap? (I + // need Ken or someone competent to judge whether that's correct also. - + // let me know at graehl@gmail.com + */ friend void swap(SizedProxy &first, SizedProxy &second) { std::swap_ranges( - static_cast<char*>(first.inner_.Data()), + static_cast<char*>(first.inner_.Data()), static_cast<char*>(first.inner_.Data()) + first.inner_.EntrySize(), static_cast<char*>(second.inner_.Data())); } @@ -87,7 +102,7 @@ typedef ProxyIterator<SizedProxy> SizedIterator; inline SizedIterator SizedIt(void *ptr, std::size_t size) { return SizedIterator(SizedProxy(ptr, size)); } -// Useful wrapper for a comparison function i.e. sort. +// Useful wrapper for a comparison function i.e. sort. template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public std::binary_function<const Proxy &, const Proxy &, bool> { public: explicit SizedCompare(const Delegate &delegate = Delegate()) : delegate_(delegate) {} @@ -106,7 +121,7 @@ template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public } const Delegate &GetDelegate() const { return delegate_; } - + private: const Delegate delegate_; }; diff --git a/klm/util/stream/chain.hh b/klm/util/stream/chain.hh index 154b9b33..0cc83a85 100644 --- a/klm/util/stream/chain.hh +++ b/klm/util/stream/chain.hh @@ -122,7 +122,7 @@ class Chain { threads_.push_back(new Thread(Complete(), kRecycle)); } - Chain &operator>>(const Recycler &recycle) { + Chain &operator>>(const Recycler &) { CompleteLoop(); return *this; } diff --git a/klm/util/usage.cc b/klm/util/usage.cc index 5fa3cc9a..8db375e1 100644 --- a/klm/util/usage.cc +++ b/klm/util/usage.cc @@ -21,6 +21,21 @@ namespace util { #if !defined(_WIN32) && !defined(_WIN64) namespace { + +// On Mac OS X, clock_gettime is not implemented. +// CLOCK_MONOTONIC is not defined either. +#ifdef __MACH__ +#define CLOCK_MONOTONIC 0 + +int clock_gettime(int clk_id, struct timespec *tp) { + struct timeval tv; + gettimeofday(&tv, NULL); + tp->tv_sec = tv.tv_sec; + tp->tv_nsec = tv.tv_usec * 1000; + return 0; +} +#endif // __MACH__ + float FloatSec(const struct timeval &tv) { return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000.0); } |