summaryrefslogtreecommitdiff
path: root/klm/search/vertex.hh
blob: 10b3339b9992910488811eea5b8a3ff2ad55501b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#ifndef SEARCH_VERTEX__
#define SEARCH_VERTEX__

#include "lm/left.hh"
#include "search/types.hh"

#include <boost/unordered_set.hpp>

#include <queue>
#include <vector>

#include <math.h>
#include <stdint.h>

namespace search {

class ContextBase;

class VertexNode {
  public:
    VertexNode() : end_() {}

    void InitRoot() {
      extend_.clear();
      state_.left.full = false;
      state_.left.length = 0;
      state_.right.length = 0;
      right_full_ = false;
      end_ = History();
    }

    lm::ngram::ChartState &MutableState() { return state_; }
    bool &MutableRightFull() { return right_full_; }

    void AddExtend(VertexNode *next) {
      extend_.push_back(next);
    }

    void SetEnd(History end, Score score) {
      assert(!end_);
      end_ = end;
      bound_ = score;
    }
    
    void SortAndSet(ContextBase &context);

    // Should only happen to a root node when the entire vertex is empty.   
    bool Empty() const {
      return !end_ && extend_.empty();
    }

    bool Complete() const {
      return end_;
    }

    const lm::ngram::ChartState &State() const { return state_; }
    bool RightFull() const { return right_full_; }

    Score Bound() const {
      return bound_;
    }

    unsigned char Length() const {
      return state_.left.length + state_.right.length;
    }

    // Will be invalid unless this is a leaf.   
    const History End() const { return end_; }

    const VertexNode &operator[](size_t index) const {
      return *extend_[index];
    }

    size_t Size() const {
      return extend_.size();
    }

  private:
    void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent);

    std::vector<VertexNode*> extend_;

    lm::ngram::ChartState state_;
    bool right_full_;

    Score bound_;
    History end_;
};

class PartialVertex {
  public:
    PartialVertex() {}

    explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {}

    bool Empty() const { return back_->Empty(); }

    bool Complete() const { return back_->Complete(); }

    const lm::ngram::ChartState &State() const { return back_->State(); }
    bool RightFull() const { return back_->RightFull(); }

    Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }

    unsigned char Length() const { return back_->Length(); }

    bool HasAlternative() const {
      return index_ + 1 < back_->Size();
    }

    // Split into continuation and alternative, rendering this the continuation.
    bool Split(PartialVertex &alternative) {
      assert(!Complete());
      bool ret;
      if (index_ + 1 < back_->Size()) {
        alternative.index_ = index_ + 1;
        alternative.back_ = back_;
        ret = true;
      } else {
        ret = false;
      }
      back_ = &((*back_)[index_]);
      index_ = 0;
      return ret;
    }

    const History End() const {
      return back_->End();
    }

  private:
    const VertexNode *back_;
    unsigned int index_;
};

template <class Output> class VertexGenerator;

class Vertex {
  public:
    Vertex() {}

    PartialVertex RootPartial() const { return PartialVertex(root_); }

    const History BestChild() const {
      PartialVertex top(RootPartial());
      if (top.Empty()) {
        return History();
      } else {
        PartialVertex continuation;
        while (!top.Complete()) {
          top.Split(continuation);
        }
        return top.End();
      }
    }

  private:
    template <class Output> friend class VertexGenerator;
    template <class Output> friend class RootVertexGenerator;
    VertexNode root_;
};

} // namespace search
#endif // SEARCH_VERTEX__