summaryrefslogtreecommitdiff
path: root/extractor/grammar_extractor.cc
blob: 3014c2e96ea72ee9aabd0e8057c0b9df7a5463e1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include "grammar_extractor.h"

#include <iterator>
#include <sstream>
#include <vector>

using namespace std;

vector<string> Tokenize(const string& sentence) {
  vector<string> result;
  result.push_back("<s>");

  istringstream buffer(sentence);
  copy(istream_iterator<string>(buffer),
       istream_iterator<string>(),
       back_inserter(result));

  result.push_back("</s>");
  return result;
}

GrammarExtractor::GrammarExtractor(
    shared_ptr<SuffixArray> source_suffix_array,
    shared_ptr<DataArray> target_data_array,
    const Alignment& alignment, const Precomputation& precomputation,
    int min_gap_size, int max_rule_span, int max_nonterminals,
    int max_rule_symbols, bool use_baeza_yates) :
    vocabulary(make_shared<Vocabulary>()),
    rule_factory(source_suffix_array, target_data_array, alignment,
        vocabulary, precomputation, min_gap_size, max_rule_span,
        max_nonterminals, max_rule_symbols, use_baeza_yates) {}

void GrammarExtractor::GetGrammar(const string& sentence) {
  vector<string> words = Tokenize(sentence);
  vector<int> word_ids = AnnotateWords(words);
  rule_factory.GetGrammar(word_ids);
}

vector<int> GrammarExtractor::AnnotateWords(const vector<string>& words) {
  vector<int> result;
  for (string word: words) {
    result.push_back(vocabulary->GetTerminalIndex(word));
  }
  return result;
}