summaryrefslogtreecommitdiff
path: root/klm/search/vertex_generator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/search/vertex_generator.cc')
-rw-r--r--klm/search/vertex_generator.cc97
1 files changed, 54 insertions, 43 deletions
diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc
index d94e6e06..0945fe55 100644
--- a/klm/search/vertex_generator.cc
+++ b/klm/search/vertex_generator.cc
@@ -10,74 +10,85 @@ namespace search {
VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) {
gen.root_.InitRoot();
- root_.under = &gen.root_;
}
namespace {
+
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
-} // namespace
-void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) {
- const lm::ngram::ChartState &state = partial.CompletedState();
- std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL)));
- if (!got.second) {
- // Found it already.
- Final &exists = *got.first->second;
- if (exists.Bound() < partial.score) {
- exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());
- }
- return;
+// Parallel structure to VertexNode.
+struct Trie {
+ Trie() : under(NULL) {}
+
+ VertexNode *under;
+ boost::unordered_map<uint64_t, Trie> extend;
+};
+
+Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
+ Trie &next = node.extend[added];
+ if (!next.under) {
+ next.under = context.NewVertexNode();
+ lm::ngram::ChartState &writing = next.under->MutableState();
+ writing = state;
+ writing.left.full &= left_full && state.left.full;
+ next.under->MutableRightFull() = right_full && state.left.full;
+ writing.left.length = left;
+ writing.right.length = right;
+ node.under->AddExtend(next.under);
}
+ return next;
+}
+
+void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) {
+ Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote());
+ Final *child_out = final.Children();
+ const PartialVertex *part = partial.NT();
+ const PartialVertex *const part_end_loop = part + partial.GetArity();
+ for (; part != part_end_loop; ++part, ++child_out)
+ *child_out = part->End();
+
+ starter.under->SetEnd(final);
+}
+
+void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
+ const lm::ngram::ChartState &state = partial.CompletedState();
+
unsigned char left = 0, right = 0;
- Trie *node = &root_;
+ Trie *node = &root;
while (true) {
if (left == state.left.length) {
- node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false);
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false);
for (; right < state.right.length; ++right) {
- node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false);
+ node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false);
}
break;
}
- node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false);
+ node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false);
left++;
if (right == state.right.length) {
- node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true);
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true);
for (; left < state.left.length; ++left) {
- node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true);
+ node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true);
}
break;
}
- node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false);
+ node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false);
right++;
}
- node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
- got.first->second = CompleteTransition(*node, state, note, partial);
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
+ CompleteTransition(context, *node, partial);
}
-VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
- VertexGenerator::Trie &next = node.extend[added];
- if (!next.under) {
- next.under = context_.NewVertexNode();
- lm::ngram::ChartState &writing = next.under->MutableState();
- writing = state;
- writing.left.full &= left_full && state.left.full;
- next.under->MutableRightFull() = right_full && state.left.full;
- writing.left.length = left;
- writing.right.length = right;
- node.under->AddExtend(next.under);
- }
- return next;
-}
+} // namespace
-Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) {
- VertexNode &node = *starter.under;
- assert(node.State().left.full == state.left.full);
- assert(!node.End());
- Final *final = context_.NewFinal();
- final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());
- node.SetEnd(final);
- return final;
+void VertexGenerator::FinishedSearch() {
+ Trie root;
+ root.under = &gen_.root_;
+ for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) {
+ AddHypothesis(context_, root, i->second);
+ }
+ root.under->SortAndSet(context_, NULL);
}
} // namespace search