diff options
author | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-22 16:01:23 -0500 |
---|---|---|
committer | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-22 16:01:23 -0500 |
commit | 597d89c11db53e91bc011eab70fd613bbe6453e8 (patch) | |
tree | 83c87c07d1ff6d3ee4e3b1626f7eddd49c61095b /klm | |
parent | 65e958ff2678a41c22be7171456a63f002ef370b (diff) | |
parent | 201af2acd394415a05072fbd53d42584875aa4b4 (diff) |
Merge branch 'master' of git://github.com/redpony/cdec
Diffstat (limited to 'klm')
48 files changed, 1438 insertions, 610 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index efa67056..39c4a9b6 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -16,11 +16,11 @@ namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; -// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). +// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; const long int kMagicVersion = 5; -// Old binary files built on 32-bit machines have this header. +// Old binary files built on 32-bit machines have this header. // TODO: eliminate with next binary release. struct OldSanity { char magic[sizeof(kMagicBytes)]; @@ -39,7 +39,7 @@ struct OldSanity { }; -// Test values aligned to 8 bytes. +// Test values aligned to 8 bytes. struct Sanity { char magic[ALIGN8(sizeof(kMagicBytes))]; float zero_f, one_f, minus_half_f; @@ -101,7 +101,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { - // Grow the file to accomodate the search, using zeros. + // Grow the file to accomodate the search, using zeros. try { util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size); } catch (util::ErrnoException &e) { @@ -114,7 +114,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t return reinterpret_cast<uint8_t*>(backing.search.get()); } // mmap it now. - // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. + // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. std::size_t page_size = util::SizePage(); std::size_t alignment_cruft = adjusted_vocab % page_size; backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); @@ -122,7 +122,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t } else { util::MapAnonymous(memory_size, backing.search); return reinterpret_cast<uint8_t*>(backing.search.get()); - } + } } void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { @@ -140,7 +140,7 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_ util::FSyncOrThrow(backing.file.get()); break; } - // header and vocab share the same mmap. The header is written here because we know the counts. + // header and vocab share the same mmap. The header is written here because we know the counts. Parameters params = Parameters(); params.counts = counts; params.fixed.order = counts.size(); @@ -160,7 +160,7 @@ namespace detail { bool IsBinaryFormat(int fd) { const uint64_t size = util::SizeFile(fd); if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false; - // Try reading the header. + // Try reading the header. util::scoped_memory memory; try { util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory); @@ -214,7 +214,7 @@ void SeekPastHeader(int fd, const Parameters ¶ms) { uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) { const uint64_t file_size = util::SizeFile(backing.file.get()); - // The header is smaller than a page, so we have to map the whole header as well. + // The header is smaller than a page, so we have to map the whole header as well. std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size); if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); @@ -233,7 +233,8 @@ void ComplainAboutARPA(const Config &config, ModelType model_type) { if (config.write_mmap || !config.messages) return; if (config.arpa_complain == Config::ALL) { *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; - } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) { + } else if (config.arpa_complain == Config::EXPENSIVE && + (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; } } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index f9d988ca..9520c41c 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -6,6 +6,7 @@ namespace lm { namespace ngram { Config::Config() : + show_progress(true), messages(&std::cerr), enumerate_vocab(NULL), unknown_missing(COMPLAIN), diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 739cee9c..0de7b7c6 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -11,46 +11,52 @@ /* Configuration for ngram model. Separate header to reduce pollution. */ namespace lm { - + class EnumerateVocab; namespace ngram { struct Config { - // EFFECTIVE FOR BOTH ARPA AND BINARY READS + // EFFECTIVE FOR BOTH ARPA AND BINARY READS + + // (default true) print progress bar to messages + bool show_progress; // Where to log messages including the progress bar. Set to NULL for // silence. std::ostream *messages; + std::ostream *ProgressMessages() const { + return show_progress ? messages : 0; + } + // This will be called with every string in the vocabulary. See // enumerate_vocab.hh for more detail. Config does not take ownership; you - // are still responsible for deleting it (or stack allocating). + // are still responsible for deleting it (or stack allocating). EnumerateVocab *enumerate_vocab; - // ONLY EFFECTIVE WHEN READING ARPA - // What to do when <unk> isn't in the provided model. + // What to do when <unk> isn't in the provided model. WarningAction unknown_missing; - // What to do when <s> or </s> is missing from the model. - // If THROW_UP, the exception will be of type util::SpecialWordMissingException. + // What to do when <s> or </s> is missing from the model. + // If THROW_UP, the exception will be of type util::SpecialWordMissingException. WarningAction sentence_marker_missing; // What to do with a positive log probability. For COMPLAIN and SILENT, map - // to 0. + // to 0. WarningAction positive_log_probability; - // The probability to substitute for <unk> if it's missing from the model. + // The probability to substitute for <unk> if it's missing from the model. // No effect if the model has <unk> or unknown_missing == THROW_UP. float unknown_missing_logprob; // Size multiplier for probing hash table. Must be > 1. Space is linear in // this. Time is probing_multiplier / (probing_multiplier - 1). No effect - // for sorted variant. + // for sorted variant. // If you find yourself setting this to a low number, consider using the - // TrieModel which has lower memory consumption. + // TrieModel which has lower memory consumption. float probing_multiplier; // Amount of memory to use for building. The actual memory usage will be @@ -58,10 +64,10 @@ struct Config { // models. std::size_t building_memory; - // Template for temporary directory appropriate for passing to mkdtemp. + // Template for temporary directory appropriate for passing to mkdtemp. // The characters XXXXXX are appended before passing to mkdtemp. Only // applies to trie. If NULL, defaults to write_mmap. If that's NULL, - // defaults to input file name. + // defaults to input file name. const char *temporary_directory_prefix; // Level of complaining to do when loading from ARPA instead of binary format. @@ -69,49 +75,46 @@ struct Config { ARPALoadComplain arpa_complain; // While loading an ARPA file, also write out this binary format file. Set - // to NULL to disable. + // to NULL to disable. const char *write_mmap; enum WriteMethod { - WRITE_MMAP, // Map the file directly. - WRITE_AFTER // Write after we're done. + WRITE_MMAP, // Map the file directly. + WRITE_AFTER // Write after we're done. }; WriteMethod write_method; - // Include the vocab in the binary file? Only effective if write_mmap != NULL. + // Include the vocab in the binary file? Only effective if write_mmap != NULL. bool include_vocab; - // Left rest options. Only used when the model includes rest costs. + // Left rest options. Only used when the model includes rest costs. enum RestFunction { REST_MAX, // Maximum of any score to the left - REST_LOWER, // Use lower-order files given below. + REST_LOWER, // Use lower-order files given below. }; RestFunction rest_function; - // Only used for REST_LOWER. + // Only used for REST_LOWER. std::vector<std::string> rest_lower_files; - // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used - // to quantize (and one of the remaining backoffs will be 0). + // to quantize (and one of the remaining backoffs will be 0). uint8_t prob_bits, backoff_bits; // Bhiksha compression (simple form). Only works with trie. uint8_t pointer_bhiksha_bits; - - + // ONLY EFFECTIVE WHEN READING BINARY - + // How to get the giant array into memory: lazy mmap, populate, read etc. - // See util/mmap.hh for details of MapMethod. + // See util/mmap.hh for details of MapMethod. util::LoadMethod load_method; - - // Set defaults. + // Set defaults. Config(); }; diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 8c27232e..85c1ea37 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -51,36 +51,36 @@ namespace ngram { template <class M> class RuleScore { public: - explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) { out.left.length = 0; out.right.length = 0; } void BeginSentence() { - out_.right = model_.BeginSentenceState(); - // out_.left is empty. + out_->right = model_.BeginSentenceState(); + // out_->left is empty. left_done_ = true; } void Terminal(WordIndex word) { - State copy(out_.right); - FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); + State copy(out_->right); + FullScoreReturn ret(model_.FullScore(copy, word, out_->right)); if (left_done_) { prob_ += ret.prob; return; } if (ret.independent_left) { prob_ += ret.prob; left_done_ = true; return; } - out_.left.pointers[out_.left.length++] = ret.extend_left; + out_->left.pointers[out_->left.length++] = ret.extend_left; prob_ += ret.rest; - if (out_.right.length != copy.length + 1) + if (out_->right.length != copy.length + 1) left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. void BeginNonTerminal(const ChartState &in, float prob = 0.0) { prob_ = prob; - out_ = in; + *out_ = in; left_done_ = in.left.full; } @@ -89,23 +89,23 @@ template <class M> class RuleScore { if (!in.left.length) { if (in.left.full) { - for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; + for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i; left_done_ = true; - out_.right = in.right; + out_->right = in.right; } return; } - if (!out_.right.length) { - out_.right = in.right; + if (!out_->right.length) { + out_->right = in.right; if (left_done_) { prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1); return; } - if (out_.left.length) { + if (out_->left.length) { left_done_ = true; } else { - out_.left = in.left; + out_->left = in.left; left_done_ = in.left.full; } return; @@ -113,10 +113,10 @@ template <class M> class RuleScore { float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1]; float *back = backoffs, *back2 = backoffs2; - unsigned char next_use = out_.right.length; + unsigned char next_use = out_->right.length; // First word - if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; + if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return; // Words after the first, so extending a bigram to begin with for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { @@ -127,54 +127,58 @@ template <class M> class RuleScore { if (in.left.full) { for (const float *i = back; i != back + next_use; ++i) prob_ += *i; left_done_ = true; - out_.right = in.right; + out_->right = in.right; return; } // Right state was minimized, so it's already independent of the new words to the left. if (in.right.length < in.left.length) { - out_.right = in.right; + out_->right = in.right; return; } // Shift exisiting words down. - for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { + for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) { *(i + in.right.length) = *i; } // Add words from in.right. - std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); + std::copy(in.right.words, in.right.words + in.right.length, out_->right.words); // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. - std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); - std::copy(back, back + next_use, out_.right.backoff + in.right.length); - out_.right.length = in.right.length + next_use; + std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff); + std::copy(back, back + next_use, out_->right.backoff + in.right.length); + out_->right.length = in.right.length + next_use; } float Finish() { // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. - out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1); + out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1); return prob_; } void Reset() { prob_ = 0.0; left_done_ = false; - out_.left.length = 0; - out_.right.length = 0; + out_->left.length = 0; + out_->right.length = 0; + } + void Reset(ChartState &replacement) { + out_ = &replacement; + Reset(); } private: bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { ProcessRet(model_.ExtendLeft( - out_.right.words, out_.right.words + next_use, // Words to extend into + out_->right.words, out_->right.words + next_use, // Words to extend into back_in, // Backoffs to use in.left.pointers[extend_length - 1], extend_length, // Words to be extended back_out, // Backoffs for the next score next_use)); // Length of n-gram to use in next scoring. - if (next_use != out_.right.length) { + if (next_use != out_->right.length) { left_done_ = true; if (!next_use) { // Early exit. - out_.right = in.right; + out_->right = in.right; prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1); return true; } @@ -193,13 +197,13 @@ template <class M> class RuleScore { left_done_ = true; return; } - out_.left.pointers[out_.left.length++] = ret.extend_left; + out_->left.pointers[out_->left.length++] = ret.extend_left; prob_ += ret.rest; } const M &model_; - ChartState &out_; + ChartState *out_; bool left_done_; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index 989f8324..3eb97ccd 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -4,9 +4,6 @@ * (kMaxOrder - 1) * sizeof(float) bytes instead of * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead */ -#ifndef KENLM_MAX_ORDER -#define KENLM_MAX_ORDER 6 -#endif #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 2fd20481..a40fd2fb 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -37,7 +37,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { LoadLM(file, config, *this); - // g++ prints warnings unless these are fully initialized. + // g++ prints warnings unless these are fully initialized. State begin_sentence = State(); begin_sentence.length = 1; begin_sentence.words[0] = vocab_.BeginSentence(); @@ -69,8 +69,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { - // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. - util::FilePiece f(backing_.file.release(), file, config.messages); + // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. + util::FilePiece f(backing_.file.release(), file, config.ProgressMessages()); try { std::vector<uint64_t> counts; // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. @@ -80,14 +80,14 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); - // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); if (config.write_mmap) { WriteWordsWrapper wrap(config.enumerate_vocab); vocab_.ConfigureEnumerate(&wrap, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + backing_.search.size()); + wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config)); } else { vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); @@ -95,7 +95,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT if (!vocab_.SawUnk()) { assert(config.unknown_missing != THROW_UP); - // Default probabilities for unknown. + // Default probabilities for unknown. search_.UnknownUnigram().backoff = 0.0; search_.UnknownUnigram().prob = config.unknown_missing_logprob; } @@ -147,7 +147,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { - // Generate a state from context. + // Generate a state from context. context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); if (context_rend == context_rbegin) { out_state.length = 0; @@ -191,7 +191,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, ret.rest = ptr.Rest(); ret.prob = ptr.Prob(); ret.extend_left = extend_pointer; - // If this function is called, then it does depend on left words. + // If this function is called, then it does depend on left words. ret.independent_left = false; } float subtract_me = ret.rest; @@ -199,7 +199,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, next_use = extend_length; ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret); next_use -= extend_length; - // Charge backoffs. + // Charge backoffs. for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; ret.prob -= subtract_me; ret.rest -= subtract_me; @@ -209,7 +209,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, namespace { // Do a paraonoid copy of history, assuming new_word has already been copied // (hence the -1). out_state.length could be zero so I avoided using -// std::copy. +// std::copy. void CopyRemainingHistory(const WordIndex *from, State &out_state) { WordIndex *out = out_state.words + 1; const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1; @@ -217,18 +217,19 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) { } } // namespace -/* Ugly optimized function. Produce a score excluding backoff. - * The search goes in increasing order of ngram length. +/* Ugly optimized function. Produce a score excluding backoff. + * The search goes in increasing order of ngram length. * Context goes backward, so context_begin is the word immediately preceeding - * new_word. + * new_word. */ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff( const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const { + assert(new_word < vocab_.Bound()); FullScoreReturn ret; - // ret.ngram_length contains the last known non-blank ngram length. + // ret.ngram_length contains the last known non-blank ngram length. ret.ngram_length = 1; typename Search::Node node; @@ -237,9 +238,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, ret.prob = uni.Prob(); ret.rest = uni.Rest(); - // This is the length of the context that should be used for continuation to the right. + // This is the length of the context that should be used for continuation to the right. out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; - // We'll write the word anyway since it will probably be used and does no harm being there. + // We'll write the word anyway since it will probably be used and does no harm being there. out_state.words[0] = new_word; if (context_rbegin == context_rend) return ret; diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index a1623834..2d6f15b2 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -231,7 +231,7 @@ template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { NoRestBuild build; - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { @@ -239,19 +239,19 @@ template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, cons case Config::REST_MAX: { MaxRestBuild build; - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } break; case Config::REST_LOWER: { LowerRestBuild<ProbingModel> build(config, counts.size(), vocab); - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } break; } } -template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { +template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { for (WordIndex i = 0; i < counts[0]; ++i) { build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]); } diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index a52f107b..00595796 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -147,7 +147,7 @@ template <class Value> class HashedSearch { // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); - template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); + template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); class Unigram { public: diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index debcfd07..1b0d9b26 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -55,7 +55,7 @@ struct ProbPointer { uint64_t index; }; -// Array of n-grams and float indices. +// Array of n-grams and float indices. class BackoffMessages { public: void Init(std::size_t entry_size) { @@ -100,7 +100,7 @@ class BackoffMessages { void Apply(float *const *const base, RecordReader &reader) { FinishedAdding(); if (current_ == allocated_) return; - // We'll also use the same buffer to record messages to blanks that they extend. + // We'll also use the same buffer to record messages to blanks that they extend. WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_); const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex); for (reader.Rewind(); reader && (current_ != allocated_); ) { @@ -109,7 +109,7 @@ class BackoffMessages { ++reader; break; case 1: - // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. + // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w; current_ += entry_size_; break; @@ -126,7 +126,7 @@ class BackoffMessages { break; } } - // Now this is a list of blanks that extend right. + // Now this is a list of blanks that extend right. entry_size_ = sizeof(WordIndex) * order; Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get())); current_ = (uint8_t*)backing_.get(); @@ -153,7 +153,7 @@ class BackoffMessages { private: void FinishedAdding() { Resize(current_ - (uint8_t*)backing_.get()); - // Sort requests in same order as files. + // Sort requests in same order as files. std::sort( util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)), util::SizedIterator(util::SizedProxy(current_, entry_size_)), @@ -220,7 +220,7 @@ class SRISucks { } private: - // This used to be one array. Then I needed to separate it by order for quantization to work. + // This used to be one array. Then I needed to separate it by order for quantization to work. std::vector<float> values_[KENLM_MAX_ORDER - 1]; BackoffMessages messages_[KENLM_MAX_ORDER - 1]; @@ -253,7 +253,7 @@ class FindBlanks { ++counts_.back(); } - // Unigrams wrote one past. + // Unigrams wrote one past. void Cleanup() { --counts_[0]; } @@ -270,15 +270,15 @@ class FindBlanks { SRISucks &sri_; }; -// Phase to actually write n-grams to the trie. +// Phase to actually write n-grams to the trie. template <class Quant, class Bhiksha> class WriteEntries { public: - WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : + WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), quant_(quant), unigrams_(unigrams), middle_(middle), - longest_(longest), + longest_(longest), bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)), order_(order), sri_(sri) {} @@ -328,7 +328,7 @@ struct Gram { const WordIndex *begin, *end; - // For queue, this is the direction we want. + // For queue, this is the direction we want. bool operator<(const Gram &other) const { return std::lexicographical_compare(other.begin, other.end, begin, end); } @@ -353,7 +353,7 @@ template <class Doing> class BlankManager { been_length_ = length; return; } - // There are blanks to insert starting with order blank. + // There are blanks to insert starting with order blank. unsigned char blank = cur - to + 1; UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context."); const float *lower_basis; @@ -363,7 +363,7 @@ template <class Doing> class BlankManager { assert(*lower_basis != kBadProb); doing_.MiddleBlank(blank, to, based_on, *lower_basis); *pre = *cur; - // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. + // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. basis_[blank - 1] = kBadProb; } *pre = *cur; @@ -377,7 +377,7 @@ template <class Doing> class BlankManager { unsigned char been_length_; float basis_[KENLM_MAX_ORDER]; - + Doing &doing_; }; @@ -451,7 +451,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re } void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { - // Fill unigram probabilities. + // Fill unigram probabilities. try { rewind(file); for (WordIndex i = 0; i < unigram_count; ++i) { @@ -486,7 +486,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve util::scoped_memory unigrams; MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); - RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); + RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder); fixed_counts = finder.Counts(); } unigram_file.reset(util::FDOpenOrThrow(unigram_fd)); @@ -504,7 +504,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve inputs[i-2].Rewind(); } if (Quant::kTrain) { - util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing"); + util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), + config.ProgressMessages(), "Quantizing"); for (unsigned char i = 2; i < counts.size(); ++i) { TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } @@ -519,13 +520,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); } - // Fill entries except unigram probabilities. + // Fill entries except unigram probabilities. { WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); - RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); + RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer); } - // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. + // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. for (unsigned char order = 2; order <= counts.size(); ++order) { const RecordReader &context = contexts[order - 2]; if (context) { @@ -541,13 +542,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve } /* Set ending offsets so the last entry will be sized properly */ - // Last entry for unigrams was already set. + // Last entry for unigrams was already set. if (out.middle_begin_ != out.middle_end_) { for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { i->FinishedLoading((i+1)->InsertIndex(), config); } (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config); - } + } } template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { @@ -595,7 +596,7 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::Initializ } else { temporary_prefix = file; } - // At least 1MB sorting memory. + // At least 1MB sorting memory. SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab); BuildTrie(sorted, counts, config, *this, quant_, vocab, backing); diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 11c27518..fd7f96dc 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { } *end_ = hashed; if (enumerate_) { - strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); + void *copied = string_backing_.Allocate(str.size()); + memcpy(copied, str.data(), str.size()); + strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size()); } ++end_; // This is 1 + the offset where it was inserted to make room for unk. @@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { if (enumerate_) { if (!strings_to_enumerate_.empty()) { - util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); + util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); util::JointSort(begin_, end_, values); } for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) { @@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { enumerate_->Add(i + 1, strings_to_enumerate_[i]); } strings_to_enumerate_.clear(); + string_backing_.FreeAll(); } else { util::JointSort(begin_, end_, reorder_vocab + 1); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index de54eb06..3902f117 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -4,6 +4,7 @@ #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" +#include "util/pool.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" #include "util/string_piece.hh" @@ -96,7 +97,9 @@ class SortedVocabulary : public base::Vocabulary { EnumerateVocab *enumerate_; // Actual strings. Used only when loading from ARPA and enumerate_ != NULL - std::vector<std::string> strings_to_enumerate_; + util::Pool string_backing_; + + std::vector<StringPiece> strings_to_enumerate_; }; #pragma pack(push) diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am index ccc5b7f6..5aea33c2 100644 --- a/klm/search/Makefile.am +++ b/klm/search/Makefile.am @@ -2,10 +2,10 @@ noinst_LIBRARIES = libksearch.a libksearch_a_SOURCES = \ edge_generator.cc \ + nbest.cc \ rule.cc \ vertex.cc \ - vertex_generator.cc \ - weights.cc + vertex_generator.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/search/applied.hh b/klm/search/applied.hh new file mode 100644 index 00000000..bd659e5c --- /dev/null +++ b/klm/search/applied.hh @@ -0,0 +1,86 @@ +#ifndef SEARCH_APPLIED__ +#define SEARCH_APPLIED__ + +#include "search/edge.hh" +#include "search/header.hh" +#include "util/pool.hh" + +#include <math.h> + +namespace search { + +// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted. +template <class Below> class GenericApplied : public Header { + public: + GenericApplied() {} + + GenericApplied(void *location, PartialEdge partial) + : Header(location) { + memcpy(Base(), partial.Base(), kHeaderSize); + Below *child_out = Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = Below(part->End()); + } + + GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) { + SetScore(score); + SetNote(note); + } + + explicit GenericApplied(History from) : Header(from) {} + + + // These are arrays of length GetArity(). + Below *Children() { + return reinterpret_cast<Below*>(After()); + } + const Below *Children() const { + return reinterpret_cast<const Below*>(After()); + } + + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Below); + } +}; + +// Applied rule that references itself. +class Applied : public GenericApplied<Applied> { + private: + typedef GenericApplied<Applied> P; + + public: + Applied() {} + Applied(void *location, PartialEdge partial) : P(location, partial) {} + Applied(History from) : P(from) {} +}; + +// How to build single-best hypotheses. +class SingleBest { + public: + typedef PartialEdge Combine; + + void Add(PartialEdge &existing, PartialEdge add) const { + if (!existing.Valid() || existing.GetScore() < add.GetScore()) + existing = add; + } + + NBestComplete Complete(PartialEdge partial) { + if (!partial.Valid()) + return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY); + void *place_final = pool_.Allocate(Applied::Size(partial.GetArity())); + Applied(place_final, partial); + return NBestComplete( + place_final, + partial.CompletedState(), + partial.GetScore()); + } + + private: + util::Pool pool_; +}; + +} // namespace search + +#endif // SEARCH_APPLIED__ diff --git a/klm/search/config.hh b/klm/search/config.hh index ef8e2354..ba18c09e 100644 --- a/klm/search/config.hh +++ b/klm/search/config.hh @@ -1,23 +1,36 @@ #ifndef SEARCH_CONFIG__ #define SEARCH_CONFIG__ -#include "search/weights.hh" -#include "util/string_piece.hh" +#include "search/types.hh" namespace search { +struct NBestConfig { + explicit NBestConfig(unsigned int in_size) { + keep = in_size; + size = in_size; + } + + unsigned int keep, size; +}; + class Config { public: - Config(const Weights &weights, unsigned int pop_limit) : - weights_(weights), pop_limit_(pop_limit) {} + Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) : + lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {} - const Weights &GetWeights() const { return weights_; } + Score LMWeight() const { return lm_weight_; } unsigned int PopLimit() const { return pop_limit_; } + const NBestConfig &GetNBest() const { return nbest_; } + private: - Weights weights_; + Score lm_weight_; + unsigned int pop_limit_; + + NBestConfig nbest_; }; } // namespace search diff --git a/klm/search/context.hh b/klm/search/context.hh index 62163144..08f21bbf 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -1,30 +1,16 @@ #ifndef SEARCH_CONTEXT__ #define SEARCH_CONTEXT__ -#include "lm/model.hh" #include "search/config.hh" -#include "search/final.hh" -#include "search/types.hh" #include "search/vertex.hh" -#include "util/exception.hh" -#include "util/pool.hh" #include <boost/pool/object_pool.hpp> -#include <boost/ptr_container/ptr_vector.hpp> - -#include <vector> namespace search { -class Weights; - class ContextBase { public: - explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - - util::Pool &FinalPool() { - return final_pool_; - } + explicit ContextBase(const Config &config) : config_(config) {} VertexNode *NewVertexNode() { VertexNode *ret = vertex_node_pool_.construct(); @@ -36,18 +22,16 @@ class ContextBase { vertex_node_pool_.destroy(node); } - unsigned int PopLimit() const { return pop_limit_; } + unsigned int PopLimit() const { return config_.PopLimit(); } - const Weights &GetWeights() const { return weights_; } + Score LMWeight() const { return config_.LMWeight(); } - private: - util::Pool final_pool_; + const Config &GetConfig() const { return config_; } + private: boost::object_pool<VertexNode> vertex_node_pool_; - unsigned int pop_limit_; - - const Weights &weights_; + Config config_; }; template <class Model> class Context : public ContextBase { diff --git a/klm/search/dedupe.hh b/klm/search/dedupe.hh new file mode 100644 index 00000000..7eaa3b95 --- /dev/null +++ b/klm/search/dedupe.hh @@ -0,0 +1,131 @@ +#ifndef SEARCH_DEDUPE__ +#define SEARCH_DEDUPE__ + +#include "lm/state.hh" +#include "search/edge_generator.hh" + +#include <boost/pool/object_pool.hpp> +#include <boost/unordered_map.hpp> + +namespace search { + +class Dedupe { + public: + Dedupe() {} + + PartialEdge AllocateEdge(Arity arity) { + return behind_.AllocateEdge(arity); + } + + void AddEdge(PartialEdge edge) { + edge.MutableFlags() = 0; + + uint64_t hash = 0; + const PartialVertex *v = edge.NT(); + const PartialVertex *v_end = v + edge.GetArity(); + for (; v != v_end; ++v) { + const void *ptr = v->Identify(); + hash = util::MurmurHashNative(&ptr, sizeof(const void*), hash); + } + + const lm::ngram::ChartState *c = edge.Between(); + const lm::ngram::ChartState *const c_end = c + edge.GetArity() + 1; + for (; c != c_end; ++c) hash = hash_value(*c, hash); + + std::pair<Table::iterator, bool> ret(table_.insert(std::make_pair(hash, edge))); + if (!ret.second) FoundDupe(ret.first->second, edge); + } + + bool Empty() const { return behind_.Empty(); } + + template <class Model, class Output> void Search(Context<Model> &context, Output &output) { + for (Table::const_iterator i(table_.begin()); i != table_.end(); ++i) { + behind_.AddEdge(i->second); + } + Unpack<Output> unpack(output, *this); + behind_.Search(context, unpack); + } + + private: + void FoundDupe(PartialEdge &table, PartialEdge adding) { + if (table.GetFlags() & kPackedFlag) { + Packed &packed = *static_cast<Packed*>(table.GetNote().mut); + if (table.GetScore() >= adding.GetScore()) { + packed.others.push_back(adding); + return; + } + Note original(packed.original); + packed.original = adding.GetNote(); + adding.SetNote(table.GetNote()); + table.SetNote(original); + packed.others.push_back(table); + packed.starting = adding.GetScore(); + table = adding; + table.MutableFlags() |= kPackedFlag; + return; + } + PartialEdge loser; + if (adding.GetScore() > table.GetScore()) { + loser = table; + table = adding; + } else { + loser = adding; + } + // table is winner, loser is loser... + packed_.construct(table, loser); + } + + struct Packed { + Packed(PartialEdge winner, PartialEdge loser) + : original(winner.GetNote()), starting(winner.GetScore()), others(1, loser) { + winner.MutableNote().vp = this; + winner.MutableFlags() |= kPackedFlag; + loser.MutableFlags() &= ~kPackedFlag; + } + Note original; + Score starting; + std::vector<PartialEdge> others; + }; + + template <class Output> class Unpack { + public: + explicit Unpack(Output &output, Dedupe &owner) : output_(output), owner_(owner) {} + + void NewHypothesis(PartialEdge edge) { + if (edge.GetFlags() & kPackedFlag) { + Packed &packed = *reinterpret_cast<Packed*>(edge.GetNote().mut); + edge.SetNote(packed.original); + edge.MutableFlags() = 0; + std::size_t copy_size = sizeof(PartialVertex) * edge.GetArity() + sizeof(lm::ngram::ChartState); + for (std::vector<PartialEdge>::iterator i = packed.others.begin(); i != packed.others.end(); ++i) { + PartialEdge copy(owner_.AllocateEdge(edge.GetArity())); + copy.SetScore(edge.GetScore() - packed.starting + i->GetScore()); + copy.MutableFlags() = 0; + copy.SetNote(i->GetNote()); + memcpy(copy.NT(), edge.NT(), copy_size); + output_.NewHypothesis(copy); + } + } + output_.NewHypothesis(edge); + } + + void FinishedSearch() { + output_.FinishedSearch(); + } + + private: + Output &output_; + Dedupe &owner_; + }; + + EdgeGenerator behind_; + + typedef boost::unordered_map<uint64_t, PartialEdge> Table; + Table table_; + + boost::object_pool<Packed> packed_; + + static const uint16_t kPackedFlag = 1; +}; +} // namespace search +#endif // SEARCH_DEDUPE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index 260159b1..eacf5de5 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -1,6 +1,7 @@ #include "search/edge_generator.hh" #include "lm/left.hh" +#include "lm/model.hh" #include "lm/partial.hh" #include "search/context.hh" #include "search/vertex.hh" @@ -38,7 +39,7 @@ template <class Model> void FastScore(const Context<Model> &context, Arity victi *cover = *(cover + 1); } } - update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); + update.SetScore(update.GetScore() + adjustment * context.LMWeight()); } } // namespace diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index 582c78b7..203942c6 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -2,7 +2,6 @@ #define SEARCH_EDGE_GENERATOR__ #include "search/edge.hh" -#include "search/note.hh" #include "search/types.hh" #include <queue> diff --git a/klm/search/final.hh b/klm/search/final.hh deleted file mode 100644 index 50e62cf2..00000000 --- a/klm/search/final.hh +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef SEARCH_FINAL__ -#define SEARCH_FINAL__ - -#include "search/header.hh" -#include "util/pool.hh" - -namespace search { - -// A full hypothesis with pointers to children. -class Final : public Header { - public: - Final() {} - - Final(util::Pool &pool, Score score, Arity arity, Note note) - : Header(pool.Allocate(Size(arity)), arity) { - SetScore(score); - SetNote(note); - } - - // These are arrays of length GetArity(). - Final *Children() { - return reinterpret_cast<Final*>(After()); - } - const Final *Children() const { - return reinterpret_cast<const Final*>(After()); - } - - private: - static std::size_t Size(Arity arity) { - return kHeaderSize + arity * sizeof(const Final); - } -}; - -} // namespace search - -#endif // SEARCH_FINAL__ diff --git a/klm/search/header.hh b/klm/search/header.hh index 25550dbe..69f0eed0 100644 --- a/klm/search/header.hh +++ b/klm/search/header.hh @@ -3,7 +3,6 @@ // Header consisting of Score, Arity, and Note -#include "search/note.hh" #include "search/types.hh" #include <stdint.h> @@ -24,6 +23,9 @@ class Header { bool operator<(const Header &other) const { return GetScore() < other.GetScore(); } + bool operator>(const Header &other) const { + return GetScore() > other.GetScore(); + } Arity GetArity() const { return *reinterpret_cast<const Arity*>(base_ + sizeof(Score)); @@ -36,9 +38,14 @@ class Header { *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to; } + uint8_t *Base() { return base_; } + const uint8_t *Base() const { return base_; } + protected: Header() : base_(NULL) {} + explicit Header(void *base) : base_(static_cast<uint8_t*>(base)) {} + Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) { *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity; } diff --git a/klm/search/nbest.cc b/klm/search/nbest.cc new file mode 100644 index 00000000..ec3322c9 --- /dev/null +++ b/klm/search/nbest.cc @@ -0,0 +1,106 @@ +#include "search/nbest.hh" + +#include "util/pool.hh" + +#include <algorithm> +#include <functional> +#include <queue> + +#include <assert.h> +#include <math.h> + +namespace search { + +NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) { + assert(!partials.empty()); + std::vector<PartialEdge>::iterator end; + if (partials.size() > keep) { + end = partials.begin() + keep; + std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>()); + } else { + end = partials.end(); + } + for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) { + queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i)); + } +} + +Score NBestList::TopAfterConstructor() const { + assert(revealed_.empty()); + return queue_.top().GetScore(); +} + +const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) { + while (revealed_.size() < n && !queue_.empty()) { + MoveTop(pool); + } + return revealed_; +} + +Score NBestList::Visit(util::Pool &pool, std::size_t index) { + if (index + 1 < revealed_.size()) + return revealed_[index + 1].GetScore() - revealed_[index].GetScore(); + if (queue_.empty()) + return -INFINITY; + if (index + 1 == revealed_.size()) + return queue_.top().GetScore() - revealed_[index].GetScore(); + assert(index == revealed_.size()); + + MoveTop(pool); + + if (queue_.empty()) return -INFINITY; + return queue_.top().GetScore() - revealed_[index].GetScore(); +} + +Applied NBestList::Get(util::Pool &pool, std::size_t index) { + assert(index <= revealed_.size()); + if (index == revealed_.size()) MoveTop(pool); + return revealed_[index]; +} + +void NBestList::MoveTop(util::Pool &pool) { + assert(!queue_.empty()); + QueueEntry entry(queue_.top()); + queue_.pop(); + RevealedRef *const children_begin = entry.Children(); + RevealedRef *const children_end = children_begin + entry.GetArity(); + Score basis = entry.GetScore(); + for (RevealedRef *child = children_begin; child != children_end; ++child) { + Score change = child->in_->Visit(pool, child->index_); + if (change != -INFINITY) { + assert(change < 0.001); + QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote()); + std::copy(children_begin, child, new_entry.Children()); + RevealedRef *update = new_entry.Children() + (child - children_begin); + update->in_ = child->in_; + update->index_ = child->index_ + 1; + std::copy(child + 1, children_end, update + 1); + queue_.push(new_entry); + } + // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010. + if (child->index_) break; + } + + // Convert QueueEntry to Applied. This leaves some unused memory. + void *overwrite = entry.Children(); + for (unsigned int i = 0; i < entry.GetArity(); ++i) { + RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i)); + *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_); + } + revealed_.push_back(Applied(entry.Base())); +} + +NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) { + assert(!partials.empty()); + NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep); + return NBestComplete( + list, + partials.front().CompletedState(), // All partials have the same state + list->TopAfterConstructor()); +} + +const std::vector<Applied> &NBest::Extract(History history) { + return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size); +} + +} // namespace search diff --git a/klm/search/nbest.hh b/klm/search/nbest.hh new file mode 100644 index 00000000..cb7651bc --- /dev/null +++ b/klm/search/nbest.hh @@ -0,0 +1,81 @@ +#ifndef SEARCH_NBEST__ +#define SEARCH_NBEST__ + +#include "search/applied.hh" +#include "search/config.hh" +#include "search/edge.hh" + +#include <boost/pool/object_pool.hpp> + +#include <cstddef> +#include <queue> +#include <vector> + +#include <assert.h> + +namespace search { + +class NBestList; + +class NBestList { + private: + class RevealedRef { + public: + explicit RevealedRef(History history) + : in_(static_cast<NBestList*>(history)), index_(0) {} + + private: + friend class NBestList; + + NBestList *in_; + std::size_t index_; + }; + + typedef GenericApplied<RevealedRef> QueueEntry; + + public: + NBestList(std::vector<PartialEdge> &existing, util::Pool &entry_pool, std::size_t keep); + + Score TopAfterConstructor() const; + + const std::vector<Applied> &Extract(util::Pool &pool, std::size_t n); + + private: + Score Visit(util::Pool &pool, std::size_t index); + + Applied Get(util::Pool &pool, std::size_t index); + + void MoveTop(util::Pool &pool); + + typedef std::vector<Applied> Revealed; + Revealed revealed_; + + typedef std::priority_queue<QueueEntry> Queue; + Queue queue_; +}; + +class NBest { + public: + typedef std::vector<PartialEdge> Combine; + + explicit NBest(const NBestConfig &config) : config_(config) {} + + void Add(std::vector<PartialEdge> &existing, PartialEdge addition) const { + existing.push_back(addition); + } + + NBestComplete Complete(std::vector<PartialEdge> &partials); + + const std::vector<Applied> &Extract(History root); + + private: + const NBestConfig config_; + + boost::object_pool<NBestList> list_pool_; + + util::Pool entry_pool_; +}; + +} // namespace search + +#endif // SEARCH_NBEST__ diff --git a/klm/search/note.hh b/klm/search/note.hh deleted file mode 100644 index 50bed06e..00000000 --- a/klm/search/note.hh +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef SEARCH_NOTE__ -#define SEARCH_NOTE__ - -namespace search { - -union Note { - const void *vp; -}; - -} // namespace search - -#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 5b00207e..0244a09f 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -1,7 +1,7 @@ #include "search/rule.hh" +#include "lm/model.hh" #include "search/context.hh" -#include "search/final.hh" #include <ostream> @@ -9,35 +9,35 @@ namespace search { -template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) { - unsigned int oov_count = 0; - float prob = 0.0; - const Model &model = context.LanguageModel(); - const lm::WordIndex oov = model.GetVocabulary().NotFound(); - for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) { - lm::ngram::RuleScore<Model> scorer(model, *(writing++)); - // TODO: optimize - if (prepend_bos && (word == words.begin())) { - scorer.BeginSentence(); - } - for (; ; ++word) { - if (word == words.end()) { - prob += scorer.Finish(); - return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); - } - if (*word == kNonTerminal) break; - if (*word == oov) ++oov_count; +template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing) { + ScoreRuleRet ret; + ret.prob = 0.0; + ret.oov = 0; + const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence(); + lm::ngram::RuleScore<Model> scorer(model, *(writing++)); + std::vector<lm::WordIndex>::const_iterator word = words.begin(); + if (word != words.end() && *word == bos) { + scorer.BeginSentence(); + ++word; + } + for (; word != words.end(); ++word) { + if (*word == kNonTerminal) { + ret.prob += scorer.Finish(); + scorer.Reset(*(writing++)); + } else { + if (*word == oov) ++ret.oov; scorer.Terminal(*word); } - prob += scorer.Finish(); } + ret.prob += scorer.Finish(); + return ret; } -template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 0ce2794d..43ca6162 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -9,11 +9,16 @@ namespace search { -template <class Model> class Context; - const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; -template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out); +struct ScoreRuleRet { + Score prob; + unsigned int oov; +}; + +// Pass <s> and </s> normally. +// Indicate non-terminals with kNonTerminal. +template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *state_out); } // namespace search diff --git a/klm/search/types.hh b/klm/search/types.hh index 06eb5bfa..f9c849b3 100644 --- a/klm/search/types.hh +++ b/klm/search/types.hh @@ -3,12 +3,29 @@ #include <stdint.h> +namespace lm { namespace ngram { class ChartState; } } + namespace search { typedef float Score; typedef uint32_t Arity; +union Note { + const void *vp; +}; + +typedef void *History; + +struct NBestComplete { + NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score) + : history(in_history), state(&in_state), score(in_score) {} + + History history; + const lm::ngram::ChartState *state; + Score score; +}; + } // namespace search #endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index 11f4631f..45842982 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -19,21 +19,34 @@ struct GreaterByBound : public std::binary_function<const VertexNode *, const Ve } // namespace -void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { +void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) { if (Complete()) { - assert(end_.Valid()); + assert(end_); assert(extend_.empty()); - bound_ = end_.GetScore(); return; } - if (extend_.size() == 1 && parent_ptr) { - *parent_ptr = extend_[0]; - extend_[0]->SortAndSet(context, parent_ptr); + if (extend_.size() == 1) { + parent_ptr = extend_[0]; + extend_[0]->RecursiveSortAndSet(context, parent_ptr); context.DeleteVertexNode(this); return; } for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { - (*i)->SortAndSet(context, &*i); + (*i)->RecursiveSortAndSet(context, *i); + } + std::sort(extend_.begin(), extend_.end(), GreaterByBound()); + bound_ = extend_.front()->Bound(); +} + +void VertexNode::SortAndSet(ContextBase &context) { + // This is the root. The root might be empty. + if (extend_.empty()) { + bound_ = -INFINITY; + return; + } + // The root cannot be replaced. There's always one transition. + for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { + (*i)->RecursiveSortAndSet(context, *i); } std::sort(extend_.begin(), extend_.end(), GreaterByBound()); bound_ = extend_.front()->Bound(); diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index 52bc1dfe..10b3339b 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -2,7 +2,6 @@ #define SEARCH_VERTEX__ #include "lm/left.hh" -#include "search/final.hh" #include "search/types.hh" #include <boost/unordered_set.hpp> @@ -10,6 +9,7 @@ #include <queue> #include <vector> +#include <math.h> #include <stdint.h> namespace search { @@ -18,7 +18,7 @@ class ContextBase; class VertexNode { public: - VertexNode() {} + VertexNode() : end_() {} void InitRoot() { extend_.clear(); @@ -26,7 +26,7 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - end_ = Final(); + end_ = History(); } lm::ngram::ChartState &MutableState() { return state_; } @@ -36,20 +36,21 @@ class VertexNode { extend_.push_back(next); } - void SetEnd(Final end) { - assert(!end_.Valid()); + void SetEnd(History end, Score score) { + assert(!end_); end_ = end; + bound_ = score; } - void SortAndSet(ContextBase &context, VertexNode **parent_pointer); + void SortAndSet(ContextBase &context); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_.Valid() && extend_.empty(); + return !end_ && extend_.empty(); } bool Complete() const { - return end_.Valid(); + return end_; } const lm::ngram::ChartState &State() const { return state_; } @@ -64,7 +65,7 @@ class VertexNode { } // Will be invalid unless this is a leaf. - const Final End() const { return end_; } + const History End() const { return end_; } const VertexNode &operator[](size_t index) const { return *extend_[index]; @@ -75,13 +76,15 @@ class VertexNode { } private: + void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); + std::vector<VertexNode*> extend_; lm::ngram::ChartState state_; bool right_full_; Score bound_; - Final end_; + History end_; }; class PartialVertex { @@ -97,7 +100,7 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } + Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } unsigned char Length() const { return back_->Length(); } @@ -121,7 +124,7 @@ class PartialVertex { return ret; } - const Final End() const { + const History End() const { return back_->End(); } @@ -130,16 +133,18 @@ class PartialVertex { unsigned int index_; }; +template <class Output> class VertexGenerator; + class Vertex { public: Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } - const Final BestChild() const { + const History BestChild() const { PartialVertex top(RootPartial()); if (top.Empty()) { - return Final(); + return History(); } else { PartialVertex continuation; while (!top.Complete()) { @@ -150,8 +155,8 @@ class Vertex { } private: - friend class VertexGenerator; - + template <class Output> friend class VertexGenerator; + template <class Output> friend class RootVertexGenerator; VertexNode root_; }; diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 0945fe55..73139ffc 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -4,26 +4,18 @@ #include "search/context.hh" #include "search/edge.hh" +#include <boost/unordered_map.hpp> +#include <boost/version.hpp> + #include <stdint.h> namespace search { -VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { - gen.root_.InitRoot(); -} - +#if BOOST_VERSION > 104200 namespace { const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); -// Parallel structure to VertexNode. -struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map<uint64_t, Trie> extend; -}; - Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { Trie &next = node.extend[added]; if (!next.under) { @@ -39,19 +31,10 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n return next; } -void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { - Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); - Final *child_out = final.Children(); - const PartialVertex *part = partial.NT(); - const PartialVertex *const part_end_loop = part + partial.GetArity(); - for (; part != part_end_loop; ++part, ++child_out) - *child_out = part->End(); - - starter.under->SetEnd(final); -} +} // namespace -void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { - const lm::ngram::ChartState &state = partial.CompletedState(); +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) { + const lm::ngram::ChartState &state = *end.state; unsigned char left = 0, right = 0; Trie *node = &root; @@ -77,18 +60,9 @@ void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { } node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - CompleteTransition(context, *node, partial); + node->under->SetEnd(end.history, end.score); } -} // namespace - -void VertexGenerator::FinishedSearch() { - Trie root; - root.under = &gen_.root_; - for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { - AddHypothesis(context_, root, i->second); - } - root.under->SortAndSet(context_, NULL); -} +#endif // BOOST_VERSION } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 60e86112..da563c2d 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -2,9 +2,11 @@ #define SEARCH_VERTEX_GENERATOR__ #include "search/edge.hh" +#include "search/types.hh" #include "search/vertex.hh" #include <boost/unordered_map.hpp> +#include <boost/version.hpp> namespace lm { namespace ngram { @@ -15,21 +17,44 @@ class ChartState; namespace search { class ContextBase; -class Final; -class VertexGenerator { +#if BOOST_VERSION > 104200 +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map<uint64_t, Trie> extend; +}; + +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); + +#endif // BOOST_VERSION + +// Output makes the single-best or n-best list. +template <class Output> class VertexGenerator { public: - VertexGenerator(ContextBase &context, Vertex &gen); + VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { + gen.root_.InitRoot(); + } void NewHypothesis(PartialEdge partial) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial))); - if (!ret.second && ret.first->second < partial) { - ret.first->second = partial; - } + nbest_.Add(existing_[hash_value(partial.CompletedState())], partial); } - void FinishedSearch(); + void FinishedSearch() { +#if BOOST_VERSION > 104200 + Trie root; + root.under = &gen_.root_; + for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, nbest_.Complete(i->second)); + } + existing_.clear(); + root.under->SortAndSet(context_); +#else + UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); +#endif + } const Vertex &Generating() const { return gen_; } @@ -38,8 +63,35 @@ class VertexGenerator { Vertex &gen_; - typedef boost::unordered_map<uint64_t, PartialEdge> Existing; + typedef boost::unordered_map<uint64_t, typename Output::Combine> Existing; Existing existing_; + + Output &nbest_; +}; + +// Special case for root vertex: everything should come together into the root +// node. In theory, this should happen naturally due to state collapsing with +// <s> and </s>. If that's the case, VertexGenerator is fine, though it will +// make one connection. +template <class Output> class RootVertexGenerator { + public: + RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {} + + void NewHypothesis(PartialEdge partial) { + out_.Add(combine_, partial); + } + + void FinishedSearch() { + gen_.root_.InitRoot(); + NBestComplete completed(out_.Complete(combine_)); + gen_.root_.SetEnd(completed.history, completed.score); + } + + private: + Vertex &gen_; + + typename Output::Combine combine_; + Output &out_; }; } // namespace search diff --git a/klm/search/weights.cc b/klm/search/weights.cc deleted file mode 100644 index d65471ad..00000000 --- a/klm/search/weights.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include "search/weights.hh" -#include "util/tokenize_piece.hh" - -#include <cstdlib> - -namespace search { - -namespace { -struct Insert { - void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const { - std::string copy(name.data(), name.size()); - map[copy] = score; - } -}; - -struct DotProduct { - search::Score total; - DotProduct() : total(0.0) {} - - void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) { - boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name)); - if (i != map.end()) - total += score * i->second; - } -}; - -template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) { - for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) { - util::TokenIter<util::SingleCharacter> equals(*spaces, '='); - UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); - StringPiece name(*equals); - UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); - char *end; - // Assumes proper termination. - double value = std::strtod(equals->data(), &end); - UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); - UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); - op(map, name, value); - } -} - -} // namespace - -Weights::Weights(StringPiece text) { - Insert op; - Parse<Map, Insert>(text, map_, op); - lm_ = Steal("LanguageModel"); - oov_ = Steal("OOV"); - word_penalty_ = Steal("WordPenalty"); -} - -Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} - -search::Score Weights::DotNoLM(StringPiece text) const { - DotProduct dot; - Parse<const Map, DotProduct>(text, map_, dot); - return dot.total; -} - -float Weights::Steal(const std::string &str) { - Map::iterator i(map_.find(str)); - if (i == map_.end()) { - return 0.0; - } else { - float ret = i->second; - map_.erase(i); - return ret; - } -} - -} // namespace search diff --git a/klm/search/weights.hh b/klm/search/weights.hh deleted file mode 100644 index df1c419f..00000000 --- a/klm/search/weights.hh +++ /dev/null @@ -1,52 +0,0 @@ -// For now, the individual features are not kept. -#ifndef SEARCH_WEIGHTS__ -#define SEARCH_WEIGHTS__ - -#include "search/types.hh" -#include "util/exception.hh" -#include "util/string_piece.hh" - -#include <boost/unordered_map.hpp> - -#include <string> - -namespace search { - -class WeightParseException : public util::Exception { - public: - WeightParseException() {} - ~WeightParseException() throw() {} -}; - -class Weights { - public: - // Parses weights, sets lm_weight_, removes it from map_. - explicit Weights(StringPiece text); - - // Just the three scores we care about adding. - Weights(Score lm, Score oov, Score word_penalty); - - Score DotNoLM(StringPiece text) const; - - Score LM() const { return lm_; } - - Score OOV() const { return oov_; } - - Score WordPenalty() const { return word_penalty_; } - - // Mostly for testing. - const boost::unordered_map<std::string, Score> &GetMap() const { return map_; } - - private: - float Steal(const std::string &str); - - typedef boost::unordered_map<std::string, Score> Map; - - Map map_; - - Score lm_, oov_, word_penalty_; -}; - -} // namespace search - -#endif // SEARCH_WEIGHTS__ diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc deleted file mode 100644 index 4811ff06..00000000 --- a/klm/search/weights_test.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include "search/weights.hh" - -#define BOOST_TEST_MODULE WeightTest -#include <boost/test/unit_test.hpp> -#include <boost/test/floating_point_comparison.hpp> - -namespace search { -namespace { - -#define CHECK_WEIGHT(value, string) \ - i = parsed.find(string); \ - BOOST_REQUIRE(i != parsed.end()); \ - BOOST_CHECK_CLOSE((value), i->second, 0.001); - -BOOST_AUTO_TEST_CASE(parse) { - // These are not real feature weights. - Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); - const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap(); - boost::unordered_map<std::string, search::Score>::const_iterator i; - CHECK_WEIGHT(0.0, "rarity"); - CHECK_WEIGHT(0.0, "phrase-SGT"); - CHECK_WEIGHT(9.45117, "phrase-TGS"); - CHECK_WEIGHT(2.33833, "lexical-SGT"); - BOOST_CHECK(parsed.end() == parsed.find("lm")); - BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); - CHECK_WEIGHT(-28.3317, "lexical-TGS"); - CHECK_WEIGHT(5.0, "glue?"); -} - -BOOST_AUTO_TEST_CASE(dot) { - Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); - BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); - BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); - BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); -} - -} // namespace -} // namespace search diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 5306850f..a676bdb3 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -27,6 +27,7 @@ libklm_util_a_SOURCES = \ mmap.cc \ murmur_hash.cc \ pool.cc \ + read_compressed.cc \ string_piece.cc \ usage.cc diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 053a850b..0165a7a3 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -87,8 +87,14 @@ template <class Except, class Data> typename Except::template ExceptionTag<Excep throw UTIL_e; \ } while (0) +#if __GNUC__ >= 3 +#define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0) +#else +#define UTIL_UNLIKELY(x) (x) +#endif + #define UTIL_THROW_IF(Condition, Exception, Modify) do { \ - if (Condition) { \ + if (UTIL_UNLIKELY(Condition)) { \ Exception UTIL_e; \ UTIL_SET_LOCATION(UTIL_e, #Exception, #Condition); \ UTIL_e << Modify; \ diff --git a/klm/util/file.cc b/klm/util/file.cc index 6bf879ac..b9a77cf9 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -15,6 +15,8 @@ #if defined(_WIN32) || defined(_WIN64) #include <windows.h> #include <io.h> +#include <algorithm> +#include <limits.h> #else #include <unistd.h> #endif @@ -48,7 +50,7 @@ int OpenReadOrThrow(const char *name) { int CreateOrThrow(const char *name) { int ret; #if defined(_WIN32) || defined(_WIN64) - UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); + UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR | _O_BINARY, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); #else UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name); #endif @@ -74,16 +76,22 @@ void ResizeOrThrow(int fd, uint64_t to) { #endif } -#ifdef WIN32 -typedef int ssize_t; +std::size_t PartialRead(int fd, void *to, std::size_t amount) { +#if defined(_WIN32) || defined(_WIN64) + amount = min(static_cast<std::size_t>(INT_MAX), amount); + int ret = _read(fd, to, amount); +#else + ssize_t ret = read(fd, to, amount); #endif + UTIL_THROW_IF(ret < 0, ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); + return static_cast<std::size_t>(ret); +} void ReadOrThrow(int fd, void *to_void, std::size_t amount) { uint8_t *to = static_cast<uint8_t*>(to_void); while (amount) { - ssize_t ret = read(fd, to, amount); - UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); - UTIL_THROW_IF(ret == 0, EndOfFileException, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); + std::size_t ret = PartialRead(fd, to, amount); + UTIL_THROW_IF(ret == 0, EndOfFileException, " in fd " << fd << " but there should be " << amount << " more bytes to read."); amount -= ret; to += ret; } @@ -93,8 +101,7 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) { uint8_t *to = static_cast<uint8_t*>(to_void); std::size_t remaining = amount; while (remaining) { - ssize_t ret = read(fd, to, remaining); - UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << remaining << " from fd " << fd << " failed."); + std::size_t ret = PartialRead(fd, to, remaining); if (!ret) return amount - remaining; remaining -= ret; to += ret; @@ -105,7 +112,11 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) { void WriteOrThrow(int fd, const void *data_void, std::size_t size) { const uint8_t *data = static_cast<const uint8_t*>(data_void); while (size) { +#if defined(_WIN32) || defined(_WIN64) + int ret = write(fd, data, min(static_cast<std::size_t>(INT_MAX), size)); +#else ssize_t ret = write(fd, data, size); +#endif if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); data += ret; size -= ret; @@ -114,7 +125,7 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) { void WriteOrThrow(FILE *to, const void *data, std::size_t size) { assert(size); - if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); + UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), util::ErrnoException, "Short write; requested size " << size); } void FSyncOrThrow(int fd) { @@ -149,14 +160,15 @@ void SeekEnd(int fd) { std::FILE *FDOpenOrThrow(scoped_fd &file) { std::FILE *ret = fdopen(file.get(), "r+b"); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen"); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get()); file.release(); return ret; } -std::FILE *FOpenOrThrow(const char *path, const char *mode) { - std::FILE *ret; - UTIL_THROW_IF(!(ret = fopen(path, mode)), util::ErrnoException, "Could not fopen " << path << " for " << mode); +std::FILE *FDOpenReadOrThrow(scoped_fd &file) { + std::FILE *ret = fdopen(file.get(), "rb"); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get()); + file.release(); return ret; } diff --git a/klm/util/file.hh b/klm/util/file.hh index 185cb1f3..c24580d6 100644 --- a/klm/util/file.hh +++ b/klm/util/file.hh @@ -32,8 +32,6 @@ class scoped_fd { return ret; } - operator bool() { return fd_ != -1; } - private: int fd_; @@ -76,8 +74,9 @@ uint64_t SizeFile(int fd); void ResizeOrThrow(int fd, uint64_t to); +std::size_t PartialRead(int fd, void *to, std::size_t size); void ReadOrThrow(int fd, void *to, std::size_t size); -std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount); +std::size_t ReadOrEOF(int fd, void *to_void, std::size_t size); void WriteOrThrow(int fd, const void *data_void, std::size_t size); void WriteOrThrow(FILE *to, const void *data, std::size_t size); @@ -90,8 +89,7 @@ void AdvanceOrThrow(int fd, int64_t off); void SeekEnd(int fd); std::FILE *FDOpenOrThrow(scoped_fd &file); - -std::FILE *FOpenOrThrow(const char *path, const char *mode); +std::FILE *FDOpenReadOrThrow(scoped_fd &file); class TempMaker { public: diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 280f438c..5a208eff 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -14,7 +14,6 @@ #include <limits> #include <assert.h> -#include <ctype.h> #include <fcntl.h> #include <stdlib.h> #include <sys/types.h> @@ -26,13 +25,6 @@ ParseNumberException::ParseNumberException(StringPiece value) throw() { *this << "Could not parse \"" << value << "\" into a number"; } -#ifdef HAVE_ZLIB -GZException::GZException(gzFile file) { - int num; - *this << gzerror(file, &num) << " from zlib"; -} -#endif // HAVE_ZLIB - // Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale). const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; @@ -48,19 +40,7 @@ FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std: Initialize(name, show_progress, min_buffer); } -FilePiece::~FilePiece() { -#ifdef HAVE_ZLIB - if (gz_file_) { - // zlib took ownership - file_.release(); - int ret; - if (Z_OK != (ret = gzclose(gz_file_))) { - std::cerr << "could not close file " << file_name_ << " using zlib" << std::endl; - abort(); - } - } -#endif -} +FilePiece::~FilePiece() {} StringPiece FilePiece::ReadLine(char delim) { std::size_t skip = 0; @@ -95,9 +75,6 @@ unsigned long int FilePiece::ReadULong() { } void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer) { -#ifdef HAVE_ZLIB - gz_file_ = NULL; -#endif file_name_ = name; default_map_size_ = page_ * std::max<std::size_t>((min_buffer / page_ + 1), 2); @@ -117,10 +94,7 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::s } Shift(); // gzip detect. - if ((position_end_ - position_) > 2 && *position_ == 0x1f && static_cast<unsigned char>(*(position_ + 1)) == 0x8b) { -#ifndef HAVE_ZLIB - UTIL_THROW(GZException, "Looks like a gzip file but support was not compiled in."); -#endif + if ((position_end_ - position_) >= ReadCompressed::kMagicSize && ReadCompressed::DetectCompressedMagic(position_)) { if (!fallback_to_read_) { at_end_ = false; TransitionToRead(); @@ -197,7 +171,7 @@ void FilePiece::Shift() { if (fallback_to_read_) ReadShift(); for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) { - if (isspace(*last_space_)) break; + if (kSpaces[static_cast<unsigned char>(*last_space_)]) break; } } @@ -248,17 +222,14 @@ void FilePiece::TransitionToRead() { position_ = data_.begin(); position_end_ = position_; -#ifdef HAVE_ZLIB - assert(!gz_file_); - gz_file_ = gzdopen(file_.get(), "r"); - UTIL_THROW_IF(!gz_file_, GZException, "zlib failed to open " << file_name_); -#endif + try { + fell_back_.Reset(file_.release()); + } catch (util::Exception &e) { + e << " in file " << file_name_; + throw; + } } -#ifdef WIN32 -typedef int ssize_t; -#endif - void FilePiece::ReadShift() { assert(fallback_to_read_); // Bytes [data_.begin(), position_) have been consumed. @@ -283,7 +254,7 @@ void FilePiece::ReadShift() { position_ = data_.begin(); position_end_ = position_ + valid_length; } else { - size_t moving = position_end_ - position_; + std::size_t moving = position_end_ - position_; memmove(data_.get(), position_, moving); position_ = data_.begin(); position_end_ = position_ + moving; @@ -291,20 +262,9 @@ void FilePiece::ReadShift() { } } - ssize_t read_return; -#ifdef HAVE_ZLIB - read_return = gzread(gz_file_, static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read); - if (read_return == -1) throw GZException(gz_file_); - if (total_size_ != kBadSize) { - // Just get the position, don't actually seek. Apparently this is how you do it. . . - off_t ret = lseek(file_.get(), 0, SEEK_CUR); - if (ret != -1) progress_.Set(ret); - } -#else - read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read); - UTIL_THROW_IF(read_return == -1, ErrnoException, "read failed"); - progress_.Set(mapped_offset_); -#endif + std::size_t read_return = fell_back_.Read(static_cast<uint8_t*>(data_.get()) + already_read, default_map_size_ - already_read); + progress_.Set(fell_back_.RawAmount()); + if (read_return == 0) { at_end_ = true; } diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index af93d8aa..39bd1581 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -4,8 +4,8 @@ #include "util/ersatz_progress.hh" #include "util/exception.hh" #include "util/file.hh" -#include "util/have.hh" #include "util/mmap.hh" +#include "util/read_compressed.hh" #include "util/string_piece.hh" #include <cstddef> @@ -13,10 +13,6 @@ #include <stdint.h> -#ifdef HAVE_ZLIB -#include <zlib.h> -#endif - namespace util { class ParseNumberException : public Exception { @@ -25,28 +21,19 @@ class ParseNumberException : public Exception { ~ParseNumberException() throw() {} }; -class GZException : public Exception { - public: -#ifdef HAVE_ZLIB - explicit GZException(gzFile file); -#endif - GZException() throw() {} - ~GZException() throw() {} -}; - extern const bool kSpaces[256]; -// Memory backing the returned StringPiece may vanish on the next call. +// Memory backing the returned StringPiece may vanish on the next call. class FilePiece { public: - // 32 MB default. - explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432); - // Takes ownership of fd. name is used for messages. - explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432); + // 1 MB default. + explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576); + // Takes ownership of fd. name is used for messages. + explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576); ~FilePiece(); - - char get() { + + char get() { if (position_ == position_end_) { Shift(); if (at_end_) throw EndOfFileException(); @@ -54,14 +41,14 @@ class FilePiece { return *(position_++); } - // Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace(). + // Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace(). StringPiece ReadDelimited(const bool *delim = kSpaces) { SkipSpaces(delim); return Consume(FindDelimiterOrEOF(delim)); } // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. - // It is similar to getline in that way. + // It is similar to getline in that way. StringPiece ReadLine(char delim = '\n'); float ReadFloat(); @@ -69,7 +56,7 @@ class FilePiece { long int ReadLong(); unsigned long int ReadULong(); - // Skip spaces defined by isspace. + // Skip spaces defined by isspace. void SkipSpaces(const bool *delim = kSpaces) { for (; ; ++position_) { if (position_ == position_end_) Shift(); @@ -82,7 +69,7 @@ class FilePiece { } const std::string &FileName() const { return file_name_; } - + private: void Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer); @@ -122,9 +109,7 @@ class FilePiece { std::string file_name_; -#ifdef HAVE_ZLIB - gzFile gz_file_; -#endif // HAVE_ZLIB + ReadCompressed fell_back_; }; } // namespace util diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc index f912e18a..e79ece7a 100644 --- a/klm/util/file_piece_test.cc +++ b/klm/util/file_piece_test.cc @@ -38,7 +38,7 @@ BOOST_AUTO_TEST_CASE(MMapReadLine) { BOOST_CHECK_THROW(test.get(), EndOfFileException); } -#ifndef __APPLE__ +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__) /* Apple isn't happy with the popen, fileno, dup. And I don't want to * reimplement popen. This is an issue with the test. */ @@ -65,7 +65,7 @@ BOOST_AUTO_TEST_CASE(StreamReadLine) { BOOST_CHECK_THROW(test.get(), EndOfFileException); BOOST_REQUIRE(!pclose(catter)); } -#endif // __APPLE__ +#endif #ifdef HAVE_ZLIB diff --git a/klm/util/have.hh b/klm/util/have.hh index b8181e99..85b838e4 100644 --- a/klm/util/have.hh +++ b/klm/util/have.hh @@ -2,22 +2,16 @@ #ifndef UTIL_HAVE__ #define UTIL_HAVE__ -#ifndef HAVE_ZLIB -#if !defined(_WIN32) && !defined(_WIN64) -#define HAVE_ZLIB -#endif -#endif - #ifndef HAVE_ICU //#define HAVE_ICU #endif #ifndef HAVE_BOOST -#define HAVE_BOOST +//#define HAVE_BOOST #endif -#ifndef HAVE_THREADS -//#define HAVE_THREADS +#ifdef HAVE_CONFIG_H +#include "config.h" #endif #endif // UTIL_HAVE__ diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh index cf3d8432..1b43ddcf 100644 --- a/klm/util/joint_sort.hh +++ b/klm/util/joint_sort.hh @@ -60,7 +60,7 @@ template <class KeyIter, class ValueIter> class JointProxy { JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {} JointProxy(const JointProxy<KeyIter, ValueIter> &other) : inner_(other.inner_) {} - operator const value_type() const { + operator value_type() const { value_type ret; ret.key = *inner_.key_; ret.value = *inner_.value_; @@ -121,7 +121,7 @@ template <class Proxy, class Less> class LessWrapper : public std::binary_functi template <class KeyIter, class ValueIter> class PairedIterator : public ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > { public: - PairedIterator(const KeyIter &key, const ValueIter &value) : + PairedIterator(const KeyIter &key, const ValueIter &value) : ProxyIterator<detail::JointProxy<KeyIter, ValueIter> >(detail::JointProxy<KeyIter, ValueIter>(key, value)) {} }; diff --git a/klm/util/read_compressed.cc b/klm/util/read_compressed.cc new file mode 100644 index 00000000..4ec94c4e --- /dev/null +++ b/klm/util/read_compressed.cc @@ -0,0 +1,403 @@ +#include "util/read_compressed.hh" + +#include "util/file.hh" +#include "util/have.hh" +#include "util/scoped.hh" + +#include <algorithm> +#include <iostream> + +#include <assert.h> +#include <limits.h> +#include <stdlib.h> +#include <string.h> + +#ifdef HAVE_ZLIB +#include <zlib.h> +#endif + +#ifdef HAVE_BZLIB +#include <bzlib.h> +#endif + +#ifdef HAVE_XZLIB +#include <lzma.h> +#endif + +namespace util { + +CompressedException::CompressedException() throw() {} +CompressedException::~CompressedException() throw() {} + +GZException::GZException() throw() {} +GZException::~GZException() throw() {} + +BZException::BZException() throw() {} +BZException::~BZException() throw() {} + +XZException::XZException() throw() {} +XZException::~XZException() throw() {} + +class ReadBase { + public: + virtual ~ReadBase() {} + + virtual std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) = 0; + + protected: + static void ReplaceThis(ReadBase *with, ReadCompressed &thunk) { + thunk.internal_.reset(with); + } + + static uint64_t &ReadCount(ReadCompressed &thunk) { + return thunk.raw_amount_; + } +}; + +namespace { + +// Completed file that other classes can thunk to. +class Complete : public ReadBase { + public: + std::size_t Read(void *, std::size_t, ReadCompressed &) { + return 0; + } +}; + +class Uncompressed : public ReadBase { + public: + explicit Uncompressed(int fd) : fd_(fd) {} + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + std::size_t got = PartialRead(fd_.get(), to, amount); + ReadCount(thunk) += got; + return got; + } + + private: + scoped_fd fd_; +}; + +class UncompressedWithHeader : public ReadBase { + public: + UncompressedWithHeader(int fd, void *already_data, std::size_t already_size) : fd_(fd) { + assert(already_size); + buf_.reset(malloc(already_size)); + if (!buf_.get()) throw std::bad_alloc(); + memcpy(buf_.get(), already_data, already_size); + remain_ = static_cast<uint8_t*>(buf_.get()); + end_ = remain_ + already_size; + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + assert(buf_.get()); + std::size_t sending = std::min<std::size_t>(amount, end_ - remain_); + memcpy(to, remain_, sending); + remain_ += sending; + if (remain_ == end_) { + ReplaceThis(new Uncompressed(fd_.release()), thunk); + } + return sending; + } + + private: + scoped_malloc buf_; + uint8_t *remain_; + uint8_t *end_; + + scoped_fd fd_; +}; + +#ifdef HAVE_ZLIB +class GZip : public ReadBase { + private: + static const std::size_t kInputBuffer = 16384; + public: + GZip(int fd, void *already_data, std::size_t already_size) + : file_(fd), in_buffer_(malloc(kInputBuffer)) { + if (!in_buffer_.get()) throw std::bad_alloc(); + assert(already_size < kInputBuffer); + if (already_size) { + memcpy(in_buffer_.get(), already_data, already_size); + stream_.next_in = static_cast<Bytef *>(in_buffer_.get()); + stream_.avail_in = already_size; + stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size); + } else { + stream_.avail_in = 0; + } + stream_.zalloc = Z_NULL; + stream_.zfree = Z_NULL; + stream_.opaque = Z_NULL; + stream_.msg = NULL; + // 32 for zlib and gzip decoding with automatic header detection. + // 15 for maximum window size. + UTIL_THROW_IF(Z_OK != inflateInit2(&stream_, 32 + 15), GZException, "Failed to initialize zlib."); + } + + ~GZip() { + if (Z_OK != inflateEnd(&stream_)) { + std::cerr << "zlib could not close properly." << std::endl; + abort(); + } + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + if (amount == 0) return 0; + stream_.next_out = static_cast<Bytef*>(to); + stream_.avail_out = std::min<std::size_t>(std::numeric_limits<uInt>::max(), amount); + do { + if (!stream_.avail_in) ReadInput(thunk); + int result = inflate(&stream_, 0); + switch (result) { + case Z_OK: + break; + case Z_STREAM_END: + { + std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); + ReplaceThis(new Complete(), thunk); + return ret; + } + case Z_ERRNO: + UTIL_THROW(ErrnoException, "zlib error"); + default: + UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result); + } + } while (stream_.next_out == to); + return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); + } + + private: + void ReadInput(ReadCompressed &thunk) { + assert(!stream_.avail_in); + stream_.next_in = static_cast<Bytef *>(in_buffer_.get()); + stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); + ReadCount(thunk) += stream_.avail_in; + } + + scoped_fd file_; + scoped_malloc in_buffer_; + z_stream stream_; +}; +#endif // HAVE_ZLIB + +#ifdef HAVE_BZLIB +class BZip : public ReadBase { + public: + explicit BZip(int fd, void *already_data, std::size_t already_size) { + scoped_fd hold(fd); + closer_.reset(FDOpenReadOrThrow(hold)); + int bzerror = BZ_OK; + file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size); + switch (bzerror) { + case BZ_OK: + return; + case BZ_CONFIG_ERROR: + UTIL_THROW(BZException, "Looks like bzip2 was miscompiled."); + case BZ_PARAM_ERROR: + UTIL_THROW(BZException, "Parameter error"); + case BZ_IO_ERROR: + UTIL_THROW(BZException, "IO error reading file"); + case BZ_MEM_ERROR: + throw std::bad_alloc(); + } + } + + ~BZip() { + int bzerror = BZ_OK; + BZ2_bzReadClose(&bzerror, file_); + if (bzerror != BZ_OK) { + std::cerr << "bz2 readclose error" << std::endl; + abort(); + } + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + int bzerror = BZ_OK; + int ret = BZ2_bzRead(&bzerror, file_, to, std::min<std::size_t>(static_cast<std::size_t>(INT_MAX), amount)); + long pos; + switch (bzerror) { + case BZ_STREAM_END: + pos = ftell(closer_.get()); + if (pos != -1) ReadCount(thunk) = pos; + ReplaceThis(new Complete(), thunk); + return ret; + case BZ_OK: + pos = ftell(closer_.get()); + if (pos != -1) ReadCount(thunk) = pos; + return ret; + default: + UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror); + } + } + + private: + scoped_FILE closer_; + BZFILE *file_; +}; +#endif // HAVE_BZLIB + +#ifdef HAVE_XZLIB +class XZip : public ReadBase { + private: + static const std::size_t kInputBuffer = 16384; + public: + XZip(int fd, void *already_data, std::size_t already_size) + : file_(fd), in_buffer_(malloc(kInputBuffer)), stream_(), action_(LZMA_RUN) { + if (!in_buffer_.get()) throw std::bad_alloc(); + assert(already_size < kInputBuffer); + if (already_size) { + memcpy(in_buffer_.get(), already_data, already_size); + stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get()); + stream_.avail_in = already_size; + stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size); + } else { + stream_.avail_in = 0; + } + stream_.allocator = NULL; + lzma_ret ret = lzma_stream_decoder(&stream_, UINT64_MAX, LZMA_CONCATENATED); + switch (ret) { + case LZMA_OK: + break; + case LZMA_MEM_ERROR: + UTIL_THROW(ErrnoException, "xz open error"); + default: + UTIL_THROW(XZException, "xz error code " << ret); + } + } + + ~XZip() { + lzma_end(&stream_); + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + if (amount == 0) return 0; + stream_.next_out = static_cast<uint8_t*>(to); + stream_.avail_out = amount; + do { + if (!stream_.avail_in) ReadInput(thunk); + lzma_ret status = lzma_code(&stream_, action_); + switch (status) { + case LZMA_OK: + break; + case LZMA_STREAM_END: + UTIL_THROW_IF(action_ != LZMA_FINISH, XZException, "Input not finished yet."); + { + std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); + ReplaceThis(new Complete(), thunk); + return ret; + } + case LZMA_MEM_ERROR: + throw std::bad_alloc(); + case LZMA_FORMAT_ERROR: + UTIL_THROW(XZException, "xzlib says file format not recognized"); + case LZMA_OPTIONS_ERROR: + UTIL_THROW(XZException, "xzlib says unsupported compression options"); + case LZMA_DATA_ERROR: + UTIL_THROW(XZException, "xzlib says this file is corrupt"); + case LZMA_BUF_ERROR: + UTIL_THROW(XZException, "xzlib says unexpected end of input"); + default: + UTIL_THROW(XZException, "unrecognized xzlib error " << status); + } + } while (stream_.next_out == to); + return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); + } + + private: + void ReadInput(ReadCompressed &thunk) { + assert(!stream_.avail_in); + stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get()); + stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); + if (!stream_.avail_in) action_ = LZMA_FINISH; + ReadCount(thunk) += stream_.avail_in; + } + + scoped_fd file_; + scoped_malloc in_buffer_; + lzma_stream stream_; + + lzma_action action_; +}; +#endif // HAVE_XZLIB + +enum MagicResult { + UNKNOWN, GZIP, BZIP, XZIP +}; + +MagicResult DetectMagic(const void *from_void) { + const uint8_t *header = static_cast<const uint8_t*>(from_void); + if (header[0] == 0x1f && header[1] == 0x8b) { + return GZIP; + } + if (header[0] == 'B' && header[1] == 'Z') { + return BZIP; + } + const uint8_t xzmagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 }; + if (!memcmp(header, xzmagic, 6)) { + return XZIP; + } + return UNKNOWN; +} + +ReadBase *ReadFactory(int fd, uint64_t &raw_amount) { + scoped_fd hold(fd); + unsigned char header[ReadCompressed::kMagicSize]; + raw_amount = ReadOrEOF(fd, header, ReadCompressed::kMagicSize); + if (!raw_amount) + return new Uncompressed(hold.release()); + if (raw_amount != ReadCompressed::kMagicSize) + return new UncompressedWithHeader(hold.release(), header, raw_amount); + switch (DetectMagic(header)) { + case GZIP: +#ifdef HAVE_ZLIB + return new GZip(hold.release(), header, ReadCompressed::kMagicSize); +#else + UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in."); +#endif + case BZIP: +#ifdef HAVE_BZLIB + return new BZip(hold.release(), header, ReadCompressed::kMagicSize); +#else + UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZ), but bzip support was not compiled in."); +#endif + case XZIP: +#ifdef HAVE_XZLIB + return new XZip(hold.release(), header, ReadCompressed::kMagicSize); +#else + UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in."); +#endif + case UNKNOWN: + break; + } + try { + AdvanceOrThrow(fd, -ReadCompressed::kMagicSize); + } catch (const util::ErrnoException &e) { + return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize); + } + return new Uncompressed(hold.release()); +} + +} // namespace + +bool ReadCompressed::DetectCompressedMagic(const void *from_void) { + return DetectMagic(from_void) != UNKNOWN; +} + +ReadCompressed::ReadCompressed(int fd) { + Reset(fd); +} + +ReadCompressed::ReadCompressed() {} + +ReadCompressed::~ReadCompressed() {} + +void ReadCompressed::Reset(int fd) { + internal_.reset(); + internal_.reset(ReadFactory(fd, raw_amount_)); +} + +std::size_t ReadCompressed::Read(void *to, std::size_t amount) { + return internal_->Read(to, amount, *this); +} + +} // namespace util diff --git a/klm/util/read_compressed.hh b/klm/util/read_compressed.hh new file mode 100644 index 00000000..83ca9fb2 --- /dev/null +++ b/klm/util/read_compressed.hh @@ -0,0 +1,74 @@ +#ifndef UTIL_READ_COMPRESSED__ +#define UTIL_READ_COMPRESSED__ + +#include "util/exception.hh" +#include "util/scoped.hh" + +#include <cstddef> + +#include <stdint.h> + +namespace util { + +class CompressedException : public Exception { + public: + CompressedException() throw(); + virtual ~CompressedException() throw(); +}; + +class GZException : public CompressedException { + public: + GZException() throw(); + ~GZException() throw(); +}; + +class BZException : public CompressedException { + public: + BZException() throw(); + ~BZException() throw(); +}; + +class XZException : public CompressedException { + public: + XZException() throw(); + ~XZException() throw(); +}; + +class ReadBase; + +class ReadCompressed { + public: + static const std::size_t kMagicSize = 6; + // Must have at least kMagicSize bytes. + static bool DetectCompressedMagic(const void *from); + + // Takes ownership of fd. + explicit ReadCompressed(int fd); + + // Must call Reset later. + ReadCompressed(); + + ~ReadCompressed(); + + // Takes ownership of fd. + void Reset(int fd); + + std::size_t Read(void *to, std::size_t amount); + + uint64_t RawAmount() const { return raw_amount_; } + + private: + friend class ReadBase; + + scoped_ptr<ReadBase> internal_; + + uint64_t raw_amount_; + + // No copying. + ReadCompressed(const ReadCompressed &); + void operator=(const ReadCompressed &); +}; + +} // namespace util + +#endif // UTIL_READ_COMPRESSED__ diff --git a/klm/util/read_compressed_test.cc b/klm/util/read_compressed_test.cc new file mode 100644 index 00000000..6fd97e5e --- /dev/null +++ b/klm/util/read_compressed_test.cc @@ -0,0 +1,94 @@ +#include "util/read_compressed.hh" + +#include "util/file.hh" +#include "util/have.hh" + +#define BOOST_TEST_MODULE ReadCompressedTest +#include <boost/test/unit_test.hpp> +#include <boost/scoped_ptr.hpp> + +#include <fstream> +#include <string> + +#include <stdlib.h> + +namespace util { +namespace { + +void ReadLoop(ReadCompressed &reader, void *to_void, std::size_t amount) { + uint8_t *to = static_cast<uint8_t*>(to_void); + while (amount) { + std::size_t ret = reader.Read(to, amount); + BOOST_REQUIRE(ret); + to += ret; + amount -= ret; + } +} + +void TestRandom(const char *compressor) { + const uint32_t kSize4 = 100000 / 4; + char name[] = "tempXXXXXX"; + + // Write test file. + { + scoped_fd original(mkstemp(name)); + BOOST_REQUIRE(original.get() > 0); + for (uint32_t i = 0; i < kSize4; ++i) { + WriteOrThrow(original.get(), &i, sizeof(uint32_t)); + } + } + + char gzname[] = "tempXXXXXX"; + scoped_fd gzipped(mkstemp(gzname)); + + std::string command(compressor); +#ifdef __CYGWIN__ + command += ".exe"; +#endif + command += " <\""; + command += name; + command += "\" >\""; + command += gzname; + command += "\""; + BOOST_REQUIRE_EQUAL(0, system(command.c_str())); + + BOOST_CHECK_EQUAL(0, unlink(name)); + BOOST_CHECK_EQUAL(0, unlink(gzname)); + + ReadCompressed reader(gzipped.release()); + for (uint32_t i = 0; i < kSize4; ++i) { + uint32_t got; + ReadLoop(reader, &got, sizeof(uint32_t)); + BOOST_CHECK_EQUAL(i, got); + } + + char ignored; + BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); + // Test double EOF call. + BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); +} + +BOOST_AUTO_TEST_CASE(Uncompressed) { + TestRandom("cat"); +} + +#ifdef HAVE_ZLIB +BOOST_AUTO_TEST_CASE(ReadGZ) { + TestRandom("gzip"); +} +#endif // HAVE_ZLIB + +#ifdef HAVE_BZLIB +BOOST_AUTO_TEST_CASE(ReadBZ) { + TestRandom("bzip2"); +} +#endif // HAVE_BZLIB + +#ifdef HAVE_XZLIB +BOOST_AUTO_TEST_CASE(ReadXZ) { + TestRandom("xz"); +} +#endif + +} // namespace +} // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index 93e2e817..d62c6df1 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -1,40 +1,13 @@ #ifndef UTIL_SCOPED__ #define UTIL_SCOPED__ +/* Other scoped objects in the style of scoped_ptr. */ #include "util/exception.hh" - -/* Other scoped objects in the style of scoped_ptr. */ #include <cstddef> #include <cstdlib> namespace util { -template <class T, class R, R (*Free)(T*)> class scoped_thing { - public: - explicit scoped_thing(T *c = static_cast<T*>(0)) : c_(c) {} - - ~scoped_thing() { if (c_) Free(c_); } - - void reset(T *c) { - if (c_) Free(c_); - c_ = c; - } - - T &operator*() { return *c_; } - const T&operator*() const { return *c_; } - T &operator->() { return *c_; } - const T&operator->() const { return *c_; } - - T *get() { return c_; } - const T *get() const { return c_; } - - private: - T *c_; - - scoped_thing(const scoped_thing &); - scoped_thing &operator=(const scoped_thing &); -}; - class scoped_malloc { public: scoped_malloc() : p_(NULL) {} @@ -77,9 +50,6 @@ template <class T> class scoped_array { T &operator*() { return *c_; } const T&operator*() const { return *c_; } - T &operator->() { return *c_; } - const T&operator->() const { return *c_; } - T &operator[](std::size_t idx) { return c_[idx]; } const T &operator[](std::size_t idx) const { return c_[idx]; } @@ -90,6 +60,39 @@ template <class T> class scoped_array { private: T *c_; + + scoped_array(const scoped_array &); + void operator=(const scoped_array &); +}; + +template <class T> class scoped_ptr { + public: + explicit scoped_ptr(T *content = NULL) : c_(content) {} + + ~scoped_ptr() { delete c_; } + + T *get() { return c_; } + const T* get() const { return c_; } + + T &operator*() { return *c_; } + const T&operator*() const { return *c_; } + + T *operator->() { return c_; } + const T*operator->() const { return c_; } + + T &operator[](std::size_t idx) { return c_[idx]; } + const T &operator[](std::size_t idx) const { return c_[idx]; } + + void reset(T *to = NULL) { + scoped_ptr<T> other(c_); + c_ = to; + } + + private: + T *c_; + + scoped_ptr(const scoped_ptr &); + void operator=(const scoped_ptr &); }; } // namespace util diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh index be6a643d..51481646 100644 --- a/klm/util/string_piece.hh +++ b/klm/util/string_piece.hh @@ -1,6 +1,6 @@ /* If you use ICU in your program, then compile with -DHAVE_ICU -licui18n. If * you don't use ICU, then this will use the Google implementation from Chrome. - * This has been modified from the original version to let you choose. + * This has been modified from the original version to let you choose. */ // Copyright 2008, Google Inc. @@ -62,9 +62,9 @@ #include <unicode/stringpiece.h> #include <unicode/uversion.h> -// Old versions of ICU don't define operator== and operator!=. +// Old versions of ICU don't define operator== and operator!=. #if (U_ICU_VERSION_MAJOR_NUM < 4) || ((U_ICU_VERSION_MAJOR_NUM == 4) && (U_ICU_VERSION_MINOR_NUM < 4)) -#warning You are using an old version of ICU. Consider upgrading to ICU >= 4.6. +#warning You are using an old version of ICU. Consider upgrading to ICU >= 4.6. inline bool operator==(const StringPiece& x, const StringPiece& y) { if (x.size() != y.size()) return false; @@ -274,15 +274,28 @@ struct StringPieceCompatibleEquals : public std::binary_function<const StringPie } }; template <class T> typename T::const_iterator FindStringPiece(const T &t, const StringPiece &key) { +#if BOOST_VERSION < 104200 + std::string temp(key.data(), key.size()); + return t.find(temp); +#else return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +#endif } + template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece &key) { +#if BOOST_VERSION < 104200 + std::string temp(key.data(), key.size()); + return t.find(temp); +#else return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +#endif } #endif #ifdef HAVE_ICU U_NAMESPACE_END +using U_NAMESPACE_QUALIFIER StringPiece; #endif + #endif // BASE_STRING_PIECE_H__ diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index 4a7f5460..a588c3fc 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -20,6 +20,7 @@ class OutOfTokens : public Exception { class SingleCharacter { public: + SingleCharacter() {} explicit SingleCharacter(char delim) : delim_(delim) {} StringPiece Find(const StringPiece &in) const { @@ -32,6 +33,8 @@ class SingleCharacter { class MultiCharacter { public: + MultiCharacter() {} + explicit MultiCharacter(const StringPiece &delimiter) : delimiter_(delimiter) {} StringPiece Find(const StringPiece &in) const { @@ -44,6 +47,7 @@ class MultiCharacter { class AnyCharacter { public: + AnyCharacter() {} explicit AnyCharacter(const StringPiece &chars) : chars_(chars) {} StringPiece Find(const StringPiece &in) const { @@ -56,6 +60,8 @@ class AnyCharacter { class AnyCharacterLast { public: + AnyCharacterLast() {} + explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {} StringPiece Find(const StringPiece &in) const { @@ -81,8 +87,8 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it return current_.data() != 0; } - static TokenIter<Find> end() { - return TokenIter<Find>(); + static TokenIter<Find, SkipEmpty> end() { + return TokenIter<Find, SkipEmpty>(); } private: @@ -100,8 +106,8 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it } while (SkipEmpty && current_.data() && current_.empty()); // Compiler should optimize this away if SkipEmpty is false. } - bool equal(const TokenIter<Find> &other) const { - return after_.data() == other.after_.data(); + bool equal(const TokenIter<Find, SkipEmpty> &other) const { + return current_.data() == other.current_.data(); } const StringPiece &dereference() const { |