summaryrefslogtreecommitdiff
path: root/klm/search/edge_generator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/search/edge_generator.cc')
-rw-r--r--klm/search/edge_generator.cc53
1 files changed, 22 insertions, 31 deletions
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
index d135899a..56239dfb 100644
--- a/klm/search/edge_generator.cc
+++ b/klm/search/edge_generator.cc
@@ -10,28 +10,15 @@
namespace search {
-bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) {
- from_ = &edge;
- for (unsigned int i = 0; i < GetRule().Arity(); ++i) {
- if (edge.GetVertex(i).RootPartial().Empty()) return false;
- }
- PartialEdge &root = *parent.MallocPartialEdge();
- root.score = GetRule().Bound();
- for (unsigned int i = 0; i < GetRule().Arity(); ++i) {
+EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) {
+/* for (unsigned char i = 0; i < edge.Arity(); ++i) {
root.nt[i] = edge.GetVertex(i).RootPartial();
- root.score += root.nt[i].Bound();
}
- for (unsigned int i = GetRule().Arity(); i < 2; ++i) {
+ for (unsigned char i = edge.Arity(); i < 2; ++i) {
root.nt[i] = kBlankPartialVertex;
- }
- for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) {
- root.between[i] = GetRule().Lexical(i);
- }
- // wtf no clear method?
- generate_ = Generate();
+ }*/
generate_.push(&root);
- top_ = root.score;
- return true;
+ top_score_ = root.score;
}
namespace {
@@ -78,13 +65,13 @@ template <class Model> float FastScore(const Context<Model> &context, unsigned c
} // namespace
-template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGenerator &parent) {
+template <class Model> PartialEdge *EdgeGenerator::Pop(Context<Model> &context, boost::pool<> &partial_edge_pool) {
assert(!generate_.empty());
PartialEdge &top = *generate_.top();
generate_.pop();
unsigned int victim = 0;
unsigned char lowest_length = 255;
- for (unsigned int i = 0; i != GetRule().Arity(); ++i) {
+ for (unsigned char i = 0; i != arity_; ++i) {
if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) {
lowest_length = top.nt[i].Length();
victim = i;
@@ -92,21 +79,21 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe
}
if (lowest_length == 255) {
// All states report complete.
- top.between[0].right = top.between[GetRule().Arity()].right;
- parent.NewHypothesis(top.between[0], *from_, top);
- top_ = generate_.empty() ? -kScoreInf : generate_.top()->score;
- return !generate_.empty();
+ top.between[0].right = top.between[arity_].right;
+ // Now top.between[0] is the full edge state.
+ top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score;
+ return &top;
}
unsigned int stay = !victim;
- PartialEdge &continuation = *parent.MallocPartialEdge();
+ PartialEdge &continuation = *static_cast<PartialEdge*>(partial_edge_pool.malloc());
float old_bound = top.nt[victim].Bound();
// The alternate's score will change because alternate.nt[victim] changes.
bool split = top.nt[victim].Split(continuation.nt[victim]);
// top is now the alternate.
continuation.nt[stay] = top.nt[stay];
- continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation);
+ continuation.score = FastScore(context, victim, arity_, top, continuation);
// TODO: dedupe?
generate_.push(&continuation);
@@ -116,14 +103,18 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe
// TODO: dedupe?
generate_.push(&top);
} else {
- parent.FreePartialEdge(&top);
+ partial_edge_pool.free(&top);
}
- top_ = generate_.top()->score;
- return true;
+ top_score_ = generate_.top()->score;
+ return NULL;
}
-template bool EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, VertexGenerator &parent);
-template bool EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, VertexGenerator &parent);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, boost::pool<> &partial_edge_pool);
} // namespace search