From 9b99cb844e3e379b557ff8578df27893ce147f1a Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 14 Oct 2012 10:46:34 +0100 Subject: Update to faster but less cute search --- klm/search/Jamfile | 2 +- klm/search/edge.hh | 31 +++--------------- klm/search/edge_generator.cc | 53 +++++++++++++----------------- klm/search/edge_generator.hh | 24 ++++++++------ klm/search/edge_queue.cc | 25 +++++++++++++++ klm/search/edge_queue.hh | 73 ++++++++++++++++++++++++++++++++++++++++++ klm/search/final.hh | 11 +++---- klm/search/note.hh | 12 +++++++ klm/search/rule.cc | 32 +++++++++--------- klm/search/rule.hh | 31 ++---------------- klm/search/vertex.hh | 45 +++++++++++--------------- klm/search/vertex_generator.cc | 32 +++++------------- klm/search/vertex_generator.hh | 35 +++++++------------- 13 files changed, 213 insertions(+), 193 deletions(-) create mode 100644 klm/search/edge_queue.cc create mode 100644 klm/search/edge_queue.hh create mode 100644 klm/search/note.hh (limited to 'klm/search') diff --git a/klm/search/Jamfile b/klm/search/Jamfile index ac47c249..e8b14363 100644 --- a/klm/search/Jamfile +++ b/klm/search/Jamfile @@ -1,4 +1,4 @@ -lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil : : : .. ; +lib search : weights.cc vertex.cc vertex_generator.cc edge_queue.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : .. ; import testing ; diff --git a/klm/search/edge.hh b/klm/search/edge.hh index 4d2a5cbf..77ab0ade 100644 --- a/klm/search/edge.hh +++ b/klm/search/edge.hh @@ -11,33 +11,6 @@ namespace search { -class Edge { - public: - Edge() { - end_to_ = to_; - } - - Rule &InitRule() { return rule_; } - - void Add(Vertex &vertex) { - assert(end_to_ - to_ < kMaxArity); - *(end_to_++) = &vertex; - } - - const Vertex &GetVertex(std::size_t index) const { - return *to_[index]; - } - - const Rule &GetRule() const { return rule_; } - - private: - // Rule and pointers to rule arguments. - Rule rule_; - - Vertex *to_[kMaxArity]; - Vertex **end_to_; -}; - struct PartialEdge { Score score; // Terminals @@ -45,6 +18,10 @@ struct PartialEdge { // Non-terminals PartialVertex nt[kMaxArity]; + const lm::ngram::ChartState &CompletedState() const { + return between[0]; + } + bool operator<(const PartialEdge &other) const { return score < other.score; } diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index d135899a..56239dfb 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -10,28 +10,15 @@ namespace search { -bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) { - from_ = &edge; - for (unsigned int i = 0; i < GetRule().Arity(); ++i) { - if (edge.GetVertex(i).RootPartial().Empty()) return false; - } - PartialEdge &root = *parent.MallocPartialEdge(); - root.score = GetRule().Bound(); - for (unsigned int i = 0; i < GetRule().Arity(); ++i) { +EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { +/* for (unsigned char i = 0; i < edge.Arity(); ++i) { root.nt[i] = edge.GetVertex(i).RootPartial(); - root.score += root.nt[i].Bound(); } - for (unsigned int i = GetRule().Arity(); i < 2; ++i) { + for (unsigned char i = edge.Arity(); i < 2; ++i) { root.nt[i] = kBlankPartialVertex; - } - for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) { - root.between[i] = GetRule().Lexical(i); - } - // wtf no clear method? - generate_ = Generate(); + }*/ generate_.push(&root); - top_ = root.score; - return true; + top_score_ = root.score; } namespace { @@ -78,13 +65,13 @@ template float FastScore(const Context &context, unsigned c } // namespace -template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent) { +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool) { assert(!generate_.empty()); PartialEdge &top = *generate_.top(); generate_.pop(); unsigned int victim = 0; unsigned char lowest_length = 255; - for (unsigned int i = 0; i != GetRule().Arity(); ++i) { + for (unsigned char i = 0; i != arity_; ++i) { if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { lowest_length = top.nt[i].Length(); victim = i; @@ -92,21 +79,21 @@ template bool EdgeGenerator::Pop(Context &context, VertexGe } if (lowest_length == 255) { // All states report complete. - top.between[0].right = top.between[GetRule().Arity()].right; - parent.NewHypothesis(top.between[0], *from_, top); - top_ = generate_.empty() ? -kScoreInf : generate_.top()->score; - return !generate_.empty(); + top.between[0].right = top.between[arity_].right; + // Now top.between[0] is the full edge state. + top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; + return ⊤ } unsigned int stay = !victim; - PartialEdge &continuation = *parent.MallocPartialEdge(); + PartialEdge &continuation = *static_cast(partial_edge_pool.malloc()); float old_bound = top.nt[victim].Bound(); // The alternate's score will change because alternate.nt[victim] changes. bool split = top.nt[victim].Split(continuation.nt[victim]); // top is now the alternate. continuation.nt[stay] = top.nt[stay]; - continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation); + continuation.score = FastScore(context, victim, arity_, top, continuation); // TODO: dedupe? generate_.push(&continuation); @@ -116,14 +103,18 @@ template bool EdgeGenerator::Pop(Context &context, VertexGe // TODO: dedupe? generate_.push(&top); } else { - parent.FreePartialEdge(&top); + partial_edge_pool.free(&top); } - top_ = generate_.top()->score; - return true; + top_score_ = generate_.top()->score; + return NULL; } -template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent); -template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); } // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index e306dc61..875ccc5e 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -2,7 +2,9 @@ #define SEARCH_EDGE_GENERATOR__ #include "search/edge.hh" +#include "search/note.hh" +#include #include #include @@ -28,26 +30,28 @@ struct PartialEdgePointerLess : std::binary_function bool Pop(Context &context, VertexGenerator &parent); + Note GetNote() const { + return note_; + } + + // Pop. If there's a complete hypothesis, return it. Otherwise return NULL. + template PartialEdge *Pop(Context &context, boost::pool<> &partial_edge_pool); private: - const Rule &GetRule() const { - return from_->GetRule(); - } + Score top_score_; - Score top_; + unsigned char arity_; typedef std::priority_queue, PartialEdgePointerLess> Generate; Generate generate_; - Edge *from_; + Note note_; }; } // namespace search diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc new file mode 100644 index 00000000..e3ae6ebf --- /dev/null +++ b/klm/search/edge_queue.cc @@ -0,0 +1,25 @@ +#include "search/edge_queue.hh" + +#include "lm/left.hh" +#include "search/context.hh" + +#include + +namespace search { + +EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) { + take_ = static_cast(partial_edge_pool_.malloc()); +} + +/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) { + // Ignore empty edges. + for (unsigned char i = 0; i < edge.Arity(); ++i) { + PartialVertex root(edge.GetVertex(i).RootPartial()); + if (root.Empty()) return; + total_score += root.Bound(); + } + PartialEdge &allocated = *static_cast(partial_edge_pool_.malloc()); + allocated.score = total_score; +}*/ + +} // namespace search diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh new file mode 100644 index 00000000..187eaed7 --- /dev/null +++ b/klm/search/edge_queue.hh @@ -0,0 +1,73 @@ +#ifndef SEARCH_EDGE_QUEUE__ +#define SEARCH_EDGE_QUEUE__ + +#include "search/edge.hh" +#include "search/edge_generator.hh" +#include "search/note.hh" + +#include +#include + +#include + +namespace search { + +template class Context; + +class EdgeQueue { + public: + explicit EdgeQueue(unsigned int pop_limit_hint); + + PartialEdge &InitializeEdge() { + return *take_; + } + + void AddEdge(unsigned char arity, Note note) { + generate_.push(edge_pool_.construct(*take_, arity, note)); + take_ = static_cast(partial_edge_pool_.malloc()); + } + + bool Empty() const { return generate_.empty(); } + + /* Generate hypotheses and send them to output. Normally, output is a + * VertexGenerator, but the decoder may want to route edges to different + * vertices i.e. if they have different LHS non-terminal labels. + */ + template void Search(Context &context, Output &output) { + int to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + EdgeGenerator *top = generate_.top(); + generate_.pop(); + PartialEdge *ret = top->Pop(context, partial_edge_pool_); + if (ret) { + output.NewHypothesis(*ret, top->GetNote()); + --to_pop; + if (top->TopScore() != -kScoreInf) { + generate_.push(top); + } + } else { + generate_.push(top); + } + } + output.FinishedSearch(); + } + + private: + boost::object_pool edge_pool_; + + struct LessByTopScore : public std::binary_function { + bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { + return first->TopScore() < second->TopScore(); + } + }; + + typedef std::priority_queue, LessByTopScore> Generate; + Generate generate_; + + boost::pool<> partial_edge_pool_; + + PartialEdge *take_; +}; + +} // namespace search +#endif // SEARCH_EDGE_QUEUE__ diff --git a/klm/search/final.hh b/klm/search/final.hh index 823b8c1a..1b3092ac 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -2,35 +2,34 @@ #define SEARCH_FINAL__ #include "search/arity.hh" +#include "search/note.hh" #include "search/types.hh" #include namespace search { -class Edge; - class Final { public: typedef boost::array ChildArray; - void Reset(Score bound, const Edge &from, const Final &left, const Final &right) { + void Reset(Score bound, Note note, const Final &left, const Final &right) { bound_ = bound; - from_ = &from; + note_ = note; children_[0] = &left; children_[1] = &right; } const ChildArray &Children() const { return children_; } - const Edge &From() const { return *from_; } + Note GetNote() const { return note_; } Score Bound() const { return bound_; } private: Score bound_; - const Edge *from_; + Note note_; ChildArray children_; }; diff --git a/klm/search/note.hh b/klm/search/note.hh new file mode 100644 index 00000000..50bed06e --- /dev/null +++ b/klm/search/note.hh @@ -0,0 +1,12 @@ +#ifndef SEARCH_NOTE__ +#define SEARCH_NOTE__ + +namespace search { + +union Note { + const void *vp; +}; + +} // namespace search + +#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 0a941527..5b00207e 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,35 +9,35 @@ namespace search { -template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos) { - additive_ = additive; - Score lm_score = 0.0; - lexical_.clear(); - const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); - +template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing) { + unsigned int oov_count = 0; + float prob = 0.0; + const Model &model = context.LanguageModel(); + const lm::WordIndex oov = model.GetVocabulary().NotFound(); for (std::vector::const_iterator word = words.begin(); ; ++word) { - lexical_.resize(lexical_.size() + 1); - lm::ngram::RuleScore scorer(context.LanguageModel(), lexical_.back()); + lm::ngram::RuleScore scorer(model, *(writing++)); // TODO: optimize if (prepend_bos && (word == words.begin())) { scorer.BeginSentence(); } for (; ; ++word) { if (word == words.end()) { - lm_score += scorer.Finish(); - bound_ = additive_ + context.GetWeights().LM() * lm_score; - arity_ = lexical_.size() - 1; - return; + prob += scorer.Finish(); + return static_cast(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); } if (*word == kNonTerminal) break; - if (*word == oov) additive_ += context.GetWeights().OOV(); + if (*word == oov) ++oov_count; scorer.Terminal(*word); } - lm_score += scorer.Finish(); + prob += scorer.Finish(); } } -template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); -template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 920c64a7..0ce2794d 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -3,44 +3,17 @@ #include "lm/left.hh" #include "lm/word_index.hh" -#include "search/arity.hh" #include "search/types.hh" -#include - -#include #include namespace search { template class Context; -class Rule { - public: - Rule() : arity_(0) {} - - static const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - - // Use kNonTerminal for non-terminals. - template void Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); - - Score Bound() const { return bound_; } - - Score Additive() const { return additive_; } - - unsigned int Arity() const { return arity_; } - - const lm::ngram::ChartState &Lexical(unsigned int index) const { - return lexical_[index]; - } - - private: - Score bound_, additive_; - - unsigned int arity_; +const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - std::vector lexical_; -}; +template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *state_out); } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index 7ef29efc..e1a9ad11 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -16,8 +16,6 @@ namespace search { class ContextBase; -class Edge; - class VertexNode { public: VertexNode() : end_(NULL) {} @@ -103,6 +101,10 @@ class PartialVertex { unsigned char Length() const { return back_->Length(); } + bool HasAlternative() const { + return index_ + 1 < back_->Size(); + } + // Split into continuation and alternative, rendering this the alternative. bool Split(PartialVertex &continuation) { assert(!Complete()); @@ -128,35 +130,26 @@ extern PartialVertex kBlankPartialVertex; class Vertex { public: - Vertex() -#ifdef DEBUG - : finished_adding_(false) -#endif - {} - - void Add(Edge &edge) { -#ifdef DEBUG - assert(!finished_adding_); -#endif - edges_.push_back(&edge); - } - - void FinishedAdding() { -#ifdef DEBUG - assert(!finished_adding_); - finished_adding_ = true; -#endif - } + Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } + const Final *BestChild() const { + PartialVertex top(RootPartial()); + if (top.Empty()) { + return NULL; + } else { + PartialVertex continuation; + while (!top.Complete()) { + top.Split(continuation); + top = continuation; + } + return &top.End(); + } + } + private: friend class VertexGenerator; - std::vector edges_; - -#ifdef DEBUG - bool finished_adding_; -#endif VertexNode root_; }; diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 78948c97..d94e6e06 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -2,45 +2,30 @@ #include "lm/left.hh" #include "search/context.hh" +#include "search/edge.hh" #include namespace search { -template VertexGenerator::VertexGenerator(Context &context, Vertex &gen) : context_(context), edges_(gen.edges_.size()), partial_edge_pool_(sizeof(PartialEdge), context.PopLimit() * 2) { - for (std::size_t i = 0; i < gen.edges_.size(); ++i) { - if (edges_[i].Init(*gen.edges_[i], *this)) - generate_.push(&edges_[i]); - } +VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { gen.root_.InitRoot(); root_.under = &gen.root_; - to_pop_ = context.PopLimit(); - while (to_pop_ > 0 && !generate_.empty()) { - EdgeGenerator *top = generate_.top(); - generate_.pop(); - if (top->Pop(context, *this)) { - generate_.push(top); - } - } - gen.root_.SortAndSet(context, NULL); } -template VertexGenerator::VertexGenerator(Context &context, Vertex &gen); -template VertexGenerator::VertexGenerator(Context &context, Vertex &gen); - namespace { const uint64_t kCompleteAdd = static_cast(-1); } // namespace -void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { + const lm::ngram::ChartState &state = partial.CompletedState(); std::pair got(existing_.insert(std::pair(hash_value(state), NULL))); if (!got.second) { // Found it already. Final &exists = *got.first->second; if (exists.Bound() < partial.score) { - exists.Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); + exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); } - --to_pop_; return; } unsigned char left = 0, right = 0; @@ -67,8 +52,7 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed } node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - got.first->second = CompleteTransition(*node, state, from, partial); - --to_pop_; + got.first->second = CompleteTransition(*node, state, note, partial); } VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { @@ -86,12 +70,12 @@ VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node return next; } -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) { VertexNode &node = *starter.under; assert(node.State().left.full == state.left.full); assert(!node.End()); Final *final = context_.NewFinal(); - final->Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); + final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); node.SetEnd(final); return final; } diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 8cdf1420..6b98da3e 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -1,10 +1,9 @@ #ifndef SEARCH_VERTEX_GENERATOR__ #define SEARCH_VERTEX_GENERATOR__ -#include "search/edge.hh" -#include "search/edge_generator.hh" +#include "search/note.hh" +#include "search/vertex.hh" -#include #include #include @@ -17,18 +16,21 @@ class ChartState; namespace search { -template class Context; class ContextBase; class Final; +struct PartialEdge; class VertexGenerator { public: - template VertexGenerator(Context &context, Vertex &gen); + VertexGenerator(ContextBase &context, Vertex &gen); - PartialEdge *MallocPartialEdge() { return static_cast(partial_edge_pool_.malloc()); } - void FreePartialEdge(PartialEdge *value) { partial_edge_pool_.free(value); } + void NewHypothesis(const PartialEdge &partial, Note note); - void NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + void FinishedSearch() { + root_.under->SortAndSet(context_, NULL); + } + + const Vertex &Generating() const { return gen_; } private: // Parallel structure to VertexNode. @@ -41,29 +43,16 @@ class VertexGenerator { Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); - Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial); ContextBase &context_; - std::vector edges_; - - struct LessByTop : public std::binary_function { - bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { - return first->Top() < second->Top(); - } - }; - - typedef std::priority_queue, LessByTop> Generate; - Generate generate_; + Vertex &gen_; Trie root_; typedef boost::unordered_map Existing; Existing existing_; - - int to_pop_; - - boost::pool<> partial_edge_pool_; }; } // namespace search -- cgit v1.2.3