From 8882e9ebe158aef382bb5544559ef7f2a553db62 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 11 Sep 2012 14:23:39 +0100 Subject: Update kenlm and build system --- klm/lm/max_order.hh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'klm/lm/max_order.hh') diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index bc8687cd..989f8324 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -8,5 +8,5 @@ #define KENLM_MAX_ORDER 6 #endif #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "Edit klm/lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif -- cgit v1.2.3 From de53e2e98acd0e2d07efb39bef430bd598908aa8 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 14 Dec 2012 12:39:04 -0800 Subject: Updated incremental, updated kenlm. Incremental assumes --- decoder/incremental.cc | 44 +++++++------- klm/lm/left.hh | 66 +++++++++++---------- klm/lm/max_order.hh | 7 +-- klm/lm/model.cc | 3 +- klm/lm/search_hashed.cc | 8 +-- klm/lm/search_hashed.hh | 2 +- klm/lm/vocab.cc | 7 ++- klm/lm/vocab.hh | 5 +- klm/search/Makefile.am | 4 +- klm/search/applied.hh | 86 +++++++++++++++++++++++++++ klm/search/config.hh | 25 ++++++-- klm/search/context.hh | 28 ++------- klm/search/dedupe.hh | 131 +++++++++++++++++++++++++++++++++++++++++ klm/search/edge_generator.cc | 3 +- klm/search/edge_generator.hh | 1 - klm/search/final.hh | 36 ----------- klm/search/header.hh | 9 ++- klm/search/nbest.cc | 106 +++++++++++++++++++++++++++++++++ klm/search/nbest.hh | 81 +++++++++++++++++++++++++ klm/search/note.hh | 12 ---- klm/search/rule.cc | 52 ++++++++-------- klm/search/rule.hh | 11 +++- klm/search/types.hh | 17 ++++++ klm/search/vertex.cc | 27 ++++++--- klm/search/vertex.hh | 37 +++++++----- klm/search/vertex_generator.cc | 44 +++----------- klm/search/vertex_generator.hh | 72 ++++++++++++++++++---- klm/search/weights.cc | 71 ---------------------- klm/search/weights.hh | 52 ---------------- klm/search/weights_test.cc | 38 ------------ 30 files changed, 680 insertions(+), 405 deletions(-) create mode 100644 klm/search/applied.hh create mode 100644 klm/search/dedupe.hh delete mode 100644 klm/search/final.hh create mode 100644 klm/search/nbest.cc create mode 100644 klm/search/nbest.hh delete mode 100644 klm/search/note.hh delete mode 100644 klm/search/weights.cc delete mode 100644 klm/search/weights.hh delete mode 100644 klm/search/weights_test.cc (limited to 'klm/lm/max_order.hh') diff --git a/decoder/incremental.cc b/decoder/incremental.cc index 46615b0b..85647a44 100644 --- a/decoder/incremental.cc +++ b/decoder/incremental.cc @@ -6,6 +6,7 @@ #include "lm/enumerate_vocab.hh" #include "lm/model.hh" +#include "search/applied.hh" #include "search/config.hh" #include "search/context.hh" #include "search/edge.hh" @@ -48,16 +49,16 @@ template class Incremental : public IncrementalBase { Incremental(const char *model_file, const std::vector &weights) : IncrementalBase(weights), m_(model_file, GetConfig()), - weights_( - weights[FD::Convert("KLanguageModel")], - weights[FD::Convert("KLanguageModel_OOV")], - weights[FD::Convert("WordPenalty")]) { - std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; + lm_(weights[FD::Convert("KLanguageModel")]), + oov_(weights[FD::Convert("KLanguageModel_OOV")]), + word_penalty_(weights[FD::Convert("WordPenalty")]) { + std::cerr << "Weights KLanguageModel " << lm_ << " KLanguageModel_OOV " << oov_ << " WordPenalty " << word_penalty_ << std::endl; } + void Search(unsigned int pop_limit, const Hypergraph &hg) const; private: - void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; + void ConvertEdge(const search::Context &context, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; lm::ngram::Config GetConfig() { lm::ngram::Config ret; @@ -69,46 +70,47 @@ template class Incremental : public IncrementalBase { const Model m_; - const search::Weights weights_; + const float lm_, oov_, word_penalty_; }; -void PrintFinal(const Hypergraph &hg, const search::Final final) { +void PrintApplied(const Hypergraph &hg, const search::Applied final) { const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); - const search::Final *child(final.Children()); + const search::Applied *child(final.Children()); for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { if (*i > 0) { std::cout << TD::Convert(*i) << ' '; } else { - PrintFinal(hg, *child++); + PrintApplied(hg, *child++); } } } template void Incremental::Search(unsigned int pop_limit, const Hypergraph &hg) const { boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); - search::Config config(weights_, pop_limit); + search::Config config(lm_, pop_limit, search::NBestConfig(1)); search::Context context(config, m_); + search::SingleBest best; for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { search::EdgeGenerator gen; const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; for (unsigned int j = 0; j < down_edges.size(); ++j) { unsigned int edge_index = down_edges[j]; - ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], gen); + ConvertEdge(context, out_vertices.get(), hg.edges_[edge_index], gen); } - search::VertexGenerator vertex_gen(context, out_vertices[i]); + search::VertexGenerator vertex_gen(context, out_vertices[i], best); gen.Search(context, vertex_gen); } - const search::Final top = out_vertices[hg.nodes_.size() - 2].BestChild(); + const search::Applied top = out_vertices[hg.nodes_.size() - 2].BestChild(); if (!top.Valid()) { std::cout << "NO PATH FOUND" << std::endl; } else { - PrintFinal(hg, top); + PrintApplied(hg, top); std::cout << "||| " << top.GetScore() << std::endl; } } -template void Incremental::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const { +template void Incremental::ConvertEdge(const search::Context &context, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const { const std::vector &e = in.rule_->e(); std::vector words; words.reserve(e.size()); @@ -127,10 +129,6 @@ template void Incremental::ConvertEdge(const search::Contex } } - if (final) { - words.push_back(m_.GetVocabulary().EndSentence()); - } - search::PartialEdge out(gen.AllocateEdge(nts.size())); memcpy(out.NT(), &nts[0], sizeof(search::PartialVertex) * nts.size()); @@ -140,8 +138,10 @@ template void Incremental::ConvertEdge(const search::Contex out.SetNote(note); score += in.rule_->GetFeatureValues().dot(cdec_weights_); - score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; - score += search::ScoreRule(context, words, final, out.Between()); + score -= static_cast(terminals) * word_penalty_ / M_LN10; + search::ScoreRuleRet res(search::ScoreRule(context.LanguageModel(), words, out.Between())); + score += res.prob * lm_ + static_cast(res.oov) * oov_; + out.SetScore(score); gen.AddEdge(out); diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 8c27232e..85c1ea37 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -51,36 +51,36 @@ namespace ngram { template class RuleScore { public: - explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) { out.left.length = 0; out.right.length = 0; } void BeginSentence() { - out_.right = model_.BeginSentenceState(); - // out_.left is empty. + out_->right = model_.BeginSentenceState(); + // out_->left is empty. left_done_ = true; } void Terminal(WordIndex word) { - State copy(out_.right); - FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); + State copy(out_->right); + FullScoreReturn ret(model_.FullScore(copy, word, out_->right)); if (left_done_) { prob_ += ret.prob; return; } if (ret.independent_left) { prob_ += ret.prob; left_done_ = true; return; } - out_.left.pointers[out_.left.length++] = ret.extend_left; + out_->left.pointers[out_->left.length++] = ret.extend_left; prob_ += ret.rest; - if (out_.right.length != copy.length + 1) + if (out_->right.length != copy.length + 1) left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. void BeginNonTerminal(const ChartState &in, float prob = 0.0) { prob_ = prob; - out_ = in; + *out_ = in; left_done_ = in.left.full; } @@ -89,23 +89,23 @@ template class RuleScore { if (!in.left.length) { if (in.left.full) { - for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; + for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i; left_done_ = true; - out_.right = in.right; + out_->right = in.right; } return; } - if (!out_.right.length) { - out_.right = in.right; + if (!out_->right.length) { + out_->right = in.right; if (left_done_) { prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1); return; } - if (out_.left.length) { + if (out_->left.length) { left_done_ = true; } else { - out_.left = in.left; + out_->left = in.left; left_done_ = in.left.full; } return; @@ -113,10 +113,10 @@ template class RuleScore { float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1]; float *back = backoffs, *back2 = backoffs2; - unsigned char next_use = out_.right.length; + unsigned char next_use = out_->right.length; // First word - if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; + if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return; // Words after the first, so extending a bigram to begin with for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { @@ -127,54 +127,58 @@ template class RuleScore { if (in.left.full) { for (const float *i = back; i != back + next_use; ++i) prob_ += *i; left_done_ = true; - out_.right = in.right; + out_->right = in.right; return; } // Right state was minimized, so it's already independent of the new words to the left. if (in.right.length < in.left.length) { - out_.right = in.right; + out_->right = in.right; return; } // Shift exisiting words down. - for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { + for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) { *(i + in.right.length) = *i; } // Add words from in.right. - std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); + std::copy(in.right.words, in.right.words + in.right.length, out_->right.words); // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. - std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); - std::copy(back, back + next_use, out_.right.backoff + in.right.length); - out_.right.length = in.right.length + next_use; + std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff); + std::copy(back, back + next_use, out_->right.backoff + in.right.length); + out_->right.length = in.right.length + next_use; } float Finish() { // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. - out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1); + out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1); return prob_; } void Reset() { prob_ = 0.0; left_done_ = false; - out_.left.length = 0; - out_.right.length = 0; + out_->left.length = 0; + out_->right.length = 0; + } + void Reset(ChartState &replacement) { + out_ = &replacement; + Reset(); } private: bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { ProcessRet(model_.ExtendLeft( - out_.right.words, out_.right.words + next_use, // Words to extend into + out_->right.words, out_->right.words + next_use, // Words to extend into back_in, // Backoffs to use in.left.pointers[extend_length - 1], extend_length, // Words to be extended back_out, // Backoffs for the next score next_use)); // Length of n-gram to use in next scoring. - if (next_use != out_.right.length) { + if (next_use != out_->right.length) { left_done_ = true; if (!next_use) { // Early exit. - out_.right = in.right; + out_->right = in.right; prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1); return true; } @@ -193,13 +197,13 @@ template class RuleScore { left_done_ = true; return; } - out_.left.pointers[out_.left.length++] = ret.extend_left; + out_->left.pointers[out_->left.length++] = ret.extend_left; prob_ += ret.rest; } const M &model_; - ChartState &out_; + ChartState *out_; bool left_done_; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index 989f8324..ea0dea46 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -4,9 +4,8 @@ * (kMaxOrder - 1) * sizeof(float) bytes instead of * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead */ -#ifndef KENLM_MAX_ORDER -#define KENLM_MAX_ORDER 6 -#endif #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif + +#define KENLM_MAX_ORDER 5 diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 2fd20481..fc61efee 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -87,7 +87,7 @@ template void GenericModel FullScoreReturn GenericModel void HashedSearch::InitializeFromARPA(const char * template <> void HashedSearch::DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { NoRestBuild build; - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } template <> void HashedSearch::DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { @@ -239,19 +239,19 @@ template <> void HashedSearch::DispatchBuild(util::FilePiece &f, cons case Config::REST_MAX: { MaxRestBuild build; - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } break; case Config::REST_LOWER: { LowerRestBuild build(config, counts.size(), vocab); - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } break; } } -template template void HashedSearch::ApplyBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { +template template void HashedSearch::ApplyBuild(util::FilePiece &f, const std::vector &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { for (WordIndex i = 0; i < counts[0]; ++i) { build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]); } diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index a52f107b..00595796 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -147,7 +147,7 @@ template class HashedSearch { // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. void DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); - template void ApplyBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); + template void ApplyBuild(util::FilePiece &f, const std::vector &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); class Unigram { public: diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 11c27518..fd7f96dc 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { } *end_ = hashed; if (enumerate_) { - strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); + void *copied = string_backing_.Allocate(str.size()); + memcpy(copied, str.data(), str.size()); + strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast(copied), str.size()); } ++end_; // This is 1 + the offset where it was inserted to make room for unk. @@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { if (enumerate_) { if (!strings_to_enumerate_.empty()) { - util::PairedIterator values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); + util::PairedIterator values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); util::JointSort(begin_, end_, values); } for (WordIndex i = 0; i < static_cast(end_ - begin_); ++i) { @@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { enumerate_->Add(i + 1, strings_to_enumerate_[i]); } strings_to_enumerate_.clear(); + string_backing_.FreeAll(); } else { util::JointSort(begin_, end_, reorder_vocab + 1); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index de54eb06..3902f117 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -4,6 +4,7 @@ #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" +#include "util/pool.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" #include "util/string_piece.hh" @@ -96,7 +97,9 @@ class SortedVocabulary : public base::Vocabulary { EnumerateVocab *enumerate_; // Actual strings. Used only when loading from ARPA and enumerate_ != NULL - std::vector strings_to_enumerate_; + util::Pool string_backing_; + + std::vector strings_to_enumerate_; }; #pragma pack(push) diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am index ccc5b7f6..5aea33c2 100644 --- a/klm/search/Makefile.am +++ b/klm/search/Makefile.am @@ -2,10 +2,10 @@ noinst_LIBRARIES = libksearch.a libksearch_a_SOURCES = \ edge_generator.cc \ + nbest.cc \ rule.cc \ vertex.cc \ - vertex_generator.cc \ - weights.cc + vertex_generator.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/search/applied.hh b/klm/search/applied.hh new file mode 100644 index 00000000..bd659e5c --- /dev/null +++ b/klm/search/applied.hh @@ -0,0 +1,86 @@ +#ifndef SEARCH_APPLIED__ +#define SEARCH_APPLIED__ + +#include "search/edge.hh" +#include "search/header.hh" +#include "util/pool.hh" + +#include + +namespace search { + +// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted. +template class GenericApplied : public Header { + public: + GenericApplied() {} + + GenericApplied(void *location, PartialEdge partial) + : Header(location) { + memcpy(Base(), partial.Base(), kHeaderSize); + Below *child_out = Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = Below(part->End()); + } + + GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) { + SetScore(score); + SetNote(note); + } + + explicit GenericApplied(History from) : Header(from) {} + + + // These are arrays of length GetArity(). + Below *Children() { + return reinterpret_cast(After()); + } + const Below *Children() const { + return reinterpret_cast(After()); + } + + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Below); + } +}; + +// Applied rule that references itself. +class Applied : public GenericApplied { + private: + typedef GenericApplied P; + + public: + Applied() {} + Applied(void *location, PartialEdge partial) : P(location, partial) {} + Applied(History from) : P(from) {} +}; + +// How to build single-best hypotheses. +class SingleBest { + public: + typedef PartialEdge Combine; + + void Add(PartialEdge &existing, PartialEdge add) const { + if (!existing.Valid() || existing.GetScore() < add.GetScore()) + existing = add; + } + + NBestComplete Complete(PartialEdge partial) { + if (!partial.Valid()) + return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY); + void *place_final = pool_.Allocate(Applied::Size(partial.GetArity())); + Applied(place_final, partial); + return NBestComplete( + place_final, + partial.CompletedState(), + partial.GetScore()); + } + + private: + util::Pool pool_; +}; + +} // namespace search + +#endif // SEARCH_APPLIED__ diff --git a/klm/search/config.hh b/klm/search/config.hh index ef8e2354..ba18c09e 100644 --- a/klm/search/config.hh +++ b/klm/search/config.hh @@ -1,23 +1,36 @@ #ifndef SEARCH_CONFIG__ #define SEARCH_CONFIG__ -#include "search/weights.hh" -#include "util/string_piece.hh" +#include "search/types.hh" namespace search { +struct NBestConfig { + explicit NBestConfig(unsigned int in_size) { + keep = in_size; + size = in_size; + } + + unsigned int keep, size; +}; + class Config { public: - Config(const Weights &weights, unsigned int pop_limit) : - weights_(weights), pop_limit_(pop_limit) {} + Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) : + lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {} - const Weights &GetWeights() const { return weights_; } + Score LMWeight() const { return lm_weight_; } unsigned int PopLimit() const { return pop_limit_; } + const NBestConfig &GetNBest() const { return nbest_; } + private: - Weights weights_; + Score lm_weight_; + unsigned int pop_limit_; + + NBestConfig nbest_; }; } // namespace search diff --git a/klm/search/context.hh b/klm/search/context.hh index 62163144..08f21bbf 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -1,30 +1,16 @@ #ifndef SEARCH_CONTEXT__ #define SEARCH_CONTEXT__ -#include "lm/model.hh" #include "search/config.hh" -#include "search/final.hh" -#include "search/types.hh" #include "search/vertex.hh" -#include "util/exception.hh" -#include "util/pool.hh" #include -#include - -#include namespace search { -class Weights; - class ContextBase { public: - explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - - util::Pool &FinalPool() { - return final_pool_; - } + explicit ContextBase(const Config &config) : config_(config) {} VertexNode *NewVertexNode() { VertexNode *ret = vertex_node_pool_.construct(); @@ -36,18 +22,16 @@ class ContextBase { vertex_node_pool_.destroy(node); } - unsigned int PopLimit() const { return pop_limit_; } + unsigned int PopLimit() const { return config_.PopLimit(); } - const Weights &GetWeights() const { return weights_; } + Score LMWeight() const { return config_.LMWeight(); } - private: - util::Pool final_pool_; + const Config &GetConfig() const { return config_; } + private: boost::object_pool vertex_node_pool_; - unsigned int pop_limit_; - - const Weights &weights_; + Config config_; }; template class Context : public ContextBase { diff --git a/klm/search/dedupe.hh b/klm/search/dedupe.hh new file mode 100644 index 00000000..7eaa3b95 --- /dev/null +++ b/klm/search/dedupe.hh @@ -0,0 +1,131 @@ +#ifndef SEARCH_DEDUPE__ +#define SEARCH_DEDUPE__ + +#include "lm/state.hh" +#include "search/edge_generator.hh" + +#include +#include + +namespace search { + +class Dedupe { + public: + Dedupe() {} + + PartialEdge AllocateEdge(Arity arity) { + return behind_.AllocateEdge(arity); + } + + void AddEdge(PartialEdge edge) { + edge.MutableFlags() = 0; + + uint64_t hash = 0; + const PartialVertex *v = edge.NT(); + const PartialVertex *v_end = v + edge.GetArity(); + for (; v != v_end; ++v) { + const void *ptr = v->Identify(); + hash = util::MurmurHashNative(&ptr, sizeof(const void*), hash); + } + + const lm::ngram::ChartState *c = edge.Between(); + const lm::ngram::ChartState *const c_end = c + edge.GetArity() + 1; + for (; c != c_end; ++c) hash = hash_value(*c, hash); + + std::pair ret(table_.insert(std::make_pair(hash, edge))); + if (!ret.second) FoundDupe(ret.first->second, edge); + } + + bool Empty() const { return behind_.Empty(); } + + template void Search(Context &context, Output &output) { + for (Table::const_iterator i(table_.begin()); i != table_.end(); ++i) { + behind_.AddEdge(i->second); + } + Unpack unpack(output, *this); + behind_.Search(context, unpack); + } + + private: + void FoundDupe(PartialEdge &table, PartialEdge adding) { + if (table.GetFlags() & kPackedFlag) { + Packed &packed = *static_cast(table.GetNote().mut); + if (table.GetScore() >= adding.GetScore()) { + packed.others.push_back(adding); + return; + } + Note original(packed.original); + packed.original = adding.GetNote(); + adding.SetNote(table.GetNote()); + table.SetNote(original); + packed.others.push_back(table); + packed.starting = adding.GetScore(); + table = adding; + table.MutableFlags() |= kPackedFlag; + return; + } + PartialEdge loser; + if (adding.GetScore() > table.GetScore()) { + loser = table; + table = adding; + } else { + loser = adding; + } + // table is winner, loser is loser... + packed_.construct(table, loser); + } + + struct Packed { + Packed(PartialEdge winner, PartialEdge loser) + : original(winner.GetNote()), starting(winner.GetScore()), others(1, loser) { + winner.MutableNote().vp = this; + winner.MutableFlags() |= kPackedFlag; + loser.MutableFlags() &= ~kPackedFlag; + } + Note original; + Score starting; + std::vector others; + }; + + template class Unpack { + public: + explicit Unpack(Output &output, Dedupe &owner) : output_(output), owner_(owner) {} + + void NewHypothesis(PartialEdge edge) { + if (edge.GetFlags() & kPackedFlag) { + Packed &packed = *reinterpret_cast(edge.GetNote().mut); + edge.SetNote(packed.original); + edge.MutableFlags() = 0; + std::size_t copy_size = sizeof(PartialVertex) * edge.GetArity() + sizeof(lm::ngram::ChartState); + for (std::vector::iterator i = packed.others.begin(); i != packed.others.end(); ++i) { + PartialEdge copy(owner_.AllocateEdge(edge.GetArity())); + copy.SetScore(edge.GetScore() - packed.starting + i->GetScore()); + copy.MutableFlags() = 0; + copy.SetNote(i->GetNote()); + memcpy(copy.NT(), edge.NT(), copy_size); + output_.NewHypothesis(copy); + } + } + output_.NewHypothesis(edge); + } + + void FinishedSearch() { + output_.FinishedSearch(); + } + + private: + Output &output_; + Dedupe &owner_; + }; + + EdgeGenerator behind_; + + typedef boost::unordered_map Table; + Table table_; + + boost::object_pool packed_; + + static const uint16_t kPackedFlag = 1; +}; +} // namespace search +#endif // SEARCH_DEDUPE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index 260159b1..eacf5de5 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -1,6 +1,7 @@ #include "search/edge_generator.hh" #include "lm/left.hh" +#include "lm/model.hh" #include "lm/partial.hh" #include "search/context.hh" #include "search/vertex.hh" @@ -38,7 +39,7 @@ template void FastScore(const Context &context, Arity victi *cover = *(cover + 1); } } - update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); + update.SetScore(update.GetScore() + adjustment * context.LMWeight()); } } // namespace diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index 582c78b7..203942c6 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -2,7 +2,6 @@ #define SEARCH_EDGE_GENERATOR__ #include "search/edge.hh" -#include "search/note.hh" #include "search/types.hh" #include diff --git a/klm/search/final.hh b/klm/search/final.hh deleted file mode 100644 index 50e62cf2..00000000 --- a/klm/search/final.hh +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef SEARCH_FINAL__ -#define SEARCH_FINAL__ - -#include "search/header.hh" -#include "util/pool.hh" - -namespace search { - -// A full hypothesis with pointers to children. -class Final : public Header { - public: - Final() {} - - Final(util::Pool &pool, Score score, Arity arity, Note note) - : Header(pool.Allocate(Size(arity)), arity) { - SetScore(score); - SetNote(note); - } - - // These are arrays of length GetArity(). - Final *Children() { - return reinterpret_cast(After()); - } - const Final *Children() const { - return reinterpret_cast(After()); - } - - private: - static std::size_t Size(Arity arity) { - return kHeaderSize + arity * sizeof(const Final); - } -}; - -} // namespace search - -#endif // SEARCH_FINAL__ diff --git a/klm/search/header.hh b/klm/search/header.hh index 25550dbe..69f0eed0 100644 --- a/klm/search/header.hh +++ b/klm/search/header.hh @@ -3,7 +3,6 @@ // Header consisting of Score, Arity, and Note -#include "search/note.hh" #include "search/types.hh" #include @@ -24,6 +23,9 @@ class Header { bool operator<(const Header &other) const { return GetScore() < other.GetScore(); } + bool operator>(const Header &other) const { + return GetScore() > other.GetScore(); + } Arity GetArity() const { return *reinterpret_cast(base_ + sizeof(Score)); @@ -36,9 +38,14 @@ class Header { *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)) = to; } + uint8_t *Base() { return base_; } + const uint8_t *Base() const { return base_; } + protected: Header() : base_(NULL) {} + explicit Header(void *base) : base_(static_cast(base)) {} + Header(void *base, Arity arity) : base_(static_cast(base)) { *reinterpret_cast(base_ + sizeof(Score)) = arity; } diff --git a/klm/search/nbest.cc b/klm/search/nbest.cc new file mode 100644 index 00000000..ec3322c9 --- /dev/null +++ b/klm/search/nbest.cc @@ -0,0 +1,106 @@ +#include "search/nbest.hh" + +#include "util/pool.hh" + +#include +#include +#include + +#include +#include + +namespace search { + +NBestList::NBestList(std::vector &partials, util::Pool &entry_pool, std::size_t keep) { + assert(!partials.empty()); + std::vector::iterator end; + if (partials.size() > keep) { + end = partials.begin() + keep; + std::nth_element(partials.begin(), end, partials.end(), std::greater()); + } else { + end = partials.end(); + } + for (std::vector::const_iterator i(partials.begin()); i != end; ++i) { + queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i)); + } +} + +Score NBestList::TopAfterConstructor() const { + assert(revealed_.empty()); + return queue_.top().GetScore(); +} + +const std::vector &NBestList::Extract(util::Pool &pool, std::size_t n) { + while (revealed_.size() < n && !queue_.empty()) { + MoveTop(pool); + } + return revealed_; +} + +Score NBestList::Visit(util::Pool &pool, std::size_t index) { + if (index + 1 < revealed_.size()) + return revealed_[index + 1].GetScore() - revealed_[index].GetScore(); + if (queue_.empty()) + return -INFINITY; + if (index + 1 == revealed_.size()) + return queue_.top().GetScore() - revealed_[index].GetScore(); + assert(index == revealed_.size()); + + MoveTop(pool); + + if (queue_.empty()) return -INFINITY; + return queue_.top().GetScore() - revealed_[index].GetScore(); +} + +Applied NBestList::Get(util::Pool &pool, std::size_t index) { + assert(index <= revealed_.size()); + if (index == revealed_.size()) MoveTop(pool); + return revealed_[index]; +} + +void NBestList::MoveTop(util::Pool &pool) { + assert(!queue_.empty()); + QueueEntry entry(queue_.top()); + queue_.pop(); + RevealedRef *const children_begin = entry.Children(); + RevealedRef *const children_end = children_begin + entry.GetArity(); + Score basis = entry.GetScore(); + for (RevealedRef *child = children_begin; child != children_end; ++child) { + Score change = child->in_->Visit(pool, child->index_); + if (change != -INFINITY) { + assert(change < 0.001); + QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote()); + std::copy(children_begin, child, new_entry.Children()); + RevealedRef *update = new_entry.Children() + (child - children_begin); + update->in_ = child->in_; + update->index_ = child->index_ + 1; + std::copy(child + 1, children_end, update + 1); + queue_.push(new_entry); + } + // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010. + if (child->index_) break; + } + + // Convert QueueEntry to Applied. This leaves some unused memory. + void *overwrite = entry.Children(); + for (unsigned int i = 0; i < entry.GetArity(); ++i) { + RevealedRef from(*(static_cast(overwrite) + i)); + *(static_cast(overwrite) + i) = from.in_->Get(pool, from.index_); + } + revealed_.push_back(Applied(entry.Base())); +} + +NBestComplete NBest::Complete(std::vector &partials) { + assert(!partials.empty()); + NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep); + return NBestComplete( + list, + partials.front().CompletedState(), // All partials have the same state + list->TopAfterConstructor()); +} + +const std::vector &NBest::Extract(History history) { + return static_cast(history)->Extract(entry_pool_, config_.size); +} + +} // namespace search diff --git a/klm/search/nbest.hh b/klm/search/nbest.hh new file mode 100644 index 00000000..cb7651bc --- /dev/null +++ b/klm/search/nbest.hh @@ -0,0 +1,81 @@ +#ifndef SEARCH_NBEST__ +#define SEARCH_NBEST__ + +#include "search/applied.hh" +#include "search/config.hh" +#include "search/edge.hh" + +#include + +#include +#include +#include + +#include + +namespace search { + +class NBestList; + +class NBestList { + private: + class RevealedRef { + public: + explicit RevealedRef(History history) + : in_(static_cast(history)), index_(0) {} + + private: + friend class NBestList; + + NBestList *in_; + std::size_t index_; + }; + + typedef GenericApplied QueueEntry; + + public: + NBestList(std::vector &existing, util::Pool &entry_pool, std::size_t keep); + + Score TopAfterConstructor() const; + + const std::vector &Extract(util::Pool &pool, std::size_t n); + + private: + Score Visit(util::Pool &pool, std::size_t index); + + Applied Get(util::Pool &pool, std::size_t index); + + void MoveTop(util::Pool &pool); + + typedef std::vector Revealed; + Revealed revealed_; + + typedef std::priority_queue Queue; + Queue queue_; +}; + +class NBest { + public: + typedef std::vector Combine; + + explicit NBest(const NBestConfig &config) : config_(config) {} + + void Add(std::vector &existing, PartialEdge addition) const { + existing.push_back(addition); + } + + NBestComplete Complete(std::vector &partials); + + const std::vector &Extract(History root); + + private: + const NBestConfig config_; + + boost::object_pool list_pool_; + + util::Pool entry_pool_; +}; + +} // namespace search + +#endif // SEARCH_NBEST__ diff --git a/klm/search/note.hh b/klm/search/note.hh deleted file mode 100644 index 50bed06e..00000000 --- a/klm/search/note.hh +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef SEARCH_NOTE__ -#define SEARCH_NOTE__ - -namespace search { - -union Note { - const void *vp; -}; - -} // namespace search - -#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 5b00207e..0244a09f 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -1,7 +1,7 @@ #include "search/rule.hh" +#include "lm/model.hh" #include "search/context.hh" -#include "search/final.hh" #include @@ -9,35 +9,35 @@ namespace search { -template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing) { - unsigned int oov_count = 0; - float prob = 0.0; - const Model &model = context.LanguageModel(); - const lm::WordIndex oov = model.GetVocabulary().NotFound(); - for (std::vector::const_iterator word = words.begin(); ; ++word) { - lm::ngram::RuleScore scorer(model, *(writing++)); - // TODO: optimize - if (prepend_bos && (word == words.begin())) { - scorer.BeginSentence(); - } - for (; ; ++word) { - if (word == words.end()) { - prob += scorer.Finish(); - return static_cast(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); - } - if (*word == kNonTerminal) break; - if (*word == oov) ++oov_count; +template ScoreRuleRet ScoreRule(const Model &model, const std::vector &words, lm::ngram::ChartState *writing) { + ScoreRuleRet ret; + ret.prob = 0.0; + ret.oov = 0; + const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence(); + lm::ngram::RuleScore scorer(model, *(writing++)); + std::vector::const_iterator word = words.begin(); + if (word != words.end() && *word == bos) { + scorer.BeginSentence(); + ++word; + } + for (; word != words.end(); ++word) { + if (*word == kNonTerminal) { + ret.prob += scorer.Finish(); + scorer.Reset(*(writing++)); + } else { + if (*word == oov) ++ret.oov; scorer.Terminal(*word); } - prob += scorer.Finish(); } + ret.prob += scorer.Finish(); + return ret; } -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 0ce2794d..43ca6162 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -9,11 +9,16 @@ namespace search { -template class Context; - const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; -template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *state_out); +struct ScoreRuleRet { + Score prob; + unsigned int oov; +}; + +// Pass and normally. +// Indicate non-terminals with kNonTerminal. +template ScoreRuleRet ScoreRule(const Model &model, const std::vector &words, lm::ngram::ChartState *state_out); } // namespace search diff --git a/klm/search/types.hh b/klm/search/types.hh index 06eb5bfa..f9c849b3 100644 --- a/klm/search/types.hh +++ b/klm/search/types.hh @@ -3,12 +3,29 @@ #include +namespace lm { namespace ngram { class ChartState; } } + namespace search { typedef float Score; typedef uint32_t Arity; +union Note { + const void *vp; +}; + +typedef void *History; + +struct NBestComplete { + NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score) + : history(in_history), state(&in_state), score(in_score) {} + + History history; + const lm::ngram::ChartState *state; + Score score; +}; + } // namespace search #endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index 11f4631f..45842982 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -19,21 +19,34 @@ struct GreaterByBound : public std::binary_functionSortAndSet(context, parent_ptr); + if (extend_.size() == 1) { + parent_ptr = extend_[0]; + extend_[0]->RecursiveSortAndSet(context, parent_ptr); context.DeleteVertexNode(this); return; } for (std::vector::iterator i = extend_.begin(); i != extend_.end(); ++i) { - (*i)->SortAndSet(context, &*i); + (*i)->RecursiveSortAndSet(context, *i); + } + std::sort(extend_.begin(), extend_.end(), GreaterByBound()); + bound_ = extend_.front()->Bound(); +} + +void VertexNode::SortAndSet(ContextBase &context) { + // This is the root. The root might be empty. + if (extend_.empty()) { + bound_ = -INFINITY; + return; + } + // The root cannot be replaced. There's always one transition. + for (std::vector::iterator i = extend_.begin(); i != extend_.end(); ++i) { + (*i)->RecursiveSortAndSet(context, *i); } std::sort(extend_.begin(), extend_.end(), GreaterByBound()); bound_ = extend_.front()->Bound(); diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index 52bc1dfe..10b3339b 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -2,7 +2,6 @@ #define SEARCH_VERTEX__ #include "lm/left.hh" -#include "search/final.hh" #include "search/types.hh" #include @@ -10,6 +9,7 @@ #include #include +#include #include namespace search { @@ -18,7 +18,7 @@ class ContextBase; class VertexNode { public: - VertexNode() {} + VertexNode() : end_() {} void InitRoot() { extend_.clear(); @@ -26,7 +26,7 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - end_ = Final(); + end_ = History(); } lm::ngram::ChartState &MutableState() { return state_; } @@ -36,20 +36,21 @@ class VertexNode { extend_.push_back(next); } - void SetEnd(Final end) { - assert(!end_.Valid()); + void SetEnd(History end, Score score) { + assert(!end_); end_ = end; + bound_ = score; } - void SortAndSet(ContextBase &context, VertexNode **parent_pointer); + void SortAndSet(ContextBase &context); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_.Valid() && extend_.empty(); + return !end_ && extend_.empty(); } bool Complete() const { - return end_.Valid(); + return end_; } const lm::ngram::ChartState &State() const { return state_; } @@ -64,7 +65,7 @@ class VertexNode { } // Will be invalid unless this is a leaf. - const Final End() const { return end_; } + const History End() const { return end_; } const VertexNode &operator[](size_t index) const { return *extend_[index]; @@ -75,13 +76,15 @@ class VertexNode { } private: + void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); + std::vector extend_; lm::ngram::ChartState state_; bool right_full_; Score bound_; - Final end_; + History end_; }; class PartialVertex { @@ -97,7 +100,7 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } + Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } unsigned char Length() const { return back_->Length(); } @@ -121,7 +124,7 @@ class PartialVertex { return ret; } - const Final End() const { + const History End() const { return back_->End(); } @@ -130,16 +133,18 @@ class PartialVertex { unsigned int index_; }; +template class VertexGenerator; + class Vertex { public: Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } - const Final BestChild() const { + const History BestChild() const { PartialVertex top(RootPartial()); if (top.Empty()) { - return Final(); + return History(); } else { PartialVertex continuation; while (!top.Complete()) { @@ -150,8 +155,8 @@ class Vertex { } private: - friend class VertexGenerator; - + template friend class VertexGenerator; + template friend class RootVertexGenerator; VertexNode root_; }; diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 0945fe55..73139ffc 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -4,26 +4,18 @@ #include "search/context.hh" #include "search/edge.hh" +#include +#include + #include namespace search { -VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { - gen.root_.InitRoot(); -} - +#if BOOST_VERSION > 104200 namespace { const uint64_t kCompleteAdd = static_cast(-1); -// Parallel structure to VertexNode. -struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map extend; -}; - Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { Trie &next = node.extend[added]; if (!next.under) { @@ -39,19 +31,10 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n return next; } -void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { - Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); - Final *child_out = final.Children(); - const PartialVertex *part = partial.NT(); - const PartialVertex *const part_end_loop = part + partial.GetArity(); - for (; part != part_end_loop; ++part, ++child_out) - *child_out = part->End(); - - starter.under->SetEnd(final); -} +} // namespace -void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { - const lm::ngram::ChartState &state = partial.CompletedState(); +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) { + const lm::ngram::ChartState &state = *end.state; unsigned char left = 0, right = 0; Trie *node = &root; @@ -77,18 +60,9 @@ void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { } node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - CompleteTransition(context, *node, partial); + node->under->SetEnd(end.history, end.score); } -} // namespace - -void VertexGenerator::FinishedSearch() { - Trie root; - root.under = &gen_.root_; - for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { - AddHypothesis(context_, root, i->second); - } - root.under->SortAndSet(context_, NULL); -} +#endif // BOOST_VERSION } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 60e86112..da563c2d 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -2,9 +2,11 @@ #define SEARCH_VERTEX_GENERATOR__ #include "search/edge.hh" +#include "search/types.hh" #include "search/vertex.hh" #include +#include namespace lm { namespace ngram { @@ -15,21 +17,44 @@ class ChartState; namespace search { class ContextBase; -class Final; -class VertexGenerator { +#if BOOST_VERSION > 104200 +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map extend; +}; + +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); + +#endif // BOOST_VERSION + +// Output makes the single-best or n-best list. +template class VertexGenerator { public: - VertexGenerator(ContextBase &context, Vertex &gen); + VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { + gen.root_.InitRoot(); + } void NewHypothesis(PartialEdge partial) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair ret(existing_.insert(std::make_pair(hash_value(state), partial))); - if (!ret.second && ret.first->second < partial) { - ret.first->second = partial; - } + nbest_.Add(existing_[hash_value(partial.CompletedState())], partial); } - void FinishedSearch(); + void FinishedSearch() { +#if BOOST_VERSION > 104200 + Trie root; + root.under = &gen_.root_; + for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, nbest_.Complete(i->second)); + } + existing_.clear(); + root.under->SortAndSet(context_); +#else + UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); +#endif + } const Vertex &Generating() const { return gen_; } @@ -38,8 +63,35 @@ class VertexGenerator { Vertex &gen_; - typedef boost::unordered_map Existing; + typedef boost::unordered_map Existing; Existing existing_; + + Output &nbest_; +}; + +// Special case for root vertex: everything should come together into the root +// node. In theory, this should happen naturally due to state collapsing with +// and . If that's the case, VertexGenerator is fine, though it will +// make one connection. +template class RootVertexGenerator { + public: + RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {} + + void NewHypothesis(PartialEdge partial) { + out_.Add(combine_, partial); + } + + void FinishedSearch() { + gen_.root_.InitRoot(); + NBestComplete completed(out_.Complete(combine_)); + gen_.root_.SetEnd(completed.history, completed.score); + } + + private: + Vertex &gen_; + + typename Output::Combine combine_; + Output &out_; }; } // namespace search diff --git a/klm/search/weights.cc b/klm/search/weights.cc deleted file mode 100644 index d65471ad..00000000 --- a/klm/search/weights.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include "search/weights.hh" -#include "util/tokenize_piece.hh" - -#include - -namespace search { - -namespace { -struct Insert { - void operator()(boost::unordered_map &map, StringPiece name, search::Score score) const { - std::string copy(name.data(), name.size()); - map[copy] = score; - } -}; - -struct DotProduct { - search::Score total; - DotProduct() : total(0.0) {} - - void operator()(const boost::unordered_map &map, StringPiece name, search::Score score) { - boost::unordered_map::const_iterator i(FindStringPiece(map, name)); - if (i != map.end()) - total += score * i->second; - } -}; - -template void Parse(StringPiece text, Map &map, Op &op) { - for (util::TokenIter spaces(text, ' '); spaces; ++spaces) { - util::TokenIter equals(*spaces, '='); - UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); - StringPiece name(*equals); - UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); - char *end; - // Assumes proper termination. - double value = std::strtod(equals->data(), &end); - UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); - UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); - op(map, name, value); - } -} - -} // namespace - -Weights::Weights(StringPiece text) { - Insert op; - Parse(text, map_, op); - lm_ = Steal("LanguageModel"); - oov_ = Steal("OOV"); - word_penalty_ = Steal("WordPenalty"); -} - -Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} - -search::Score Weights::DotNoLM(StringPiece text) const { - DotProduct dot; - Parse(text, map_, dot); - return dot.total; -} - -float Weights::Steal(const std::string &str) { - Map::iterator i(map_.find(str)); - if (i == map_.end()) { - return 0.0; - } else { - float ret = i->second; - map_.erase(i); - return ret; - } -} - -} // namespace search diff --git a/klm/search/weights.hh b/klm/search/weights.hh deleted file mode 100644 index df1c419f..00000000 --- a/klm/search/weights.hh +++ /dev/null @@ -1,52 +0,0 @@ -// For now, the individual features are not kept. -#ifndef SEARCH_WEIGHTS__ -#define SEARCH_WEIGHTS__ - -#include "search/types.hh" -#include "util/exception.hh" -#include "util/string_piece.hh" - -#include - -#include - -namespace search { - -class WeightParseException : public util::Exception { - public: - WeightParseException() {} - ~WeightParseException() throw() {} -}; - -class Weights { - public: - // Parses weights, sets lm_weight_, removes it from map_. - explicit Weights(StringPiece text); - - // Just the three scores we care about adding. - Weights(Score lm, Score oov, Score word_penalty); - - Score DotNoLM(StringPiece text) const; - - Score LM() const { return lm_; } - - Score OOV() const { return oov_; } - - Score WordPenalty() const { return word_penalty_; } - - // Mostly for testing. - const boost::unordered_map &GetMap() const { return map_; } - - private: - float Steal(const std::string &str); - - typedef boost::unordered_map Map; - - Map map_; - - Score lm_, oov_, word_penalty_; -}; - -} // namespace search - -#endif // SEARCH_WEIGHTS__ diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc deleted file mode 100644 index 4811ff06..00000000 --- a/klm/search/weights_test.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include "search/weights.hh" - -#define BOOST_TEST_MODULE WeightTest -#include -#include - -namespace search { -namespace { - -#define CHECK_WEIGHT(value, string) \ - i = parsed.find(string); \ - BOOST_REQUIRE(i != parsed.end()); \ - BOOST_CHECK_CLOSE((value), i->second, 0.001); - -BOOST_AUTO_TEST_CASE(parse) { - // These are not real feature weights. - Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); - const boost::unordered_map &parsed = w.GetMap(); - boost::unordered_map::const_iterator i; - CHECK_WEIGHT(0.0, "rarity"); - CHECK_WEIGHT(0.0, "phrase-SGT"); - CHECK_WEIGHT(9.45117, "phrase-TGS"); - CHECK_WEIGHT(2.33833, "lexical-SGT"); - BOOST_CHECK(parsed.end() == parsed.find("lm")); - BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); - CHECK_WEIGHT(-28.3317, "lexical-TGS"); - CHECK_WEIGHT(5.0, "glue?"); -} - -BOOST_AUTO_TEST_CASE(dot) { - Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); - BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); - BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); - BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); -} - -} // namespace -} // namespace search -- cgit v1.2.3 From 59737f22fccb9c2ab8744a719f4dbb95eedf7943 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 14 Dec 2012 12:48:26 -0800 Subject: Updated kenlm --- configure.ac | 2 +- klm/lm/binary_format.cc | 21 +- klm/lm/config.cc | 1 + klm/lm/config.hh | 59 +++--- klm/lm/max_order.hh | 2 - klm/lm/model.cc | 30 +-- klm/lm/search_trie.cc | 47 ++--- klm/util/Makefile.am | 1 + klm/util/exception.hh | 8 +- klm/util/file.cc | 38 ++-- klm/util/file.hh | 8 +- klm/util/file_piece.cc | 66 ++----- klm/util/file_piece.hh | 41 ++-- klm/util/file_piece_test.cc | 4 +- klm/util/have.hh | 12 +- klm/util/joint_sort.hh | 4 +- klm/util/read_compressed.cc | 403 +++++++++++++++++++++++++++++++++++++++ klm/util/read_compressed.hh | 74 +++++++ klm/util/read_compressed_test.cc | 94 +++++++++ klm/util/scoped.hh | 65 ++++--- klm/util/string_piece.hh | 19 +- klm/util/tokenize_piece.hh | 14 +- 22 files changed, 781 insertions(+), 232 deletions(-) create mode 100644 klm/util/read_compressed.cc create mode 100644 klm/util/read_compressed.hh create mode 100644 klm/util/read_compressed_test.cc (limited to 'klm/lm/max_order.hh') diff --git a/configure.ac b/configure.ac index f1b9d132..f4650ca4 100644 --- a/configure.ac +++ b/configure.ac @@ -87,7 +87,7 @@ AC_CHECK_HEADER(lzma.h, AC_PROG_INSTALL -CPPFLAGS="-DPIC -fPIC $CPPFLAGS -DHAVE_CONFIG_H" +CPPFLAGS="-DPIC -fPIC $CPPFLAGS -DHAVE_CONFIG_H -DKENLM_MAX_ORDER=6" # core cdec stuff AC_CONFIG_FILES([Makefile]) diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index efa67056..39c4a9b6 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -16,11 +16,11 @@ namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; -// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). +// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; const long int kMagicVersion = 5; -// Old binary files built on 32-bit machines have this header. +// Old binary files built on 32-bit machines have this header. // TODO: eliminate with next binary release. struct OldSanity { char magic[sizeof(kMagicBytes)]; @@ -39,7 +39,7 @@ struct OldSanity { }; -// Test values aligned to 8 bytes. +// Test values aligned to 8 bytes. struct Sanity { char magic[ALIGN8(sizeof(kMagicBytes))]; float zero_f, one_f, minus_half_f; @@ -101,7 +101,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { - // Grow the file to accomodate the search, using zeros. + // Grow the file to accomodate the search, using zeros. try { util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size); } catch (util::ErrnoException &e) { @@ -114,7 +114,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t return reinterpret_cast(backing.search.get()); } // mmap it now. - // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. + // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. std::size_t page_size = util::SizePage(); std::size_t alignment_cruft = adjusted_vocab % page_size; backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); @@ -122,7 +122,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t } else { util::MapAnonymous(memory_size, backing.search); return reinterpret_cast(backing.search.get()); - } + } } void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, std::size_t vocab_pad, Backing &backing) { @@ -140,7 +140,7 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_ util::FSyncOrThrow(backing.file.get()); break; } - // header and vocab share the same mmap. The header is written here because we know the counts. + // header and vocab share the same mmap. The header is written here because we know the counts. Parameters params = Parameters(); params.counts = counts; params.fixed.order = counts.size(); @@ -160,7 +160,7 @@ namespace detail { bool IsBinaryFormat(int fd) { const uint64_t size = util::SizeFile(fd); if (size == util::kBadSize || (size <= static_cast(sizeof(Sanity)))) return false; - // Try reading the header. + // Try reading the header. util::scoped_memory memory; try { util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory); @@ -214,7 +214,7 @@ void SeekPastHeader(int fd, const Parameters ¶ms) { uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) { const uint64_t file_size = util::SizeFile(backing.file.get()); - // The header is smaller than a page, so we have to map the whole header as well. + // The header is smaller than a page, so we have to map the whole header as well. std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size); if (file_size != util::kBadSize && static_cast(file_size) < total_map) UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); @@ -233,7 +233,8 @@ void ComplainAboutARPA(const Config &config, ModelType model_type) { if (config.write_mmap || !config.messages) return; if (config.arpa_complain == Config::ALL) { *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; - } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) { + } else if (config.arpa_complain == Config::EXPENSIVE && + (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; } } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index f9d988ca..9520c41c 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -6,6 +6,7 @@ namespace lm { namespace ngram { Config::Config() : + show_progress(true), messages(&std::cerr), enumerate_vocab(NULL), unknown_missing(COMPLAIN), diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 739cee9c..0de7b7c6 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -11,46 +11,52 @@ /* Configuration for ngram model. Separate header to reduce pollution. */ namespace lm { - + class EnumerateVocab; namespace ngram { struct Config { - // EFFECTIVE FOR BOTH ARPA AND BINARY READS + // EFFECTIVE FOR BOTH ARPA AND BINARY READS + + // (default true) print progress bar to messages + bool show_progress; // Where to log messages including the progress bar. Set to NULL for // silence. std::ostream *messages; + std::ostream *ProgressMessages() const { + return show_progress ? messages : 0; + } + // This will be called with every string in the vocabulary. See // enumerate_vocab.hh for more detail. Config does not take ownership; you - // are still responsible for deleting it (or stack allocating). + // are still responsible for deleting it (or stack allocating). EnumerateVocab *enumerate_vocab; - // ONLY EFFECTIVE WHEN READING ARPA - // What to do when isn't in the provided model. + // What to do when isn't in the provided model. WarningAction unknown_missing; - // What to do when or is missing from the model. - // If THROW_UP, the exception will be of type util::SpecialWordMissingException. + // What to do when or is missing from the model. + // If THROW_UP, the exception will be of type util::SpecialWordMissingException. WarningAction sentence_marker_missing; // What to do with a positive log probability. For COMPLAIN and SILENT, map - // to 0. + // to 0. WarningAction positive_log_probability; - // The probability to substitute for if it's missing from the model. + // The probability to substitute for if it's missing from the model. // No effect if the model has or unknown_missing == THROW_UP. float unknown_missing_logprob; // Size multiplier for probing hash table. Must be > 1. Space is linear in // this. Time is probing_multiplier / (probing_multiplier - 1). No effect - // for sorted variant. + // for sorted variant. // If you find yourself setting this to a low number, consider using the - // TrieModel which has lower memory consumption. + // TrieModel which has lower memory consumption. float probing_multiplier; // Amount of memory to use for building. The actual memory usage will be @@ -58,10 +64,10 @@ struct Config { // models. std::size_t building_memory; - // Template for temporary directory appropriate for passing to mkdtemp. + // Template for temporary directory appropriate for passing to mkdtemp. // The characters XXXXXX are appended before passing to mkdtemp. Only // applies to trie. If NULL, defaults to write_mmap. If that's NULL, - // defaults to input file name. + // defaults to input file name. const char *temporary_directory_prefix; // Level of complaining to do when loading from ARPA instead of binary format. @@ -69,49 +75,46 @@ struct Config { ARPALoadComplain arpa_complain; // While loading an ARPA file, also write out this binary format file. Set - // to NULL to disable. + // to NULL to disable. const char *write_mmap; enum WriteMethod { - WRITE_MMAP, // Map the file directly. - WRITE_AFTER // Write after we're done. + WRITE_MMAP, // Map the file directly. + WRITE_AFTER // Write after we're done. }; WriteMethod write_method; - // Include the vocab in the binary file? Only effective if write_mmap != NULL. + // Include the vocab in the binary file? Only effective if write_mmap != NULL. bool include_vocab; - // Left rest options. Only used when the model includes rest costs. + // Left rest options. Only used when the model includes rest costs. enum RestFunction { REST_MAX, // Maximum of any score to the left - REST_LOWER, // Use lower-order files given below. + REST_LOWER, // Use lower-order files given below. }; RestFunction rest_function; - // Only used for REST_LOWER. + // Only used for REST_LOWER. std::vector rest_lower_files; - // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used - // to quantize (and one of the remaining backoffs will be 0). + // to quantize (and one of the remaining backoffs will be 0). uint8_t prob_bits, backoff_bits; // Bhiksha compression (simple form). Only works with trie. uint8_t pointer_bhiksha_bits; - - + // ONLY EFFECTIVE WHEN READING BINARY - + // How to get the giant array into memory: lazy mmap, populate, read etc. - // See util/mmap.hh for details of MapMethod. + // See util/mmap.hh for details of MapMethod. util::LoadMethod load_method; - - // Set defaults. + // Set defaults. Config(); }; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index ea0dea46..3eb97ccd 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -7,5 +7,3 @@ #ifndef KENLM_ORDER_MESSAGE #define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif - -#define KENLM_MAX_ORDER 5 diff --git a/klm/lm/model.cc b/klm/lm/model.cc index fc61efee..a40fd2fb 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -37,7 +37,7 @@ template void GenericModel GenericModel::GenericModel(const char *file, const Config &config) { LoadLM(file, config, *this); - // g++ prints warnings unless these are fully initialized. + // g++ prints warnings unless these are fully initialized. State begin_sentence = State(); begin_sentence.length = 1; begin_sentence.words[0] = vocab_.BeginSentence(); @@ -69,8 +69,8 @@ template void GenericModel void GenericModel::InitializeFromARPA(const char *file, const Config &config) { - // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. - util::FilePiece f(backing_.file.release(), file, config.messages); + // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. + util::FilePiece f(backing_.file.release(), file, config.ProgressMessages()); try { std::vector counts; // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. @@ -80,7 +80,7 @@ template void GenericModel 1.0"); std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); - // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); if (config.write_mmap) { @@ -95,7 +95,7 @@ template void GenericModel FullScoreReturn GenericModel void GenericModel::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { - // Generate a state from context. + // Generate a state from context. context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); if (context_rend == context_rbegin) { out_state.length = 0; @@ -191,7 +191,7 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel FullScoreReturn GenericModel(out_state.length) - 1; @@ -217,10 +217,10 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) { } } // namespace -/* Ugly optimized function. Produce a score excluding backoff. - * The search goes in increasing order of ngram length. +/* Ugly optimized function. Produce a score excluding backoff. + * The search goes in increasing order of ngram length. * Context goes backward, so context_begin is the word immediately preceeding - * new_word. + * new_word. */ template FullScoreReturn GenericModel::ScoreExceptBackoff( const WordIndex *const context_rbegin, @@ -229,7 +229,7 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel(current_); const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex); for (reader.Rewind(); reader && (current_ != allocated_); ) { @@ -109,7 +109,7 @@ class BackoffMessages { ++reader; break; case 1: - // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. + // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. for (const WordIndex *w = reinterpret_cast(current_); w != reinterpret_cast(current_) + order; ++w, ++extend_out) *extend_out = *w; current_ += entry_size_; break; @@ -126,7 +126,7 @@ class BackoffMessages { break; } } - // Now this is a list of blanks that extend right. + // Now this is a list of blanks that extend right. entry_size_ = sizeof(WordIndex) * order; Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get())); current_ = (uint8_t*)backing_.get(); @@ -153,7 +153,7 @@ class BackoffMessages { private: void FinishedAdding() { Resize(current_ - (uint8_t*)backing_.get()); - // Sort requests in same order as files. + // Sort requests in same order as files. std::sort( util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)), util::SizedIterator(util::SizedProxy(current_, entry_size_)), @@ -220,7 +220,7 @@ class SRISucks { } private: - // This used to be one array. Then I needed to separate it by order for quantization to work. + // This used to be one array. Then I needed to separate it by order for quantization to work. std::vector values_[KENLM_MAX_ORDER - 1]; BackoffMessages messages_[KENLM_MAX_ORDER - 1]; @@ -253,7 +253,7 @@ class FindBlanks { ++counts_.back(); } - // Unigrams wrote one past. + // Unigrams wrote one past. void Cleanup() { --counts_[0]; } @@ -270,15 +270,15 @@ class FindBlanks { SRISucks &sri_; }; -// Phase to actually write n-grams to the trie. +// Phase to actually write n-grams to the trie. template class WriteEntries { public: - WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : + WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), quant_(quant), unigrams_(unigrams), middle_(middle), - longest_(longest), + longest_(longest), bigram_pack_((order == 2) ? static_cast(longest_) : static_cast(*middle_)), order_(order), sri_(sri) {} @@ -328,7 +328,7 @@ struct Gram { const WordIndex *begin, *end; - // For queue, this is the direction we want. + // For queue, this is the direction we want. bool operator<(const Gram &other) const { return std::lexicographical_compare(other.begin, other.end, begin, end); } @@ -353,7 +353,7 @@ template class BlankManager { been_length_ = length; return; } - // There are blanks to insert starting with order blank. + // There are blanks to insert starting with order blank. unsigned char blank = cur - to + 1; UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context."); const float *lower_basis; @@ -363,7 +363,7 @@ template class BlankManager { assert(*lower_basis != kBadProb); doing_.MiddleBlank(blank, to, based_on, *lower_basis); *pre = *cur; - // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. + // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. basis_[blank - 1] = kBadProb; } *pre = *cur; @@ -377,7 +377,7 @@ template class BlankManager { unsigned char been_length_; float basis_[KENLM_MAX_ORDER]; - + Doing &doing_; }; @@ -451,7 +451,7 @@ template void TrainProbQuantizer(uint8_t order, uint64_t count, Re } void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { - // Fill unigram probabilities. + // Fill unigram probabilities. try { rewind(file); for (WordIndex i = 0; i < unigram_count; ++i) { @@ -486,7 +486,7 @@ template void BuildTrie(SortedFiles &files, std::ve util::scoped_memory unigrams; MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); FindBlanks finder(counts.size(), reinterpret_cast(unigrams.get()), sri); - RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); + RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder); fixed_counts = finder.Counts(); } unigram_file.reset(util::FDOpenOrThrow(unigram_fd)); @@ -504,7 +504,8 @@ template void BuildTrie(SortedFiles &files, std::ve inputs[i-2].Rewind(); } if (Quant::kTrain) { - util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing"); + util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), + config.ProgressMessages(), "Quantizing"); for (unsigned char i = 2; i < counts.size(); ++i) { TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } @@ -519,13 +520,13 @@ template void BuildTrie(SortedFiles &files, std::ve for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); } - // Fill entries except unigram probabilities. + // Fill entries except unigram probabilities. { WriteEntries writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); - RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); + RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer); } - // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. + // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. for (unsigned char order = 2; order <= counts.size(); ++order) { const RecordReader &context = contexts[order - 2]; if (context) { @@ -541,13 +542,13 @@ template void BuildTrie(SortedFiles &files, std::ve } /* Set ending offsets so the last entry will be sized properly */ - // Last entry for unigrams was already set. + // Last entry for unigrams was already set. if (out.middle_begin_ != out.middle_end_) { for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { i->FinishedLoading((i+1)->InsertIndex(), config); } (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config); - } + } } template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { @@ -595,7 +596,7 @@ template void TrieSearch::Initializ } else { temporary_prefix = file; } - // At least 1MB sorting memory. + // At least 1MB sorting memory. SortedFiles sorted(config, f, counts, std::max(config.building_memory, 1048576), temporary_prefix, vocab); BuildTrie(sorted, counts, config, *this, quant_, vocab, backing); diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 5306850f..a676bdb3 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -27,6 +27,7 @@ libklm_util_a_SOURCES = \ mmap.cc \ murmur_hash.cc \ pool.cc \ + read_compressed.cc \ string_piece.cc \ usage.cc diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 053a850b..0165a7a3 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -87,8 +87,14 @@ template typename Except::template ExceptionTag= 3 +#define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0) +#else +#define UTIL_UNLIKELY(x) (x) +#endif + #define UTIL_THROW_IF(Condition, Exception, Modify) do { \ - if (Condition) { \ + if (UTIL_UNLIKELY(Condition)) { \ Exception UTIL_e; \ UTIL_SET_LOCATION(UTIL_e, #Exception, #Condition); \ UTIL_e << Modify; \ diff --git a/klm/util/file.cc b/klm/util/file.cc index 6bf879ac..b9a77cf9 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -15,6 +15,8 @@ #if defined(_WIN32) || defined(_WIN64) #include #include +#include +#include #else #include #endif @@ -48,7 +50,7 @@ int OpenReadOrThrow(const char *name) { int CreateOrThrow(const char *name) { int ret; #if defined(_WIN32) || defined(_WIN64) - UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); + UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR | _O_BINARY, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); #else UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name); #endif @@ -74,16 +76,22 @@ void ResizeOrThrow(int fd, uint64_t to) { #endif } -#ifdef WIN32 -typedef int ssize_t; +std::size_t PartialRead(int fd, void *to, std::size_t amount) { +#if defined(_WIN32) || defined(_WIN64) + amount = min(static_cast(INT_MAX), amount); + int ret = _read(fd, to, amount); +#else + ssize_t ret = read(fd, to, amount); #endif + UTIL_THROW_IF(ret < 0, ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); + return static_cast(ret); +} void ReadOrThrow(int fd, void *to_void, std::size_t amount) { uint8_t *to = static_cast(to_void); while (amount) { - ssize_t ret = read(fd, to, amount); - UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); - UTIL_THROW_IF(ret == 0, EndOfFileException, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); + std::size_t ret = PartialRead(fd, to, amount); + UTIL_THROW_IF(ret == 0, EndOfFileException, " in fd " << fd << " but there should be " << amount << " more bytes to read."); amount -= ret; to += ret; } @@ -93,8 +101,7 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) { uint8_t *to = static_cast(to_void); std::size_t remaining = amount; while (remaining) { - ssize_t ret = read(fd, to, remaining); - UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << remaining << " from fd " << fd << " failed."); + std::size_t ret = PartialRead(fd, to, remaining); if (!ret) return amount - remaining; remaining -= ret; to += ret; @@ -105,7 +112,11 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) { void WriteOrThrow(int fd, const void *data_void, std::size_t size) { const uint8_t *data = static_cast(data_void); while (size) { +#if defined(_WIN32) || defined(_WIN64) + int ret = write(fd, data, min(static_cast(INT_MAX), size)); +#else ssize_t ret = write(fd, data, size); +#endif if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); data += ret; size -= ret; @@ -114,7 +125,7 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) { void WriteOrThrow(FILE *to, const void *data, std::size_t size) { assert(size); - if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); + UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), util::ErrnoException, "Short write; requested size " << size); } void FSyncOrThrow(int fd) { @@ -149,14 +160,15 @@ void SeekEnd(int fd) { std::FILE *FDOpenOrThrow(scoped_fd &file) { std::FILE *ret = fdopen(file.get(), "r+b"); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen"); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get()); file.release(); return ret; } -std::FILE *FOpenOrThrow(const char *path, const char *mode) { - std::FILE *ret; - UTIL_THROW_IF(!(ret = fopen(path, mode)), util::ErrnoException, "Could not fopen " << path << " for " << mode); +std::FILE *FDOpenReadOrThrow(scoped_fd &file) { + std::FILE *ret = fdopen(file.get(), "rb"); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get()); + file.release(); return ret; } diff --git a/klm/util/file.hh b/klm/util/file.hh index 185cb1f3..c24580d6 100644 --- a/klm/util/file.hh +++ b/klm/util/file.hh @@ -32,8 +32,6 @@ class scoped_fd { return ret; } - operator bool() { return fd_ != -1; } - private: int fd_; @@ -76,8 +74,9 @@ uint64_t SizeFile(int fd); void ResizeOrThrow(int fd, uint64_t to); +std::size_t PartialRead(int fd, void *to, std::size_t size); void ReadOrThrow(int fd, void *to, std::size_t size); -std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount); +std::size_t ReadOrEOF(int fd, void *to_void, std::size_t size); void WriteOrThrow(int fd, const void *data_void, std::size_t size); void WriteOrThrow(FILE *to, const void *data, std::size_t size); @@ -90,8 +89,7 @@ void AdvanceOrThrow(int fd, int64_t off); void SeekEnd(int fd); std::FILE *FDOpenOrThrow(scoped_fd &file); - -std::FILE *FOpenOrThrow(const char *path, const char *mode); +std::FILE *FDOpenReadOrThrow(scoped_fd &file); class TempMaker { public: diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 280f438c..5a208eff 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -14,7 +14,6 @@ #include #include -#include #include #include #include @@ -26,13 +25,6 @@ ParseNumberException::ParseNumberException(StringPiece value) throw() { *this << "Could not parse \"" << value << "\" into a number"; } -#ifdef HAVE_ZLIB -GZException::GZException(gzFile file) { - int num; - *this << gzerror(file, &num) << " from zlib"; -} -#endif // HAVE_ZLIB - // Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale). const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; @@ -48,19 +40,7 @@ FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std: Initialize(name, show_progress, min_buffer); } -FilePiece::~FilePiece() { -#ifdef HAVE_ZLIB - if (gz_file_) { - // zlib took ownership - file_.release(); - int ret; - if (Z_OK != (ret = gzclose(gz_file_))) { - std::cerr << "could not close file " << file_name_ << " using zlib" << std::endl; - abort(); - } - } -#endif -} +FilePiece::~FilePiece() {} StringPiece FilePiece::ReadLine(char delim) { std::size_t skip = 0; @@ -95,9 +75,6 @@ unsigned long int FilePiece::ReadULong() { } void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer) { -#ifdef HAVE_ZLIB - gz_file_ = NULL; -#endif file_name_ = name; default_map_size_ = page_ * std::max((min_buffer / page_ + 1), 2); @@ -117,10 +94,7 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::s } Shift(); // gzip detect. - if ((position_end_ - position_) > 2 && *position_ == 0x1f && static_cast(*(position_ + 1)) == 0x8b) { -#ifndef HAVE_ZLIB - UTIL_THROW(GZException, "Looks like a gzip file but support was not compiled in."); -#endif + if ((position_end_ - position_) >= ReadCompressed::kMagicSize && ReadCompressed::DetectCompressedMagic(position_)) { if (!fallback_to_read_) { at_end_ = false; TransitionToRead(); @@ -197,7 +171,7 @@ void FilePiece::Shift() { if (fallback_to_read_) ReadShift(); for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) { - if (isspace(*last_space_)) break; + if (kSpaces[static_cast(*last_space_)]) break; } } @@ -248,17 +222,14 @@ void FilePiece::TransitionToRead() { position_ = data_.begin(); position_end_ = position_; -#ifdef HAVE_ZLIB - assert(!gz_file_); - gz_file_ = gzdopen(file_.get(), "r"); - UTIL_THROW_IF(!gz_file_, GZException, "zlib failed to open " << file_name_); -#endif + try { + fell_back_.Reset(file_.release()); + } catch (util::Exception &e) { + e << " in file " << file_name_; + throw; + } } -#ifdef WIN32 -typedef int ssize_t; -#endif - void FilePiece::ReadShift() { assert(fallback_to_read_); // Bytes [data_.begin(), position_) have been consumed. @@ -283,7 +254,7 @@ void FilePiece::ReadShift() { position_ = data_.begin(); position_end_ = position_ + valid_length; } else { - size_t moving = position_end_ - position_; + std::size_t moving = position_end_ - position_; memmove(data_.get(), position_, moving); position_ = data_.begin(); position_end_ = position_ + moving; @@ -291,20 +262,9 @@ void FilePiece::ReadShift() { } } - ssize_t read_return; -#ifdef HAVE_ZLIB - read_return = gzread(gz_file_, static_cast(data_.get()) + already_read, default_map_size_ - already_read); - if (read_return == -1) throw GZException(gz_file_); - if (total_size_ != kBadSize) { - // Just get the position, don't actually seek. Apparently this is how you do it. . . - off_t ret = lseek(file_.get(), 0, SEEK_CUR); - if (ret != -1) progress_.Set(ret); - } -#else - read_return = read(file_.get(), static_cast(data_.get()) + already_read, default_map_size_ - already_read); - UTIL_THROW_IF(read_return == -1, ErrnoException, "read failed"); - progress_.Set(mapped_offset_); -#endif + std::size_t read_return = fell_back_.Read(static_cast(data_.get()) + already_read, default_map_size_ - already_read); + progress_.Set(fell_back_.RawAmount()); + if (read_return == 0) { at_end_ = true; } diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index af93d8aa..39bd1581 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -4,8 +4,8 @@ #include "util/ersatz_progress.hh" #include "util/exception.hh" #include "util/file.hh" -#include "util/have.hh" #include "util/mmap.hh" +#include "util/read_compressed.hh" #include "util/string_piece.hh" #include @@ -13,10 +13,6 @@ #include -#ifdef HAVE_ZLIB -#include -#endif - namespace util { class ParseNumberException : public Exception { @@ -25,28 +21,19 @@ class ParseNumberException : public Exception { ~ParseNumberException() throw() {} }; -class GZException : public Exception { - public: -#ifdef HAVE_ZLIB - explicit GZException(gzFile file); -#endif - GZException() throw() {} - ~GZException() throw() {} -}; - extern const bool kSpaces[256]; -// Memory backing the returned StringPiece may vanish on the next call. +// Memory backing the returned StringPiece may vanish on the next call. class FilePiece { public: - // 32 MB default. - explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432); - // Takes ownership of fd. name is used for messages. - explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432); + // 1 MB default. + explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576); + // Takes ownership of fd. name is used for messages. + explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576); ~FilePiece(); - - char get() { + + char get() { if (position_ == position_end_) { Shift(); if (at_end_) throw EndOfFileException(); @@ -54,14 +41,14 @@ class FilePiece { return *(position_++); } - // Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace(). + // Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace(). StringPiece ReadDelimited(const bool *delim = kSpaces) { SkipSpaces(delim); return Consume(FindDelimiterOrEOF(delim)); } // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. - // It is similar to getline in that way. + // It is similar to getline in that way. StringPiece ReadLine(char delim = '\n'); float ReadFloat(); @@ -69,7 +56,7 @@ class FilePiece { long int ReadLong(); unsigned long int ReadULong(); - // Skip spaces defined by isspace. + // Skip spaces defined by isspace. void SkipSpaces(const bool *delim = kSpaces) { for (; ; ++position_) { if (position_ == position_end_) Shift(); @@ -82,7 +69,7 @@ class FilePiece { } const std::string &FileName() const { return file_name_; } - + private: void Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer); @@ -122,9 +109,7 @@ class FilePiece { std::string file_name_; -#ifdef HAVE_ZLIB - gzFile gz_file_; -#endif // HAVE_ZLIB + ReadCompressed fell_back_; }; } // namespace util diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc index f912e18a..e79ece7a 100644 --- a/klm/util/file_piece_test.cc +++ b/klm/util/file_piece_test.cc @@ -38,7 +38,7 @@ BOOST_AUTO_TEST_CASE(MMapReadLine) { BOOST_CHECK_THROW(test.get(), EndOfFileException); } -#ifndef __APPLE__ +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__) /* Apple isn't happy with the popen, fileno, dup. And I don't want to * reimplement popen. This is an issue with the test. */ @@ -65,7 +65,7 @@ BOOST_AUTO_TEST_CASE(StreamReadLine) { BOOST_CHECK_THROW(test.get(), EndOfFileException); BOOST_REQUIRE(!pclose(catter)); } -#endif // __APPLE__ +#endif #ifdef HAVE_ZLIB diff --git a/klm/util/have.hh b/klm/util/have.hh index b8181e99..1523c0c5 100644 --- a/klm/util/have.hh +++ b/klm/util/have.hh @@ -2,22 +2,12 @@ #ifndef UTIL_HAVE__ #define UTIL_HAVE__ -#ifndef HAVE_ZLIB -#if !defined(_WIN32) && !defined(_WIN64) -#define HAVE_ZLIB -#endif -#endif - #ifndef HAVE_ICU //#define HAVE_ICU #endif #ifndef HAVE_BOOST -#define HAVE_BOOST -#endif - -#ifndef HAVE_THREADS -//#define HAVE_THREADS +//#define HAVE_BOOST #endif #endif // UTIL_HAVE__ diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh index cf3d8432..1b43ddcf 100644 --- a/klm/util/joint_sort.hh +++ b/klm/util/joint_sort.hh @@ -60,7 +60,7 @@ template class JointProxy { JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {} JointProxy(const JointProxy &other) : inner_(other.inner_) {} - operator const value_type() const { + operator value_type() const { value_type ret; ret.key = *inner_.key_; ret.value = *inner_.value_; @@ -121,7 +121,7 @@ template class LessWrapper : public std::binary_functi template class PairedIterator : public ProxyIterator > { public: - PairedIterator(const KeyIter &key, const ValueIter &value) : + PairedIterator(const KeyIter &key, const ValueIter &value) : ProxyIterator >(detail::JointProxy(key, value)) {} }; diff --git a/klm/util/read_compressed.cc b/klm/util/read_compressed.cc new file mode 100644 index 00000000..4ec94c4e --- /dev/null +++ b/klm/util/read_compressed.cc @@ -0,0 +1,403 @@ +#include "util/read_compressed.hh" + +#include "util/file.hh" +#include "util/have.hh" +#include "util/scoped.hh" + +#include +#include + +#include +#include +#include +#include + +#ifdef HAVE_ZLIB +#include +#endif + +#ifdef HAVE_BZLIB +#include +#endif + +#ifdef HAVE_XZLIB +#include +#endif + +namespace util { + +CompressedException::CompressedException() throw() {} +CompressedException::~CompressedException() throw() {} + +GZException::GZException() throw() {} +GZException::~GZException() throw() {} + +BZException::BZException() throw() {} +BZException::~BZException() throw() {} + +XZException::XZException() throw() {} +XZException::~XZException() throw() {} + +class ReadBase { + public: + virtual ~ReadBase() {} + + virtual std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) = 0; + + protected: + static void ReplaceThis(ReadBase *with, ReadCompressed &thunk) { + thunk.internal_.reset(with); + } + + static uint64_t &ReadCount(ReadCompressed &thunk) { + return thunk.raw_amount_; + } +}; + +namespace { + +// Completed file that other classes can thunk to. +class Complete : public ReadBase { + public: + std::size_t Read(void *, std::size_t, ReadCompressed &) { + return 0; + } +}; + +class Uncompressed : public ReadBase { + public: + explicit Uncompressed(int fd) : fd_(fd) {} + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + std::size_t got = PartialRead(fd_.get(), to, amount); + ReadCount(thunk) += got; + return got; + } + + private: + scoped_fd fd_; +}; + +class UncompressedWithHeader : public ReadBase { + public: + UncompressedWithHeader(int fd, void *already_data, std::size_t already_size) : fd_(fd) { + assert(already_size); + buf_.reset(malloc(already_size)); + if (!buf_.get()) throw std::bad_alloc(); + memcpy(buf_.get(), already_data, already_size); + remain_ = static_cast(buf_.get()); + end_ = remain_ + already_size; + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + assert(buf_.get()); + std::size_t sending = std::min(amount, end_ - remain_); + memcpy(to, remain_, sending); + remain_ += sending; + if (remain_ == end_) { + ReplaceThis(new Uncompressed(fd_.release()), thunk); + } + return sending; + } + + private: + scoped_malloc buf_; + uint8_t *remain_; + uint8_t *end_; + + scoped_fd fd_; +}; + +#ifdef HAVE_ZLIB +class GZip : public ReadBase { + private: + static const std::size_t kInputBuffer = 16384; + public: + GZip(int fd, void *already_data, std::size_t already_size) + : file_(fd), in_buffer_(malloc(kInputBuffer)) { + if (!in_buffer_.get()) throw std::bad_alloc(); + assert(already_size < kInputBuffer); + if (already_size) { + memcpy(in_buffer_.get(), already_data, already_size); + stream_.next_in = static_cast(in_buffer_.get()); + stream_.avail_in = already_size; + stream_.avail_in += ReadOrEOF(file_.get(), static_cast(in_buffer_.get()) + already_size, kInputBuffer - already_size); + } else { + stream_.avail_in = 0; + } + stream_.zalloc = Z_NULL; + stream_.zfree = Z_NULL; + stream_.opaque = Z_NULL; + stream_.msg = NULL; + // 32 for zlib and gzip decoding with automatic header detection. + // 15 for maximum window size. + UTIL_THROW_IF(Z_OK != inflateInit2(&stream_, 32 + 15), GZException, "Failed to initialize zlib."); + } + + ~GZip() { + if (Z_OK != inflateEnd(&stream_)) { + std::cerr << "zlib could not close properly." << std::endl; + abort(); + } + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + if (amount == 0) return 0; + stream_.next_out = static_cast(to); + stream_.avail_out = std::min(std::numeric_limits::max(), amount); + do { + if (!stream_.avail_in) ReadInput(thunk); + int result = inflate(&stream_, 0); + switch (result) { + case Z_OK: + break; + case Z_STREAM_END: + { + std::size_t ret = static_cast(stream_.next_out) - static_cast(to); + ReplaceThis(new Complete(), thunk); + return ret; + } + case Z_ERRNO: + UTIL_THROW(ErrnoException, "zlib error"); + default: + UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result); + } + } while (stream_.next_out == to); + return static_cast(stream_.next_out) - static_cast(to); + } + + private: + void ReadInput(ReadCompressed &thunk) { + assert(!stream_.avail_in); + stream_.next_in = static_cast(in_buffer_.get()); + stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); + ReadCount(thunk) += stream_.avail_in; + } + + scoped_fd file_; + scoped_malloc in_buffer_; + z_stream stream_; +}; +#endif // HAVE_ZLIB + +#ifdef HAVE_BZLIB +class BZip : public ReadBase { + public: + explicit BZip(int fd, void *already_data, std::size_t already_size) { + scoped_fd hold(fd); + closer_.reset(FDOpenReadOrThrow(hold)); + int bzerror = BZ_OK; + file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size); + switch (bzerror) { + case BZ_OK: + return; + case BZ_CONFIG_ERROR: + UTIL_THROW(BZException, "Looks like bzip2 was miscompiled."); + case BZ_PARAM_ERROR: + UTIL_THROW(BZException, "Parameter error"); + case BZ_IO_ERROR: + UTIL_THROW(BZException, "IO error reading file"); + case BZ_MEM_ERROR: + throw std::bad_alloc(); + } + } + + ~BZip() { + int bzerror = BZ_OK; + BZ2_bzReadClose(&bzerror, file_); + if (bzerror != BZ_OK) { + std::cerr << "bz2 readclose error" << std::endl; + abort(); + } + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + int bzerror = BZ_OK; + int ret = BZ2_bzRead(&bzerror, file_, to, std::min(static_cast(INT_MAX), amount)); + long pos; + switch (bzerror) { + case BZ_STREAM_END: + pos = ftell(closer_.get()); + if (pos != -1) ReadCount(thunk) = pos; + ReplaceThis(new Complete(), thunk); + return ret; + case BZ_OK: + pos = ftell(closer_.get()); + if (pos != -1) ReadCount(thunk) = pos; + return ret; + default: + UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror); + } + } + + private: + scoped_FILE closer_; + BZFILE *file_; +}; +#endif // HAVE_BZLIB + +#ifdef HAVE_XZLIB +class XZip : public ReadBase { + private: + static const std::size_t kInputBuffer = 16384; + public: + XZip(int fd, void *already_data, std::size_t already_size) + : file_(fd), in_buffer_(malloc(kInputBuffer)), stream_(), action_(LZMA_RUN) { + if (!in_buffer_.get()) throw std::bad_alloc(); + assert(already_size < kInputBuffer); + if (already_size) { + memcpy(in_buffer_.get(), already_data, already_size); + stream_.next_in = static_cast(in_buffer_.get()); + stream_.avail_in = already_size; + stream_.avail_in += ReadOrEOF(file_.get(), static_cast(in_buffer_.get()) + already_size, kInputBuffer - already_size); + } else { + stream_.avail_in = 0; + } + stream_.allocator = NULL; + lzma_ret ret = lzma_stream_decoder(&stream_, UINT64_MAX, LZMA_CONCATENATED); + switch (ret) { + case LZMA_OK: + break; + case LZMA_MEM_ERROR: + UTIL_THROW(ErrnoException, "xz open error"); + default: + UTIL_THROW(XZException, "xz error code " << ret); + } + } + + ~XZip() { + lzma_end(&stream_); + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + if (amount == 0) return 0; + stream_.next_out = static_cast(to); + stream_.avail_out = amount; + do { + if (!stream_.avail_in) ReadInput(thunk); + lzma_ret status = lzma_code(&stream_, action_); + switch (status) { + case LZMA_OK: + break; + case LZMA_STREAM_END: + UTIL_THROW_IF(action_ != LZMA_FINISH, XZException, "Input not finished yet."); + { + std::size_t ret = static_cast(stream_.next_out) - static_cast(to); + ReplaceThis(new Complete(), thunk); + return ret; + } + case LZMA_MEM_ERROR: + throw std::bad_alloc(); + case LZMA_FORMAT_ERROR: + UTIL_THROW(XZException, "xzlib says file format not recognized"); + case LZMA_OPTIONS_ERROR: + UTIL_THROW(XZException, "xzlib says unsupported compression options"); + case LZMA_DATA_ERROR: + UTIL_THROW(XZException, "xzlib says this file is corrupt"); + case LZMA_BUF_ERROR: + UTIL_THROW(XZException, "xzlib says unexpected end of input"); + default: + UTIL_THROW(XZException, "unrecognized xzlib error " << status); + } + } while (stream_.next_out == to); + return static_cast(stream_.next_out) - static_cast(to); + } + + private: + void ReadInput(ReadCompressed &thunk) { + assert(!stream_.avail_in); + stream_.next_in = static_cast(in_buffer_.get()); + stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); + if (!stream_.avail_in) action_ = LZMA_FINISH; + ReadCount(thunk) += stream_.avail_in; + } + + scoped_fd file_; + scoped_malloc in_buffer_; + lzma_stream stream_; + + lzma_action action_; +}; +#endif // HAVE_XZLIB + +enum MagicResult { + UNKNOWN, GZIP, BZIP, XZIP +}; + +MagicResult DetectMagic(const void *from_void) { + const uint8_t *header = static_cast(from_void); + if (header[0] == 0x1f && header[1] == 0x8b) { + return GZIP; + } + if (header[0] == 'B' && header[1] == 'Z') { + return BZIP; + } + const uint8_t xzmagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 }; + if (!memcmp(header, xzmagic, 6)) { + return XZIP; + } + return UNKNOWN; +} + +ReadBase *ReadFactory(int fd, uint64_t &raw_amount) { + scoped_fd hold(fd); + unsigned char header[ReadCompressed::kMagicSize]; + raw_amount = ReadOrEOF(fd, header, ReadCompressed::kMagicSize); + if (!raw_amount) + return new Uncompressed(hold.release()); + if (raw_amount != ReadCompressed::kMagicSize) + return new UncompressedWithHeader(hold.release(), header, raw_amount); + switch (DetectMagic(header)) { + case GZIP: +#ifdef HAVE_ZLIB + return new GZip(hold.release(), header, ReadCompressed::kMagicSize); +#else + UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in."); +#endif + case BZIP: +#ifdef HAVE_BZLIB + return new BZip(hold.release(), header, ReadCompressed::kMagicSize); +#else + UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZ), but bzip support was not compiled in."); +#endif + case XZIP: +#ifdef HAVE_XZLIB + return new XZip(hold.release(), header, ReadCompressed::kMagicSize); +#else + UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in."); +#endif + case UNKNOWN: + break; + } + try { + AdvanceOrThrow(fd, -ReadCompressed::kMagicSize); + } catch (const util::ErrnoException &e) { + return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize); + } + return new Uncompressed(hold.release()); +} + +} // namespace + +bool ReadCompressed::DetectCompressedMagic(const void *from_void) { + return DetectMagic(from_void) != UNKNOWN; +} + +ReadCompressed::ReadCompressed(int fd) { + Reset(fd); +} + +ReadCompressed::ReadCompressed() {} + +ReadCompressed::~ReadCompressed() {} + +void ReadCompressed::Reset(int fd) { + internal_.reset(); + internal_.reset(ReadFactory(fd, raw_amount_)); +} + +std::size_t ReadCompressed::Read(void *to, std::size_t amount) { + return internal_->Read(to, amount, *this); +} + +} // namespace util diff --git a/klm/util/read_compressed.hh b/klm/util/read_compressed.hh new file mode 100644 index 00000000..83ca9fb2 --- /dev/null +++ b/klm/util/read_compressed.hh @@ -0,0 +1,74 @@ +#ifndef UTIL_READ_COMPRESSED__ +#define UTIL_READ_COMPRESSED__ + +#include "util/exception.hh" +#include "util/scoped.hh" + +#include + +#include + +namespace util { + +class CompressedException : public Exception { + public: + CompressedException() throw(); + virtual ~CompressedException() throw(); +}; + +class GZException : public CompressedException { + public: + GZException() throw(); + ~GZException() throw(); +}; + +class BZException : public CompressedException { + public: + BZException() throw(); + ~BZException() throw(); +}; + +class XZException : public CompressedException { + public: + XZException() throw(); + ~XZException() throw(); +}; + +class ReadBase; + +class ReadCompressed { + public: + static const std::size_t kMagicSize = 6; + // Must have at least kMagicSize bytes. + static bool DetectCompressedMagic(const void *from); + + // Takes ownership of fd. + explicit ReadCompressed(int fd); + + // Must call Reset later. + ReadCompressed(); + + ~ReadCompressed(); + + // Takes ownership of fd. + void Reset(int fd); + + std::size_t Read(void *to, std::size_t amount); + + uint64_t RawAmount() const { return raw_amount_; } + + private: + friend class ReadBase; + + scoped_ptr internal_; + + uint64_t raw_amount_; + + // No copying. + ReadCompressed(const ReadCompressed &); + void operator=(const ReadCompressed &); +}; + +} // namespace util + +#endif // UTIL_READ_COMPRESSED__ diff --git a/klm/util/read_compressed_test.cc b/klm/util/read_compressed_test.cc new file mode 100644 index 00000000..6fd97e5e --- /dev/null +++ b/klm/util/read_compressed_test.cc @@ -0,0 +1,94 @@ +#include "util/read_compressed.hh" + +#include "util/file.hh" +#include "util/have.hh" + +#define BOOST_TEST_MODULE ReadCompressedTest +#include +#include + +#include +#include + +#include + +namespace util { +namespace { + +void ReadLoop(ReadCompressed &reader, void *to_void, std::size_t amount) { + uint8_t *to = static_cast(to_void); + while (amount) { + std::size_t ret = reader.Read(to, amount); + BOOST_REQUIRE(ret); + to += ret; + amount -= ret; + } +} + +void TestRandom(const char *compressor) { + const uint32_t kSize4 = 100000 / 4; + char name[] = "tempXXXXXX"; + + // Write test file. + { + scoped_fd original(mkstemp(name)); + BOOST_REQUIRE(original.get() > 0); + for (uint32_t i = 0; i < kSize4; ++i) { + WriteOrThrow(original.get(), &i, sizeof(uint32_t)); + } + } + + char gzname[] = "tempXXXXXX"; + scoped_fd gzipped(mkstemp(gzname)); + + std::string command(compressor); +#ifdef __CYGWIN__ + command += ".exe"; +#endif + command += " <\""; + command += name; + command += "\" >\""; + command += gzname; + command += "\""; + BOOST_REQUIRE_EQUAL(0, system(command.c_str())); + + BOOST_CHECK_EQUAL(0, unlink(name)); + BOOST_CHECK_EQUAL(0, unlink(gzname)); + + ReadCompressed reader(gzipped.release()); + for (uint32_t i = 0; i < kSize4; ++i) { + uint32_t got; + ReadLoop(reader, &got, sizeof(uint32_t)); + BOOST_CHECK_EQUAL(i, got); + } + + char ignored; + BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); + // Test double EOF call. + BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); +} + +BOOST_AUTO_TEST_CASE(Uncompressed) { + TestRandom("cat"); +} + +#ifdef HAVE_ZLIB +BOOST_AUTO_TEST_CASE(ReadGZ) { + TestRandom("gzip"); +} +#endif // HAVE_ZLIB + +#ifdef HAVE_BZLIB +BOOST_AUTO_TEST_CASE(ReadBZ) { + TestRandom("bzip2"); +} +#endif // HAVE_BZLIB + +#ifdef HAVE_XZLIB +BOOST_AUTO_TEST_CASE(ReadXZ) { + TestRandom("xz"); +} +#endif + +} // namespace +} // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index 93e2e817..d62c6df1 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -1,40 +1,13 @@ #ifndef UTIL_SCOPED__ #define UTIL_SCOPED__ +/* Other scoped objects in the style of scoped_ptr. */ #include "util/exception.hh" - -/* Other scoped objects in the style of scoped_ptr. */ #include #include namespace util { -template class scoped_thing { - public: - explicit scoped_thing(T *c = static_cast(0)) : c_(c) {} - - ~scoped_thing() { if (c_) Free(c_); } - - void reset(T *c) { - if (c_) Free(c_); - c_ = c; - } - - T &operator*() { return *c_; } - const T&operator*() const { return *c_; } - T &operator->() { return *c_; } - const T&operator->() const { return *c_; } - - T *get() { return c_; } - const T *get() const { return c_; } - - private: - T *c_; - - scoped_thing(const scoped_thing &); - scoped_thing &operator=(const scoped_thing &); -}; - class scoped_malloc { public: scoped_malloc() : p_(NULL) {} @@ -77,9 +50,6 @@ template class scoped_array { T &operator*() { return *c_; } const T&operator*() const { return *c_; } - T &operator->() { return *c_; } - const T&operator->() const { return *c_; } - T &operator[](std::size_t idx) { return c_[idx]; } const T &operator[](std::size_t idx) const { return c_[idx]; } @@ -90,6 +60,39 @@ template class scoped_array { private: T *c_; + + scoped_array(const scoped_array &); + void operator=(const scoped_array &); +}; + +template class scoped_ptr { + public: + explicit scoped_ptr(T *content = NULL) : c_(content) {} + + ~scoped_ptr() { delete c_; } + + T *get() { return c_; } + const T* get() const { return c_; } + + T &operator*() { return *c_; } + const T&operator*() const { return *c_; } + + T *operator->() { return c_; } + const T*operator->() const { return c_; } + + T &operator[](std::size_t idx) { return c_[idx]; } + const T &operator[](std::size_t idx) const { return c_[idx]; } + + void reset(T *to = NULL) { + scoped_ptr other(c_); + c_ = to; + } + + private: + T *c_; + + scoped_ptr(const scoped_ptr &); + void operator=(const scoped_ptr &); }; } // namespace util diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh index be6a643d..51481646 100644 --- a/klm/util/string_piece.hh +++ b/klm/util/string_piece.hh @@ -1,6 +1,6 @@ /* If you use ICU in your program, then compile with -DHAVE_ICU -licui18n. If * you don't use ICU, then this will use the Google implementation from Chrome. - * This has been modified from the original version to let you choose. + * This has been modified from the original version to let you choose. */ // Copyright 2008, Google Inc. @@ -62,9 +62,9 @@ #include #include -// Old versions of ICU don't define operator== and operator!=. +// Old versions of ICU don't define operator== and operator!=. #if (U_ICU_VERSION_MAJOR_NUM < 4) || ((U_ICU_VERSION_MAJOR_NUM == 4) && (U_ICU_VERSION_MINOR_NUM < 4)) -#warning You are using an old version of ICU. Consider upgrading to ICU >= 4.6. +#warning You are using an old version of ICU. Consider upgrading to ICU >= 4.6. inline bool operator==(const StringPiece& x, const StringPiece& y) { if (x.size() != y.size()) return false; @@ -274,15 +274,28 @@ struct StringPieceCompatibleEquals : public std::binary_function typename T::const_iterator FindStringPiece(const T &t, const StringPiece &key) { +#if BOOST_VERSION < 104200 + std::string temp(key.data(), key.size()); + return t.find(temp); +#else return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +#endif } + template typename T::iterator FindStringPiece(T &t, const StringPiece &key) { +#if BOOST_VERSION < 104200 + std::string temp(key.data(), key.size()); + return t.find(temp); +#else return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +#endif } #endif #ifdef HAVE_ICU U_NAMESPACE_END +using U_NAMESPACE_QUALIFIER StringPiece; #endif + #endif // BASE_STRING_PIECE_H__ diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index 4a7f5460..a588c3fc 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -20,6 +20,7 @@ class OutOfTokens : public Exception { class SingleCharacter { public: + SingleCharacter() {} explicit SingleCharacter(char delim) : delim_(delim) {} StringPiece Find(const StringPiece &in) const { @@ -32,6 +33,8 @@ class SingleCharacter { class MultiCharacter { public: + MultiCharacter() {} + explicit MultiCharacter(const StringPiece &delimiter) : delimiter_(delimiter) {} StringPiece Find(const StringPiece &in) const { @@ -44,6 +47,7 @@ class MultiCharacter { class AnyCharacter { public: + AnyCharacter() {} explicit AnyCharacter(const StringPiece &chars) : chars_(chars) {} StringPiece Find(const StringPiece &in) const { @@ -56,6 +60,8 @@ class AnyCharacter { class AnyCharacterLast { public: + AnyCharacterLast() {} + explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {} StringPiece Find(const StringPiece &in) const { @@ -81,8 +87,8 @@ template class TokenIter : public boost::it return current_.data() != 0; } - static TokenIter end() { - return TokenIter(); + static TokenIter end() { + return TokenIter(); } private: @@ -100,8 +106,8 @@ template class TokenIter : public boost::it } while (SkipEmpty && current_.data() && current_.empty()); // Compiler should optimize this away if SkipEmpty is false. } - bool equal(const TokenIter &other) const { - return after_.data() == other.after_.data(); + bool equal(const TokenIter &other) const { + return current_.data() == other.current_.data(); } const StringPiece &dereference() const { -- cgit v1.2.3