blob: 4325390c2a88bac52862ee985427452063d15dd0 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
|
#include "phrase_builder.h"
#include "phrase.h"
#include "vocabulary.h"
PhraseBuilder::PhraseBuilder(shared_ptr<Vocabulary> vocabulary) :
vocabulary(vocabulary) {}
Phrase PhraseBuilder::Build(const vector<int>& symbols) {
Phrase phrase;
phrase.symbols = symbols;
for (size_t i = 0; i < symbols.size(); ++i) {
if (vocabulary->IsTerminal(symbols[i])) {
phrase.words.push_back(vocabulary->GetTerminalValue(symbols[i]));
} else {
phrase.var_pos.push_back(i);
}
}
return phrase;
}
Phrase PhraseBuilder::Extend(const Phrase& phrase, bool start_x, bool end_x) {
vector<int> symbols = phrase.Get();
int num_nonterminals = 0;
if (start_x) {
num_nonterminals = 1;
symbols.insert(symbols.begin(),
vocabulary->GetNonterminalIndex(num_nonterminals));
}
for (size_t i = start_x; i < symbols.size(); ++i) {
if (!vocabulary->IsTerminal(symbols[i])) {
++num_nonterminals;
symbols[i] = vocabulary->GetNonterminalIndex(num_nonterminals);
}
}
if (end_x) {
++num_nonterminals;
symbols.push_back(vocabulary->GetNonterminalIndex(num_nonterminals));
}
return Build(symbols);
}
|