1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
|
#include "search/vertex.hh"
#include "search/context.hh"
#include <boost/unordered_map.hpp>
#include <algorithm>
#include <functional>
#include <assert.h>
namespace search {
namespace {
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
class DivideLeft {
public:
explicit DivideLeft(unsigned char index)
: index_(index) {}
uint64_t operator()(const lm::ngram::ChartState &state) const {
return (index_ < state.left.length) ?
state.left.pointers[index_] :
(kCompleteAdd - state.left.full);
}
private:
unsigned char index_;
};
class DivideRight {
public:
explicit DivideRight(unsigned char index)
: index_(index) {}
uint64_t operator()(const lm::ngram::ChartState &state) const {
return (index_ < state.right.length) ?
static_cast<uint64_t>(state.right.words[index_]) :
(kCompleteAdd - state.left.full);
}
private:
unsigned char index_;
};
template <class Divider> void Split(const Divider ÷r, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) {
// Map from divider to index in extend.
typedef boost::unordered_map<uint64_t, std::size_t> Lookup;
Lookup lookup;
for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) {
uint64_t key = divider(i->state);
std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size())));
if (res.second) {
extend.resize(extend.size() + 1);
extend.back().AppendHypothesis(*i);
} else {
extend[res.first->second].AppendHypothesis(*i);
}
}
//assert((extend.size() != 1) || (hypos.size() == 1));
}
lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) {
return right.words[index];
}
uint64_t Identify(const lm::ngram::Left &left, unsigned char index) {
return left.pointers[index];
}
template <class Side> class DetermineSame {
public:
DetermineSame(const Side &side, unsigned char guaranteed)
: side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {}
void Consider(const Side &other) {
if (shared_ != other.length) {
complete_ = false;
if (shared_ > other.length)
shared_ = other.length;
}
for (unsigned char i = guaranteed_; i < shared_; ++i) {
if (Identify(side_, i) != Identify(other, i)) {
shared_ = i;
complete_ = false;
return;
}
}
}
unsigned char Shared() const { return shared_; }
bool Complete() const { return complete_; }
private:
const Side &side_;
unsigned char guaranteed_, shared_;
bool complete_;
};
// Custom enum to save memory: valid values of policy_.
// Alternate and there is still alternation to do.
const unsigned char kPolicyAlternate = 0;
// Branch based on left state only, because right ran out or this is a left tree.
const unsigned char kPolicyOneLeft = 1;
// Branch based on right state only.
const unsigned char kPolicyOneRight = 2;
// Reveal everything in the next branch. Used to terminate the left/right policies.
// static const unsigned char kPolicyEverything = 3;
} // namespace
namespace {
struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> {
bool operator()(const HypoState &first, const HypoState &second) const {
return first.score > second.score;
}
};
} // namespace
void VertexNode::FinishRoot() {
std::sort(hypos_.begin(), hypos_.end(), GreaterByScore());
extend_.clear();
// HACK: extend to one hypo so that root can be blank.
state_.left.full = false;
state_.left.length = 0;
state_.right.length = 0;
right_full_ = false;
niceness_ = 0;
policy_ = kPolicyAlternate;
if (hypos_.size() == 1) {
extend_.resize(1);
extend_.front().AppendHypothesis(hypos_.front());
extend_.front().FinishedAppending(0, 0);
}
if (hypos_.empty()) {
bound_ = -INFINITY;
} else {
bound_ = hypos_.front().score;
}
}
void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) {
assert(!hypos_.empty());
assert(extend_.empty());
bound_ = hypos_.front().score;
state_ = hypos_.front().state;
bool all_full = state_.left.full;
bool all_non_full = !state_.left.full;
DetermineSame<lm::ngram::Left> left(state_.left, common_left);
DetermineSame<lm::ngram::Right> right(state_.right, common_right);
for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) {
all_full &= i->state.left.full;
all_non_full &= !i->state.left.full;
left.Consider(i->state.left);
right.Consider(i->state.right);
}
state_.left.full = all_full && left.Complete();
right_full_ = all_full && right.Complete();
state_.left.length = left.Shared();
state_.right.length = right.Shared();
if (!all_full && !all_non_full) {
policy_ = kPolicyAlternate;
} else if (left.Complete()) {
policy_ = kPolicyOneRight;
} else if (right.Complete()) {
policy_ = kPolicyOneLeft;
} else {
policy_ = kPolicyAlternate;
}
niceness_ = state_.left.length + state_.right.length;
}
void VertexNode::BuildExtend() {
// Already built.
if (!extend_.empty()) return;
// Nothing to build since this is a leaf.
if (hypos_.size() <= 1) return;
bool left_branch = true;
switch (policy_) {
case kPolicyAlternate:
left_branch = (state_.left.length <= state_.right.length);
break;
case kPolicyOneLeft:
left_branch = true;
break;
case kPolicyOneRight:
left_branch = false;
break;
}
if (left_branch) {
Split(DivideLeft(state_.left.length), hypos_, extend_);
} else {
Split(DivideRight(state_.right.length), hypos_, extend_);
}
for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
// TODO: provide more here for branching?
i->FinishedAppending(state_.left.length, state_.right.length);
}
}
} // namespace search
|