summaryrefslogtreecommitdiff
path: root/klm/search/edge_generator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/search/edge_generator.cc')
-rw-r--r--klm/search/edge_generator.cc129
1 files changed, 129 insertions, 0 deletions
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
new file mode 100644
index 00000000..d135899a
--- /dev/null
+++ b/klm/search/edge_generator.cc
@@ -0,0 +1,129 @@
+#include "search/edge_generator.hh"
+
+#include "lm/left.hh"
+#include "lm/partial.hh"
+#include "search/context.hh"
+#include "search/vertex.hh"
+#include "search/vertex_generator.hh"
+
+#include <numeric>
+
+namespace search {
+
+bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) {
+ from_ = &edge;
+ for (unsigned int i = 0; i < GetRule().Arity(); ++i) {
+ if (edge.GetVertex(i).RootPartial().Empty()) return false;
+ }
+ PartialEdge &root = *parent.MallocPartialEdge();
+ root.score = GetRule().Bound();
+ for (unsigned int i = 0; i < GetRule().Arity(); ++i) {
+ root.nt[i] = edge.GetVertex(i).RootPartial();
+ root.score += root.nt[i].Bound();
+ }
+ for (unsigned int i = GetRule().Arity(); i < 2; ++i) {
+ root.nt[i] = kBlankPartialVertex;
+ }
+ for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) {
+ root.between[i] = GetRule().Lexical(i);
+ }
+ // wtf no clear method?
+ generate_ = Generate();
+ generate_.push(&root);
+ top_ = root.score;
+ return true;
+}
+
+namespace {
+
+template <class Model> float FastScore(const Context<Model> &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) {
+ memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1));
+
+ float ret = 0.0;
+ lm::ngram::ChartState *before, *after;
+ if (victim == 0) {
+ before = &update.between[0];
+ after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1];
+ } else {
+ assert(victim == 1);
+ assert(arity == 2);
+ before = &update.between[previous.nt[0].Complete() ? 0 : 1];
+ after = &update.between[2];
+ }
+ const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State();
+ const PartialVertex &update_nt = update.nt[victim];
+ const lm::ngram::ChartState &update_reveal = update_nt.State();
+ float just_after = 0.0;
+ if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) {
+ just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length);
+ }
+ if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) {
+ ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right);
+ }
+ if (update_nt.Complete()) {
+ if (update_reveal.left.full) {
+ before->left.full = true;
+ } else {
+ assert(update_reveal.left.length == update_reveal.right.length);
+ ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length);
+ }
+ if (victim == 0) {
+ update.between[0].right = after->right;
+ } else {
+ update.between[2].left = before->left;
+ }
+ }
+ return previous.score + (ret + just_after) * context.GetWeights().LM();
+}
+
+} // namespace
+
+template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGenerator &parent) {
+ assert(!generate_.empty());
+ PartialEdge &top = *generate_.top();
+ generate_.pop();
+ unsigned int victim = 0;
+ unsigned char lowest_length = 255;
+ for (unsigned int i = 0; i != GetRule().Arity(); ++i) {
+ if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) {
+ lowest_length = top.nt[i].Length();
+ victim = i;
+ }
+ }
+ if (lowest_length == 255) {
+ // All states report complete.
+ top.between[0].right = top.between[GetRule().Arity()].right;
+ parent.NewHypothesis(top.between[0], *from_, top);
+ top_ = generate_.empty() ? -kScoreInf : generate_.top()->score;
+ return !generate_.empty();
+ }
+
+ unsigned int stay = !victim;
+ PartialEdge &continuation = *parent.MallocPartialEdge();
+ float old_bound = top.nt[victim].Bound();
+ // The alternate's score will change because alternate.nt[victim] changes.
+ bool split = top.nt[victim].Split(continuation.nt[victim]);
+ // top is now the alternate.
+
+ continuation.nt[stay] = top.nt[stay];
+ continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation);
+ // TODO: dedupe?
+ generate_.push(&continuation);
+
+ if (split) {
+ // We have an alternate.
+ top.score += top.nt[victim].Bound() - old_bound;
+ // TODO: dedupe?
+ generate_.push(&top);
+ } else {
+ parent.FreePartialEdge(&top);
+ }
+
+ top_ = generate_.top()->score;
+ return true;
+}
+
+template bool EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, VertexGenerator &parent);
+template bool EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, VertexGenerator &parent);
+
+} // namespace search