summaryrefslogtreecommitdiff
path: root/klm/search/vertex.hh
blob: e1a9ad1138151a4174f80b24d176e141398f3ce0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#ifndef SEARCH_VERTEX__
#define SEARCH_VERTEX__

#include "lm/left.hh"
#include "search/final.hh"
#include "search/types.hh"

#include <boost/unordered_set.hpp>

#include <queue>
#include <vector>

#include <stdint.h>

namespace search {

class ContextBase;

class VertexNode {
  public:
    VertexNode() : end_(NULL) {}

    void InitRoot() {
      extend_.clear();
      state_.left.full = false;
      state_.left.length = 0;
      state_.right.length = 0;
      right_full_ = false;
      bound_ = -kScoreInf;
      end_ = NULL;
    }

    lm::ngram::ChartState &MutableState() { return state_; }
    bool &MutableRightFull() { return right_full_; }

    void AddExtend(VertexNode *next) {
      extend_.push_back(next);
    }

    void SetEnd(Final *end) { end_ = end; }
    
    Final &MutableEnd() { return *end_; }

    void SortAndSet(ContextBase &context, VertexNode **parent_pointer);

    // Should only happen to a root node when the entire vertex is empty.   
    bool Empty() const {
      return !end_ && extend_.empty();
    }

    bool Complete() const {
      return end_;
    }

    const lm::ngram::ChartState &State() const { return state_; }
    bool RightFull() const { return right_full_; }

    Score Bound() const {
      return bound_;
    }

    unsigned char Length() const {
      return state_.left.length + state_.right.length;
    }

    // May be NULL.
    const Final *End() const { return end_; }

    const VertexNode &operator[](size_t index) const {
      return *extend_[index];
    }

    size_t Size() const {
      return extend_.size();
    }

  private:
    std::vector<VertexNode*> extend_;

    lm::ngram::ChartState state_;
    bool right_full_;

    Score bound_;
    Final *end_;
};

class PartialVertex {
  public:
    PartialVertex() {}

    explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {}

    bool Empty() const { return back_->Empty(); }

    bool Complete() const { return back_->Complete(); }

    const lm::ngram::ChartState &State() const { return back_->State(); }
    bool RightFull() const { return back_->RightFull(); }

    Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); }

    unsigned char Length() const { return back_->Length(); }

    bool HasAlternative() const {
      return index_ + 1 < back_->Size();
    }

    // Split into continuation and alternative, rendering this the alternative.
    bool Split(PartialVertex &continuation) {
      assert(!Complete());
      continuation.back_ = &((*back_)[index_]);
      continuation.index_ = 0;
      if (index_ + 1 < back_->Size()) {
        ++index_;
        return true;
      }
      return false;
    }

    const Final &End() const {
      return *back_->End();
    }

  private:
    const VertexNode *back_;
    unsigned int index_;
};

extern PartialVertex kBlankPartialVertex;

class Vertex {
  public:
    Vertex() {}

    PartialVertex RootPartial() const { return PartialVertex(root_); }

    const Final *BestChild() const {
      PartialVertex top(RootPartial());
      if (top.Empty()) {
        return NULL;
      } else {
        PartialVertex continuation;
        while (!top.Complete()) {
          top.Split(continuation);
          top = continuation;
        }
        return &top.End();
      }
    }

  private:
    friend class VertexGenerator;

    VertexNode root_;
};

} // namespace search
#endif // SEARCH_VERTEX__