summaryrefslogtreecommitdiff
path: root/klm/search/vertex.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/search/vertex.hh')
-rw-r--r--klm/search/vertex.hh37
1 files changed, 21 insertions, 16 deletions
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
index 52bc1dfe..10b3339b 100644
--- a/klm/search/vertex.hh
+++ b/klm/search/vertex.hh
@@ -2,7 +2,6 @@
#define SEARCH_VERTEX__
#include "lm/left.hh"
-#include "search/final.hh"
#include "search/types.hh"
#include <boost/unordered_set.hpp>
@@ -10,6 +9,7 @@
#include <queue>
#include <vector>
+#include <math.h>
#include <stdint.h>
namespace search {
@@ -18,7 +18,7 @@ class ContextBase;
class VertexNode {
public:
- VertexNode() {}
+ VertexNode() : end_() {}
void InitRoot() {
extend_.clear();
@@ -26,7 +26,7 @@ class VertexNode {
state_.left.length = 0;
state_.right.length = 0;
right_full_ = false;
- end_ = Final();
+ end_ = History();
}
lm::ngram::ChartState &MutableState() { return state_; }
@@ -36,20 +36,21 @@ class VertexNode {
extend_.push_back(next);
}
- void SetEnd(Final end) {
- assert(!end_.Valid());
+ void SetEnd(History end, Score score) {
+ assert(!end_);
end_ = end;
+ bound_ = score;
}
- void SortAndSet(ContextBase &context, VertexNode **parent_pointer);
+ void SortAndSet(ContextBase &context);
// Should only happen to a root node when the entire vertex is empty.
bool Empty() const {
- return !end_.Valid() && extend_.empty();
+ return !end_ && extend_.empty();
}
bool Complete() const {
- return end_.Valid();
+ return end_;
}
const lm::ngram::ChartState &State() const { return state_; }
@@ -64,7 +65,7 @@ class VertexNode {
}
// Will be invalid unless this is a leaf.
- const Final End() const { return end_; }
+ const History End() const { return end_; }
const VertexNode &operator[](size_t index) const {
return *extend_[index];
@@ -75,13 +76,15 @@ class VertexNode {
}
private:
+ void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent);
+
std::vector<VertexNode*> extend_;
lm::ngram::ChartState state_;
bool right_full_;
Score bound_;
- Final end_;
+ History end_;
};
class PartialVertex {
@@ -97,7 +100,7 @@ class PartialVertex {
const lm::ngram::ChartState &State() const { return back_->State(); }
bool RightFull() const { return back_->RightFull(); }
- Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); }
+ Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }
unsigned char Length() const { return back_->Length(); }
@@ -121,7 +124,7 @@ class PartialVertex {
return ret;
}
- const Final End() const {
+ const History End() const {
return back_->End();
}
@@ -130,16 +133,18 @@ class PartialVertex {
unsigned int index_;
};
+template <class Output> class VertexGenerator;
+
class Vertex {
public:
Vertex() {}
PartialVertex RootPartial() const { return PartialVertex(root_); }
- const Final BestChild() const {
+ const History BestChild() const {
PartialVertex top(RootPartial());
if (top.Empty()) {
- return Final();
+ return History();
} else {
PartialVertex continuation;
while (!top.Complete()) {
@@ -150,8 +155,8 @@ class Vertex {
}
private:
- friend class VertexGenerator;
-
+ template <class Output> friend class VertexGenerator;
+ template <class Output> friend class RootVertexGenerator;
VertexNode root_;
};