diff options
Diffstat (limited to 'klm/search/vertex_generator.cc')
-rw-r--r-- | klm/search/vertex_generator.cc | 97 |
1 files changed, 54 insertions, 43 deletions
diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index d94e6e06..0945fe55 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -10,74 +10,85 @@ namespace search { VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { gen.root_.InitRoot(); - root_.under = &gen.root_; } namespace { + const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); -} // namespace -void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL))); - if (!got.second) { - // Found it already. - Final &exists = *got.first->second; - if (exists.Bound() < partial.score) { - exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - } - return; +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map<uint64_t, Trie> extend; +}; + +Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { + Trie &next = node.extend[added]; + if (!next.under) { + next.under = context.NewVertexNode(); + lm::ngram::ChartState &writing = next.under->MutableState(); + writing = state; + writing.left.full &= left_full && state.left.full; + next.under->MutableRightFull() = right_full && state.left.full; + writing.left.length = left; + writing.right.length = right; + node.under->AddExtend(next.under); } + return next; +} + +void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { + Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); + Final *child_out = final.Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = part->End(); + + starter.under->SetEnd(final); +} + +void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + unsigned char left = 0, right = 0; - Trie *node = &root_; + Trie *node = &root; while (true) { if (left == state.left.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false); for (; right < state.right.length; ++right) { - node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false); } break; } - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false); left++; if (right == state.right.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true); for (; left < state.left.length; ++left) { - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true); } break; } - node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false); right++; } - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - got.first->second = CompleteTransition(*node, state, note, partial); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); + CompleteTransition(context, *node, partial); } -VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { - VertexGenerator::Trie &next = node.extend[added]; - if (!next.under) { - next.under = context_.NewVertexNode(); - lm::ngram::ChartState &writing = next.under->MutableState(); - writing = state; - writing.left.full &= left_full && state.left.full; - next.under->MutableRightFull() = right_full && state.left.full; - writing.left.length = left; - writing.right.length = right; - node.under->AddExtend(next.under); - } - return next; -} +} // namespace -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) { - VertexNode &node = *starter.under; - assert(node.State().left.full == state.left.full); - assert(!node.End()); - Final *final = context_.NewFinal(); - final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - node.SetEnd(final); - return final; +void VertexGenerator::FinishedSearch() { + Trie root; + root.under = &gen_.root_; + for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, i->second); + } + root.under->SortAndSet(context_, NULL); } } // namespace search |