diff options
Diffstat (limited to 'klm/search/vertex.hh')
-rw-r--r-- | klm/search/vertex.hh | 121 |
1 files changed, 77 insertions, 44 deletions
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index ca9a4fcd..81c3cfed 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -16,59 +16,74 @@ namespace search { class ContextBase; +struct HypoState { + History history; + lm::ngram::ChartState state; + Score score; +}; + class VertexNode { public: - VertexNode() : end_() {} - - void InitRoot() { - extend_.clear(); - state_.left.full = false; - state_.left.length = 0; - state_.right.length = 0; - right_full_ = false; - end_ = History(); + VertexNode() {} + + void InitRoot() { hypos_.clear(); } + + /* The steps of building a VertexNode: + * 1. Default construct. + * 2. AppendHypothesis at least once, possibly multiple times. + * 3. FinishAppending with the number of words on left and right guaranteed + * to be common. + * 4. If !Complete(), call BuildExtend to construct the extensions + */ + // Must default construct, call AppendHypothesis 1 or more times then do FinishedAppending. + void AppendHypothesis(const NBestComplete &best) { + assert(hypos_.empty() || !(hypos_.front().state == *best.state)); + HypoState hypo; + hypo.history = best.history; + hypo.state = *best.state; + hypo.score = best.score; + hypos_.push_back(hypo); + } + void AppendHypothesis(const HypoState &hypo) { + hypos_.push_back(hypo); } - lm::ngram::ChartState &MutableState() { return state_; } - bool &MutableRightFull() { return right_full_; } + // Sort hypotheses for the root. + void FinishRoot(); - void AddExtend(VertexNode *next) { - extend_.push_back(next); - } + void FinishedAppending(const unsigned char common_left, const unsigned char common_right); - void SetEnd(History end, Score score) { - assert(!end_); - end_ = end; - bound_ = score; - } - - void SortAndSet(ContextBase &context); + void BuildExtend(); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_ && extend_.empty(); + return hypos_.empty() && extend_.empty(); } bool Complete() const { - return end_; + // HACK: prevent root from being complete. TODO: allow root to be complete. + return hypos_.size() == 1 && extend_.empty(); } const lm::ngram::ChartState &State() const { return state_; } bool RightFull() const { return right_full_; } + // Priority relative to other non-terminals. 0 is highest. + unsigned char Niceness() const { return niceness_; } + Score Bound() const { return bound_; } - unsigned char Length() const { - return state_.left.length + state_.right.length; - } - // Will be invalid unless this is a leaf. - History End() const { return end_; } + History End() const { + assert(hypos_.size() == 1); + return hypos_.front().history; + } - const VertexNode &operator[](size_t index) const { - return *extend_[index]; + VertexNode &operator[](size_t index) { + assert(!extend_.empty()); + return extend_[index]; } size_t Size() const { @@ -76,22 +91,26 @@ class VertexNode { } private: - void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); + // Hypotheses to be split. + std::vector<HypoState> hypos_; - std::vector<VertexNode*> extend_; + std::vector<VertexNode> extend_; lm::ngram::ChartState state_; bool right_full_; + unsigned char niceness_; + + unsigned char policy_; + Score bound_; - History end_; }; class PartialVertex { public: PartialVertex() {} - explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} + explicit PartialVertex(VertexNode &back) : back_(&back), index_(0) {} bool Empty() const { return back_->Empty(); } @@ -100,17 +119,14 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } - - unsigned char Length() const { return back_->Length(); } + Score Bound() const { return index_ ? (*back_)[index_].Bound() : back_->Bound(); } - bool HasAlternative() const { - return index_ + 1 < back_->Size(); - } + unsigned char Niceness() const { return back_->Niceness(); } // Split into continuation and alternative, rendering this the continuation. bool Split(PartialVertex &alternative) { assert(!Complete()); + back_->BuildExtend(); bool ret; if (index_ + 1 < back_->Size()) { alternative.index_ = index_ + 1; @@ -129,7 +145,7 @@ class PartialVertex { } private: - const VertexNode *back_; + VertexNode *back_; unsigned int index_; }; @@ -139,10 +155,21 @@ class Vertex { public: Vertex() {} - PartialVertex RootPartial() const { return PartialVertex(root_); } + //PartialVertex RootFirst() const { return PartialVertex(right_); } + PartialVertex RootAlternate() { return PartialVertex(root_); } + //PartialVertex RootLast() const { return PartialVertex(left_); } + + bool Empty() const { + return root_.Empty(); + } + + Score Bound() const { + return root_.Bound(); + } - History BestChild() const { - PartialVertex top(RootPartial()); + History BestChild() { + // left_ and right_ are not set at the root. + PartialVertex top(RootAlternate()); if (top.Empty()) { return History(); } else { @@ -158,6 +185,12 @@ class Vertex { template <class Output> friend class VertexGenerator; template <class Output> friend class RootVertexGenerator; VertexNode root_; + + // These will not be set for the root vertex. + // Branches only on left state. + //VertexNode left_; + // Branches only on right state. + //VertexNode right_; }; } // namespace search |