summaryrefslogtreecommitdiff
path: root/utils/corpus_tools.cc
blob: 191153a2b5e4de36da41a39b2f412d48c9109b2c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include "corpus_tools.h"

#include <iostream>

#include "tdict.h"
#include "filelib.h"
#include "verbose.h"

using namespace std;

void CorpusTools::ReadLine(const string& line,
                           vector<WordID>* src,
                           vector<WordID>* trg) {
  static const WordID kDIV = TD::Convert("|||");
  static vector<WordID> tmp;
  src->clear();
  trg->clear();
  TD::ConvertSentence(line, &tmp);
  unsigned i = 0;
  while(i < tmp.size() && tmp[i] != kDIV) {
    src->push_back(tmp[i]);
    ++i;
  }
  if (i < tmp.size() && tmp[i] == kDIV) {
    ++i;
    for (; i < tmp.size() ; ++i)
      trg->push_back(tmp[i]);
  }
}

void CorpusTools::ReadFromFile(const string& filename,
                           vector<vector<WordID> >* src,
                           set<WordID>* src_vocab,
                           vector<vector<WordID> >* trg,
                           set<WordID>* trg_vocab,
                           int rank,
                           int size) {
  assert(rank >= 0);
  assert(size > 0);
  assert(rank < size);
  if (src) src->clear();
  if (src_vocab) src_vocab->clear();
  if (trg) trg->clear();
  if (trg_vocab) trg_vocab->clear();
  const int expected_fields = 1 + (trg == NULL ? 0 : 1);
  if (!SILENT) cerr << "Reading from " << filename << " ...\n";
  ReadFile rf(filename);
  istream& in = *rf.stream();
  string line;
  int lc = 0;
  static const WordID kDIV = TD::Convert("|||");
  vector<WordID> tmp;
  while(getline(in, line)) {
    const bool skip = (lc % size != rank);
    ++lc;
    TD::ConvertSentence(line, &tmp);
    vector<WordID>* d = NULL;
    if (!skip) {
      src->push_back(vector<WordID>());
      d = &src->back();
    }
    set<WordID>* v = src_vocab;
    int s = 0;
    for (unsigned i = 0; i < tmp.size(); ++i) {
      if (tmp[i] == kDIV) {
        ++s;
        if (s > 1) { cerr << "Unexpected format in line " << lc << ": " << line << endl; abort(); }
        assert(trg);
        if (!skip) {
          trg->push_back(vector<WordID>());
          d = &trg->back();
        }
        v = trg_vocab;
      } else {
        if (d) d->push_back(tmp[i]);
        if (v) v->insert(tmp[i]);
      }
    }
    ++s;
    if (expected_fields != s) {
      cerr << "Wrong number of fields in line " << lc << ": " << line << endl; abort();
    }
  }
}