diff options
Diffstat (limited to 'klm/search/edge_generator.cc')
-rw-r--r-- | klm/search/edge_generator.cc | 53 |
1 files changed, 22 insertions, 31 deletions
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index d135899a..56239dfb 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -10,28 +10,15 @@ namespace search { -bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) { - from_ = &edge; - for (unsigned int i = 0; i < GetRule().Arity(); ++i) { - if (edge.GetVertex(i).RootPartial().Empty()) return false; - } - PartialEdge &root = *parent.MallocPartialEdge(); - root.score = GetRule().Bound(); - for (unsigned int i = 0; i < GetRule().Arity(); ++i) { +EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { +/* for (unsigned char i = 0; i < edge.Arity(); ++i) { root.nt[i] = edge.GetVertex(i).RootPartial(); - root.score += root.nt[i].Bound(); } - for (unsigned int i = GetRule().Arity(); i < 2; ++i) { + for (unsigned char i = edge.Arity(); i < 2; ++i) { root.nt[i] = kBlankPartialVertex; - } - for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) { - root.between[i] = GetRule().Lexical(i); - } - // wtf no clear method? - generate_ = Generate(); + }*/ generate_.push(&root); - top_ = root.score; - return true; + top_score_ = root.score; } namespace { @@ -78,13 +65,13 @@ template <class Model> float FastScore(const Context<Model> &context, unsigned c } // namespace -template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGenerator &parent) { +template <class Model> PartialEdge *EdgeGenerator::Pop(Context<Model> &context, boost::pool<> &partial_edge_pool) { assert(!generate_.empty()); PartialEdge &top = *generate_.top(); generate_.pop(); unsigned int victim = 0; unsigned char lowest_length = 255; - for (unsigned int i = 0; i != GetRule().Arity(); ++i) { + for (unsigned char i = 0; i != arity_; ++i) { if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { lowest_length = top.nt[i].Length(); victim = i; @@ -92,21 +79,21 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe } if (lowest_length == 255) { // All states report complete. - top.between[0].right = top.between[GetRule().Arity()].right; - parent.NewHypothesis(top.between[0], *from_, top); - top_ = generate_.empty() ? -kScoreInf : generate_.top()->score; - return !generate_.empty(); + top.between[0].right = top.between[arity_].right; + // Now top.between[0] is the full edge state. + top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; + return ⊤ } unsigned int stay = !victim; - PartialEdge &continuation = *parent.MallocPartialEdge(); + PartialEdge &continuation = *static_cast<PartialEdge*>(partial_edge_pool.malloc()); float old_bound = top.nt[victim].Bound(); // The alternate's score will change because alternate.nt[victim] changes. bool split = top.nt[victim].Split(continuation.nt[victim]); // top is now the alternate. continuation.nt[stay] = top.nt[stay]; - continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation); + continuation.score = FastScore(context, victim, arity_, top, continuation); // TODO: dedupe? generate_.push(&continuation); @@ -116,14 +103,18 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe // TODO: dedupe? generate_.push(&top); } else { - parent.FreePartialEdge(&top); + partial_edge_pool.free(&top); } - top_ = generate_.top()->score; - return true; + top_score_ = generate_.top()->score; + return NULL; } -template bool EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, VertexGenerator &parent); -template bool EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, VertexGenerator &parent); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, boost::pool<> &partial_edge_pool); } // namespace search |