From 9b99cb844e3e379b557ff8578df27893ce147f1a Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 14 Oct 2012 10:46:34 +0100 Subject: Update to faster but less cute search --- decoder/lazy.cc | 57 +++++++++++++++++---------------- klm/lm/fragment.cc | 37 +++++++++++++++++++++ klm/search/Jamfile | 2 +- klm/search/edge.hh | 31 +++--------------- klm/search/edge_generator.cc | 53 +++++++++++++----------------- klm/search/edge_generator.hh | 24 ++++++++------ klm/search/edge_queue.cc | 25 +++++++++++++++ klm/search/edge_queue.hh | 73 ++++++++++++++++++++++++++++++++++++++++++ klm/search/final.hh | 11 +++---- klm/search/note.hh | 12 +++++++ klm/search/rule.cc | 32 +++++++++--------- klm/search/rule.hh | 31 ++---------------- klm/search/vertex.hh | 45 +++++++++++--------------- klm/search/vertex_generator.cc | 32 +++++------------- klm/search/vertex_generator.hh | 35 +++++++------------- 15 files changed, 279 insertions(+), 221 deletions(-) create mode 100644 klm/lm/fragment.cc create mode 100644 klm/search/edge_queue.cc create mode 100644 klm/search/edge_queue.hh create mode 100644 klm/search/note.hh diff --git a/decoder/lazy.cc b/decoder/lazy.cc index c4138d7b..9dc657d6 100644 --- a/decoder/lazy.cc +++ b/decoder/lazy.cc @@ -8,6 +8,7 @@ #include "search/config.hh" #include "search/context.hh" #include "search/edge.hh" +#include "search/edge_queue.hh" #include "search/vertex.hh" #include "search/vertex_generator.hh" #include "util/exception.hh" @@ -75,7 +76,7 @@ template class Lazy : public LazyBase { void Search(unsigned int pop_limit, const Hypergraph &hg) const; private: - void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const; + unsigned char ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const; const Model m_; }; @@ -93,73 +94,73 @@ LazyBase *LazyBase::Load(const char *model_file, const std::vector &we } } -void PrintFinal(const Hypergraph &hg, const search::Edge *edge_base, const search::Final &final) { - const std::vector &words = hg.edges_[&final.From() - edge_base].rule_->e(); +void PrintFinal(const Hypergraph &hg, const search::Final &final) { + const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); boost::array::const_iterator child(final.Children().begin()); for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { if (*i > 0) { std::cout << TD::Convert(*i) << ' '; } else { - PrintFinal(hg, edge_base, **child++); + PrintFinal(hg, **child++); } } } template void Lazy::Search(unsigned int pop_limit, const Hypergraph &hg) const { boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); - boost::scoped_array out_edges(new search::Edge[hg.edges_.size()]); search::Config config(weights_, pop_limit); search::Context context(config, m_); for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { - search::Vertex &out_vertex = out_vertices[i]; + search::EdgeQueue queue(context.PopLimit()); const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; for (unsigned int j = 0; j < down_edges.size(); ++j) { unsigned int edge_index = down_edges[j]; - ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], out_edges[edge_index]); - out_vertex.Add(out_edges[edge_index]); + unsigned char arity = ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], queue.InitializeEdge()); + search::Note note; + note.vp = &hg.edges_[edge_index]; + if (arity != 255) queue.AddEdge(arity, note); } - out_vertex.FinishedAdding(); - search::VertexGenerator(context, out_vertex); + search::VertexGenerator vertex_gen(context, out_vertices[i]); + queue.Search(context, vertex_gen); } - search::PartialVertex top = out_vertices[hg.nodes_.size() - 2].RootPartial(); - if (top.Empty()) { - std::cout << "NO PATH FOUND"; + const search::Final *top = out_vertices[hg.nodes_.size() - 2].BestChild(); + if (!top) { + std::cout << "NO PATH FOUND" << std::endl; } else { - search::PartialVertex continuation; - while (!top.Complete()) { - top.Split(continuation); - top = continuation; - } - PrintFinal(hg, out_edges.get(), top.End()); - std::cout << "||| " << top.End().Bound() << std::endl; + PrintFinal(hg, *top); + std::cout << "||| " << top->Bound() << std::endl; } } -// TODO: get weights into here somehow. -template void Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const { +template unsigned char Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const { const std::vector &e = in.rule_->e(); std::vector words; unsigned int terminals = 0; + unsigned char nt = 0; for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { if (*word <= 0) { - out.Add(vertices[in.tail_nodes_[-*word]]); + out.nt[nt] = vertices[in.tail_nodes_[-*word]].RootPartial(); + if (out.nt[nt].Empty()) return 255; + ++nt; words.push_back(lm::kMaxWordIndex); } else { ++terminals; words.push_back(vocab_.FromCDec(*word)); } } + for (unsigned char fill = nt; fill < search::kMaxArity; ++fill) { + out.nt[nt] = search::kBlankPartialVertex; + } if (final) { words.push_back(m_.GetVocabulary().EndSentence()); } - float additive = in.rule_->GetFeatureValues().dot(cdec_weights_); - UTIL_THROW_IF(isnan(additive), util::Exception, "Bad dot product"); - additive -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; - - out.InitRule().Init(context, additive, words, final); + out.score = in.rule_->GetFeatureValues().dot(cdec_weights_); + out.score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; + out.score += search::ScoreRule(context, words, final, out.between); + return nt; } boost::scoped_ptr AwfulGlobalLazy; diff --git a/klm/lm/fragment.cc b/klm/lm/fragment.cc new file mode 100644 index 00000000..0267cd4e --- /dev/null +++ b/klm/lm/fragment.cc @@ -0,0 +1,37 @@ +#include "lm/binary_format.hh" +#include "lm/model.hh" +#include "lm/left.hh" +#include "util/tokenize_piece.hh" + +template void Query(const char *name) { + Model model(name); + std::string line; + lm::ngram::ChartState ignored; + while (getline(std::cin, line)) { + lm::ngram::RuleScore scorer(model, ignored); + for (util::TokenIter i(line, ' '); i; ++i) { + scorer.Terminal(model.GetVocabulary().Index(*i)); + } + std::cout << scorer.Finish() << '\n'; + } +} + +int main(int argc, char *argv[]) { + if (argc != 2) { + std::cerr << "Expected model file name." << std::endl; + return 1; + } + const char *name = argv[1]; + lm::ngram::ModelType model_type = lm::ngram::PROBING; + lm::ngram::RecognizeBinary(name, model_type); + switch (model_type) { + case lm::ngram::PROBING: + Query(name); + break; + case lm::ngram::REST_PROBING: + Query(name); + break; + default: + std::cerr << "Model type not supported yet." << std::endl; + } +} diff --git a/klm/search/Jamfile b/klm/search/Jamfile index ac47c249..e8b14363 100644 --- a/klm/search/Jamfile +++ b/klm/search/Jamfile @@ -1,4 +1,4 @@ -lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil : : : .. ; +lib search : weights.cc vertex.cc vertex_generator.cc edge_queue.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : .. ; import testing ; diff --git a/klm/search/edge.hh b/klm/search/edge.hh index 4d2a5cbf..77ab0ade 100644 --- a/klm/search/edge.hh +++ b/klm/search/edge.hh @@ -11,33 +11,6 @@ namespace search { -class Edge { - public: - Edge() { - end_to_ = to_; - } - - Rule &InitRule() { return rule_; } - - void Add(Vertex &vertex) { - assert(end_to_ - to_ < kMaxArity); - *(end_to_++) = &vertex; - } - - const Vertex &GetVertex(std::size_t index) const { - return *to_[index]; - } - - const Rule &GetRule() const { return rule_; } - - private: - // Rule and pointers to rule arguments. - Rule rule_; - - Vertex *to_[kMaxArity]; - Vertex **end_to_; -}; - struct PartialEdge { Score score; // Terminals @@ -45,6 +18,10 @@ struct PartialEdge { // Non-terminals PartialVertex nt[kMaxArity]; + const lm::ngram::ChartState &CompletedState() const { + return between[0]; + } + bool operator<(const PartialEdge &other) const { return score < other.score; } diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index d135899a..56239dfb 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -10,28 +10,15 @@ namespace search { -bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) { - from_ = &edge; - for (unsigned int i = 0; i < GetRule().Arity(); ++i) { - if (edge.GetVertex(i).RootPartial().Empty()) return false; - } - PartialEdge &root = *parent.MallocPartialEdge(); - root.score = GetRule().Bound(); - for (unsigned int i = 0; i < GetRule().Arity(); ++i) { +EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { +/* for (unsigned char i = 0; i < edge.Arity(); ++i) { root.nt[i] = edge.GetVertex(i).RootPartial(); - root.score += root.nt[i].Bound(); } - for (unsigned int i = GetRule().Arity(); i < 2; ++i) { + for (unsigned char i = edge.Arity(); i < 2; ++i) { root.nt[i] = kBlankPartialVertex; - } - for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) { - root.between[i] = GetRule().Lexical(i); - } - // wtf no clear method? - generate_ = Generate(); + }*/ generate_.push(&root); - top_ = root.score; - return true; + top_score_ = root.score; } namespace { @@ -78,13 +65,13 @@ template float FastScore(const Context &context, unsigned c } // namespace -template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent) { +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool) { assert(!generate_.empty()); PartialEdge &top = *generate_.top(); generate_.pop(); unsigned int victim = 0; unsigned char lowest_length = 255; - for (unsigned int i = 0; i != GetRule().Arity(); ++i) { + for (unsigned char i = 0; i != arity_; ++i) { if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { lowest_length = top.nt[i].Length(); victim = i; @@ -92,21 +79,21 @@ template bool EdgeGenerator::Pop(Context &context, VertexGe } if (lowest_length == 255) { // All states report complete. - top.between[0].right = top.between[GetRule().Arity()].right; - parent.NewHypothesis(top.between[0], *from_, top); - top_ = generate_.empty() ? -kScoreInf : generate_.top()->score; - return !generate_.empty(); + top.between[0].right = top.between[arity_].right; + // Now top.between[0] is the full edge state. + top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; + return ⊤ } unsigned int stay = !victim; - PartialEdge &continuation = *parent.MallocPartialEdge(); + PartialEdge &continuation = *static_cast(partial_edge_pool.malloc()); float old_bound = top.nt[victim].Bound(); // The alternate's score will change because alternate.nt[victim] changes. bool split = top.nt[victim].Split(continuation.nt[victim]); // top is now the alternate. continuation.nt[stay] = top.nt[stay]; - continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation); + continuation.score = FastScore(context, victim, arity_, top, continuation); // TODO: dedupe? generate_.push(&continuation); @@ -116,14 +103,18 @@ template bool EdgeGenerator::Pop(Context &context, VertexGe // TODO: dedupe? generate_.push(&top); } else { - parent.FreePartialEdge(&top); + partial_edge_pool.free(&top); } - top_ = generate_.top()->score; - return true; + top_score_ = generate_.top()->score; + return NULL; } -template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent); -template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); } // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index e306dc61..875ccc5e 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -2,7 +2,9 @@ #define SEARCH_EDGE_GENERATOR__ #include "search/edge.hh" +#include "search/note.hh" +#include #include #include @@ -28,26 +30,28 @@ struct PartialEdgePointerLess : std::binary_function bool Pop(Context &context, VertexGenerator &parent); + Note GetNote() const { + return note_; + } + + // Pop. If there's a complete hypothesis, return it. Otherwise return NULL. + template PartialEdge *Pop(Context &context, boost::pool<> &partial_edge_pool); private: - const Rule &GetRule() const { - return from_->GetRule(); - } + Score top_score_; - Score top_; + unsigned char arity_; typedef std::priority_queue, PartialEdgePointerLess> Generate; Generate generate_; - Edge *from_; + Note note_; }; } // namespace search diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc new file mode 100644 index 00000000..e3ae6ebf --- /dev/null +++ b/klm/search/edge_queue.cc @@ -0,0 +1,25 @@ +#include "search/edge_queue.hh" + +#include "lm/left.hh" +#include "search/context.hh" + +#include + +namespace search { + +EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) { + take_ = static_cast(partial_edge_pool_.malloc()); +} + +/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) { + // Ignore empty edges. + for (unsigned char i = 0; i < edge.Arity(); ++i) { + PartialVertex root(edge.GetVertex(i).RootPartial()); + if (root.Empty()) return; + total_score += root.Bound(); + } + PartialEdge &allocated = *static_cast(partial_edge_pool_.malloc()); + allocated.score = total_score; +}*/ + +} // namespace search diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh new file mode 100644 index 00000000..187eaed7 --- /dev/null +++ b/klm/search/edge_queue.hh @@ -0,0 +1,73 @@ +#ifndef SEARCH_EDGE_QUEUE__ +#define SEARCH_EDGE_QUEUE__ + +#include "search/edge.hh" +#include "search/edge_generator.hh" +#include "search/note.hh" + +#include +#include + +#include + +namespace search { + +template class Context; + +class EdgeQueue { + public: + explicit EdgeQueue(unsigned int pop_limit_hint); + + PartialEdge &InitializeEdge() { + return *take_; + } + + void AddEdge(unsigned char arity, Note note) { + generate_.push(edge_pool_.construct(*take_, arity, note)); + take_ = static_cast(partial_edge_pool_.malloc()); + } + + bool Empty() const { return generate_.empty(); } + + /* Generate hypotheses and send them to output. Normally, output is a + * VertexGenerator, but the decoder may want to route edges to different + * vertices i.e. if they have different LHS non-terminal labels. + */ + template void Search(Context &context, Output &output) { + int to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + EdgeGenerator *top = generate_.top(); + generate_.pop(); + PartialEdge *ret = top->Pop(context, partial_edge_pool_); + if (ret) { + output.NewHypothesis(*ret, top->GetNote()); + --to_pop; + if (top->TopScore() != -kScoreInf) { + generate_.push(top); + } + } else { + generate_.push(top); + } + } + output.FinishedSearch(); + } + + private: + boost::object_pool edge_pool_; + + struct LessByTopScore : public std::binary_function { + bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { + return first->TopScore() < second->TopScore(); + } + }; + + typedef std::priority_queue, LessByTopScore> Generate; + Generate generate_; + + boost::pool<> partial_edge_pool_; + + PartialEdge *take_; +}; + +} // namespace search +#endif // SEARCH_EDGE_QUEUE__ diff --git a/klm/search/final.hh b/klm/search/final.hh index 823b8c1a..1b3092ac 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -2,35 +2,34 @@ #define SEARCH_FINAL__ #include "search/arity.hh" +#include "search/note.hh" #include "search/types.hh" #include namespace search { -class Edge; - class Final { public: typedef boost::array ChildArray; - void Reset(Score bound, const Edge &from, const Final &left, const Final &right) { + void Reset(Score bound, Note note, const Final &left, const Final &right) { bound_ = bound; - from_ = &from; + note_ = note; children_[0] = &left; children_[1] = &right; } const ChildArray &Children() const { return children_; } - const Edge &From() const { return *from_; } + Note GetNote() const { return note_; } Score Bound() const { return bound_; } private: Score bound_; - const Edge *from_; + Note note_; ChildArray children_; }; diff --git a/klm/search/note.hh b/klm/search/note.hh new file mode 100644 index 00000000..50bed06e --- /dev/null +++ b/klm/search/note.hh @@ -0,0 +1,12 @@ +#ifndef SEARCH_NOTE__ +#define SEARCH_NOTE__ + +namespace search { + +union Note { + const void *vp; +}; + +} // namespace search + +#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 0a941527..5b00207e 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,35 +9,35 @@ namespace search { -template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos) { - additive_ = additive; - Score lm_score = 0.0; - lexical_.clear(); - const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); - +template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing) { + unsigned int oov_count = 0; + float prob = 0.0; + const Model &model = context.LanguageModel(); + const lm::WordIndex oov = model.GetVocabulary().NotFound(); for (std::vector::const_iterator word = words.begin(); ; ++word) { - lexical_.resize(lexical_.size() + 1); - lm::ngram::RuleScore scorer(context.LanguageModel(), lexical_.back()); + lm::ngram::RuleScore scorer(model, *(writing++)); // TODO: optimize if (prepend_bos && (word == words.begin())) { scorer.BeginSentence(); } for (; ; ++word) { if (word == words.end()) { - lm_score += scorer.Finish(); - bound_ = additive_ + context.GetWeights().LM() * lm_score; - arity_ = lexical_.size() - 1; - return; + prob += scorer.Finish(); + return static_cast(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); } if (*word == kNonTerminal) break; - if (*word == oov) additive_ += context.GetWeights().OOV(); + if (*word == oov) ++oov_count; scorer.Terminal(*word); } - lm_score += scorer.Finish(); + prob += scorer.Finish(); } } -template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); -template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 920c64a7..0ce2794d 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -3,44 +3,17 @@ #include "lm/left.hh" #include "lm/word_index.hh" -#include "search/arity.hh" #include "search/types.hh" -#include - -#include #include namespace search { template class Context; -class Rule { - public: - Rule() : arity_(0) {} - - static const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - - // Use kNonTerminal for non-terminals. - template void Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); - - Score Bound() const { return bound_; } - - Score Additive() const { return additive_; } - - unsigned int Arity() const { return arity_; } - - const lm::ngram::ChartState &Lexical(unsigned int index) const { - return lexical_[index]; - } - - private: - Score bound_, additive_; - - unsigned int arity_; +const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - std::vector lexical_; -}; +template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *state_out); } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index 7ef29efc..e1a9ad11 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -16,8 +16,6 @@ namespace search { class ContextBase; -class Edge; - class VertexNode { public: VertexNode() : end_(NULL) {} @@ -103,6 +101,10 @@ class PartialVertex { unsigned char Length() const { return back_->Length(); } + bool HasAlternative() const { + return index_ + 1 < back_->Size(); + } + // Split into continuation and alternative, rendering this the alternative. bool Split(PartialVertex &continuation) { assert(!Complete()); @@ -128,35 +130,26 @@ extern PartialVertex kBlankPartialVertex; class Vertex { public: - Vertex() -#ifdef DEBUG - : finished_adding_(false) -#endif - {} - - void Add(Edge &edge) { -#ifdef DEBUG - assert(!finished_adding_); -#endif - edges_.push_back(&edge); - } - - void FinishedAdding() { -#ifdef DEBUG - assert(!finished_adding_); - finished_adding_ = true; -#endif - } + Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } + const Final *BestChild() const { + PartialVertex top(RootPartial()); + if (top.Empty()) { + return NULL; + } else { + PartialVertex continuation; + while (!top.Complete()) { + top.Split(continuation); + top = continuation; + } + return &top.End(); + } + } + private: friend class VertexGenerator; - std::vector edges_; - -#ifdef DEBUG - bool finished_adding_; -#endif VertexNode root_; }; diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 78948c97..d94e6e06 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -2,45 +2,30 @@ #include "lm/left.hh" #include "search/context.hh" +#include "search/edge.hh" #include namespace search { -template VertexGenerator::VertexGenerator(Context &context, Vertex &gen) : context_(context), edges_(gen.edges_.size()), partial_edge_pool_(sizeof(PartialEdge), context.PopLimit() * 2) { - for (std::size_t i = 0; i < gen.edges_.size(); ++i) { - if (edges_[i].Init(*gen.edges_[i], *this)) - generate_.push(&edges_[i]); - } +VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { gen.root_.InitRoot(); root_.under = &gen.root_; - to_pop_ = context.PopLimit(); - while (to_pop_ > 0 && !generate_.empty()) { - EdgeGenerator *top = generate_.top(); - generate_.pop(); - if (top->Pop(context, *this)) { - generate_.push(top); - } - } - gen.root_.SortAndSet(context, NULL); } -template VertexGenerator::VertexGenerator(Context &context, Vertex &gen); -template VertexGenerator::VertexGenerator(Context &context, Vertex &gen); - namespace { const uint64_t kCompleteAdd = static_cast(-1); } // namespace -void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { + const lm::ngram::ChartState &state = partial.CompletedState(); std::pair got(existing_.insert(std::pair(hash_value(state), NULL))); if (!got.second) { // Found it already. Final &exists = *got.first->second; if (exists.Bound() < partial.score) { - exists.Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); + exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); } - --to_pop_; return; } unsigned char left = 0, right = 0; @@ -67,8 +52,7 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed } node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - got.first->second = CompleteTransition(*node, state, from, partial); - --to_pop_; + got.first->second = CompleteTransition(*node, state, note, partial); } VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { @@ -86,12 +70,12 @@ VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node return next; } -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) { VertexNode &node = *starter.under; assert(node.State().left.full == state.left.full); assert(!node.End()); Final *final = context_.NewFinal(); - final->Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); + final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); node.SetEnd(final); return final; } diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 8cdf1420..6b98da3e 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -1,10 +1,9 @@ #ifndef SEARCH_VERTEX_GENERATOR__ #define SEARCH_VERTEX_GENERATOR__ -#include "search/edge.hh" -#include "search/edge_generator.hh" +#include "search/note.hh" +#include "search/vertex.hh" -#include #include #include @@ -17,18 +16,21 @@ class ChartState; namespace search { -template class Context; class ContextBase; class Final; +struct PartialEdge; class VertexGenerator { public: - template VertexGenerator(Context &context, Vertex &gen); + VertexGenerator(ContextBase &context, Vertex &gen); - PartialEdge *MallocPartialEdge() { return static_cast(partial_edge_pool_.malloc()); } - void FreePartialEdge(PartialEdge *value) { partial_edge_pool_.free(value); } + void NewHypothesis(const PartialEdge &partial, Note note); - void NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + void FinishedSearch() { + root_.under->SortAndSet(context_, NULL); + } + + const Vertex &Generating() const { return gen_; } private: // Parallel structure to VertexNode. @@ -41,29 +43,16 @@ class VertexGenerator { Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); - Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial); ContextBase &context_; - std::vector edges_; - - struct LessByTop : public std::binary_function { - bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { - return first->Top() < second->Top(); - } - }; - - typedef std::priority_queue, LessByTop> Generate; - Generate generate_; + Vertex &gen_; Trie root_; typedef boost::unordered_map Existing; Existing existing_; - - int to_pop_; - - boost::pool<> partial_edge_pool_; }; } // namespace search -- cgit v1.2.3