summaryrefslogtreecommitdiff
path: root/klm/search/vertex.hh
blob: 81c3cfed6a838a849f1d0d2365f709784314111b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#ifndef SEARCH_VERTEX__
#define SEARCH_VERTEX__

#include "lm/left.hh"
#include "search/types.hh"

#include <boost/unordered_set.hpp>

#include <queue>
#include <vector>

#include <math.h>
#include <stdint.h>

namespace search {

class ContextBase;

struct HypoState {
  History history;
  lm::ngram::ChartState state;
  Score score;
};

class VertexNode {
  public:
    VertexNode() {}

    void InitRoot() { hypos_.clear(); }

    /* The steps of building a VertexNode:
     * 1. Default construct.
     * 2. AppendHypothesis at least once, possibly multiple times.
     * 3. FinishAppending with the number of words on left and right guaranteed
     * to be common.
     * 4. If !Complete(), call BuildExtend to construct the extensions
     */
    // Must default construct, call AppendHypothesis 1 or more times then do FinishedAppending.
    void AppendHypothesis(const NBestComplete &best) {
      assert(hypos_.empty() || !(hypos_.front().state == *best.state));
      HypoState hypo;
      hypo.history = best.history;
      hypo.state = *best.state;
      hypo.score = best.score;
      hypos_.push_back(hypo);
    }
    void AppendHypothesis(const HypoState &hypo) {
      hypos_.push_back(hypo);
    }

    // Sort hypotheses for the root.
    void FinishRoot();

    void FinishedAppending(const unsigned char common_left, const unsigned char common_right);

    void BuildExtend();

    // Should only happen to a root node when the entire vertex is empty.   
    bool Empty() const {
      return hypos_.empty() && extend_.empty();
    }

    bool Complete() const {
      // HACK: prevent root from being complete.  TODO: allow root to be complete.
      return hypos_.size() == 1 && extend_.empty();
    }

    const lm::ngram::ChartState &State() const { return state_; }
    bool RightFull() const { return right_full_; }

    // Priority relative to other non-terminals.  0 is highest.
    unsigned char Niceness() const { return niceness_; }

    Score Bound() const {
      return bound_;
    }

    // Will be invalid unless this is a leaf.   
    History End() const {
      assert(hypos_.size() == 1);
      return hypos_.front().history;
    }

    VertexNode &operator[](size_t index) {
      assert(!extend_.empty());
      return extend_[index];
    }

    size_t Size() const {
      return extend_.size();
    }

  private:
    // Hypotheses to be split.
    std::vector<HypoState> hypos_;

    std::vector<VertexNode> extend_;

    lm::ngram::ChartState state_;
    bool right_full_;

    unsigned char niceness_;

    unsigned char policy_;

    Score bound_;
};

class PartialVertex {
  public:
    PartialVertex() {}

    explicit PartialVertex(VertexNode &back) : back_(&back), index_(0) {}

    bool Empty() const { return back_->Empty(); }

    bool Complete() const { return back_->Complete(); }

    const lm::ngram::ChartState &State() const { return back_->State(); }
    bool RightFull() const { return back_->RightFull(); }

    Score Bound() const { return index_ ? (*back_)[index_].Bound() : back_->Bound(); }

    unsigned char Niceness() const { return back_->Niceness(); }

    // Split into continuation and alternative, rendering this the continuation.
    bool Split(PartialVertex &alternative) {
      assert(!Complete());
      back_->BuildExtend();
      bool ret;
      if (index_ + 1 < back_->Size()) {
        alternative.index_ = index_ + 1;
        alternative.back_ = back_;
        ret = true;
      } else {
        ret = false;
      }
      back_ = &((*back_)[index_]);
      index_ = 0;
      return ret;
    }

    History End() const {
      return back_->End();
    }

  private:
    VertexNode *back_;
    unsigned int index_;
};

template <class Output> class VertexGenerator;

class Vertex {
  public:
    Vertex() {}

    //PartialVertex RootFirst() const { return PartialVertex(right_); }
    PartialVertex RootAlternate() { return PartialVertex(root_); }
    //PartialVertex RootLast() const { return PartialVertex(left_); }

    bool Empty() const {
      return root_.Empty();
    }

    Score Bound() const {
      return root_.Bound();
    }

    History BestChild() {
      // left_ and right_ are not set at the root.
      PartialVertex top(RootAlternate());
      if (top.Empty()) {
        return History();
      } else {
        PartialVertex continuation;
        while (!top.Complete()) {
          top.Split(continuation);
        }
        return top.End();
      }
    }

  private:
    template <class Output> friend class VertexGenerator;
    template <class Output> friend class RootVertexGenerator;
    VertexNode root_;

    // These will not be set for the root vertex.
    // Branches only on left state.
    //VertexNode left_;
    // Branches only on right state.
    //VertexNode right_;
};

} // namespace search
#endif // SEARCH_VERTEX__