summaryrefslogtreecommitdiff
path: root/klm/search/vertex.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/search/vertex.hh')
-rw-r--r--klm/search/vertex.hh121
1 files changed, 77 insertions, 44 deletions
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
index ca9a4fcd..81c3cfed 100644
--- a/klm/search/vertex.hh
+++ b/klm/search/vertex.hh
@@ -16,59 +16,74 @@ namespace search {
class ContextBase;
+struct HypoState {
+ History history;
+ lm::ngram::ChartState state;
+ Score score;
+};
+
class VertexNode {
public:
- VertexNode() : end_() {}
-
- void InitRoot() {
- extend_.clear();
- state_.left.full = false;
- state_.left.length = 0;
- state_.right.length = 0;
- right_full_ = false;
- end_ = History();
+ VertexNode() {}
+
+ void InitRoot() { hypos_.clear(); }
+
+ /* The steps of building a VertexNode:
+ * 1. Default construct.
+ * 2. AppendHypothesis at least once, possibly multiple times.
+ * 3. FinishAppending with the number of words on left and right guaranteed
+ * to be common.
+ * 4. If !Complete(), call BuildExtend to construct the extensions
+ */
+ // Must default construct, call AppendHypothesis 1 or more times then do FinishedAppending.
+ void AppendHypothesis(const NBestComplete &best) {
+ assert(hypos_.empty() || !(hypos_.front().state == *best.state));
+ HypoState hypo;
+ hypo.history = best.history;
+ hypo.state = *best.state;
+ hypo.score = best.score;
+ hypos_.push_back(hypo);
+ }
+ void AppendHypothesis(const HypoState &hypo) {
+ hypos_.push_back(hypo);
}
- lm::ngram::ChartState &MutableState() { return state_; }
- bool &MutableRightFull() { return right_full_; }
+ // Sort hypotheses for the root.
+ void FinishRoot();
- void AddExtend(VertexNode *next) {
- extend_.push_back(next);
- }
+ void FinishedAppending(const unsigned char common_left, const unsigned char common_right);
- void SetEnd(History end, Score score) {
- assert(!end_);
- end_ = end;
- bound_ = score;
- }
-
- void SortAndSet(ContextBase &context);
+ void BuildExtend();
// Should only happen to a root node when the entire vertex is empty.
bool Empty() const {
- return !end_ && extend_.empty();
+ return hypos_.empty() && extend_.empty();
}
bool Complete() const {
- return end_;
+ // HACK: prevent root from being complete. TODO: allow root to be complete.
+ return hypos_.size() == 1 && extend_.empty();
}
const lm::ngram::ChartState &State() const { return state_; }
bool RightFull() const { return right_full_; }
+ // Priority relative to other non-terminals. 0 is highest.
+ unsigned char Niceness() const { return niceness_; }
+
Score Bound() const {
return bound_;
}
- unsigned char Length() const {
- return state_.left.length + state_.right.length;
- }
-
// Will be invalid unless this is a leaf.
- History End() const { return end_; }
+ History End() const {
+ assert(hypos_.size() == 1);
+ return hypos_.front().history;
+ }
- const VertexNode &operator[](size_t index) const {
- return *extend_[index];
+ VertexNode &operator[](size_t index) {
+ assert(!extend_.empty());
+ return extend_[index];
}
size_t Size() const {
@@ -76,22 +91,26 @@ class VertexNode {
}
private:
- void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent);
+ // Hypotheses to be split.
+ std::vector<HypoState> hypos_;
- std::vector<VertexNode*> extend_;
+ std::vector<VertexNode> extend_;
lm::ngram::ChartState state_;
bool right_full_;
+ unsigned char niceness_;
+
+ unsigned char policy_;
+
Score bound_;
- History end_;
};
class PartialVertex {
public:
PartialVertex() {}
- explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {}
+ explicit PartialVertex(VertexNode &back) : back_(&back), index_(0) {}
bool Empty() const { return back_->Empty(); }
@@ -100,17 +119,14 @@ class PartialVertex {
const lm::ngram::ChartState &State() const { return back_->State(); }
bool RightFull() const { return back_->RightFull(); }
- Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }
-
- unsigned char Length() const { return back_->Length(); }
+ Score Bound() const { return index_ ? (*back_)[index_].Bound() : back_->Bound(); }
- bool HasAlternative() const {
- return index_ + 1 < back_->Size();
- }
+ unsigned char Niceness() const { return back_->Niceness(); }
// Split into continuation and alternative, rendering this the continuation.
bool Split(PartialVertex &alternative) {
assert(!Complete());
+ back_->BuildExtend();
bool ret;
if (index_ + 1 < back_->Size()) {
alternative.index_ = index_ + 1;
@@ -129,7 +145,7 @@ class PartialVertex {
}
private:
- const VertexNode *back_;
+ VertexNode *back_;
unsigned int index_;
};
@@ -139,10 +155,21 @@ class Vertex {
public:
Vertex() {}
- PartialVertex RootPartial() const { return PartialVertex(root_); }
+ //PartialVertex RootFirst() const { return PartialVertex(right_); }
+ PartialVertex RootAlternate() { return PartialVertex(root_); }
+ //PartialVertex RootLast() const { return PartialVertex(left_); }
+
+ bool Empty() const {
+ return root_.Empty();
+ }
+
+ Score Bound() const {
+ return root_.Bound();
+ }
- History BestChild() const {
- PartialVertex top(RootPartial());
+ History BestChild() {
+ // left_ and right_ are not set at the root.
+ PartialVertex top(RootAlternate());
if (top.Empty()) {
return History();
} else {
@@ -158,6 +185,12 @@ class Vertex {
template <class Output> friend class VertexGenerator;
template <class Output> friend class RootVertexGenerator;
VertexNode root_;
+
+ // These will not be set for the root vertex.
+ // Branches only on left state.
+ //VertexNode left_;
+ // Branches only on right state.
+ //VertexNode right_;
};
} // namespace search