summaryrefslogtreecommitdiff
path: root/extractor/phrase_builder.cc
blob: 4325390c2a88bac52862ee985427452063d15dd0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include "phrase_builder.h"

#include "phrase.h"
#include "vocabulary.h"

PhraseBuilder::PhraseBuilder(shared_ptr<Vocabulary> vocabulary) :
    vocabulary(vocabulary) {}

Phrase PhraseBuilder::Build(const vector<int>& symbols) {
  Phrase phrase;
  phrase.symbols = symbols;
  for (size_t i = 0; i < symbols.size(); ++i) {
    if (vocabulary->IsTerminal(symbols[i])) {
      phrase.words.push_back(vocabulary->GetTerminalValue(symbols[i]));
    } else {
      phrase.var_pos.push_back(i);
    }
  }
  return phrase;
}

Phrase PhraseBuilder::Extend(const Phrase& phrase, bool start_x, bool end_x) {
  vector<int> symbols = phrase.Get();
  int num_nonterminals = 0;
  if (start_x) {
    num_nonterminals = 1;
    symbols.insert(symbols.begin(),
        vocabulary->GetNonterminalIndex(num_nonterminals));
  }

  for (size_t i = start_x; i < symbols.size(); ++i) {
    if (!vocabulary->IsTerminal(symbols[i])) {
      ++num_nonterminals;
      symbols[i] = vocabulary->GetNonterminalIndex(num_nonterminals);
    }
  }

  if (end_x) {
    ++num_nonterminals;
    symbols.push_back(vocabulary->GetNonterminalIndex(num_nonterminals));
  }

  return Build(symbols);
}