summaryrefslogtreecommitdiff
path: root/rst_parser/dep_training.cc
blob: ef97798b675d78649c419a57da4ac247e9616e1e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include "dep_training.h"

#include <vector>
#include <iostream>

#include "stringlib.h"
#include "filelib.h"
#include "tdict.h"
#include "picojson.h"

using namespace std;

static void ParseInstance(const string& line, int start, TrainingInstance* out, int lc = 0) {
  picojson::value obj;
  string err;
  picojson::parse(obj, line.begin() + start, line.end(), &err);
  if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); }
  TrainingInstance& cur = *out;
  TaggedSentence& ts = cur.ts;
  EdgeSubset& tree = cur.tree;
  ts.pos.clear();
  ts.words.clear();
  tree.roots.clear();
  tree.h_m_pairs.clear();
  assert(obj.is<picojson::object>());
  const picojson::object& d = obj.get<picojson::object>();
  const picojson::array& ta = d.find("tokens")->second.get<picojson::array>();
  for (unsigned i = 0; i < ta.size(); ++i) {
    ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>()));
    ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>()));
  }
  if (d.find("deps") != d.end()) {
    const picojson::array& da = d.find("deps")->second.get<picojson::array>();
    for (unsigned i = 0; i < da.size(); ++i) {
      const picojson::array& thm = da[i].get<picojson::array>();
      // get dep type here
      short h = thm[2].get<double>();
      short m = thm[1].get<double>();
      if (h < 0)
        tree.roots.push_back(m);
      else
        tree.h_m_pairs.push_back(make_pair(h,m));
    }
  }
  //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl;
}

bool TrainingInstance::ReadInstance(std::istream* in, TrainingInstance* instance) {
  string line;
  if (!getline(*in, line)) return false;
  size_t pos = line.rfind('\t');
  assert(pos != string::npos);
  static int lc = 0; ++lc;
  ParseInstance(line, pos + 1, instance, lc);
  return true;
}

void TrainingInstance::ReadTrainingCorpus(const string& fname, vector<TrainingInstance>* corpus, int rank, int size) {
  ReadFile rf(fname);
  istream& in = *rf.stream();
  string line;
  int lc = 0;
  bool flag = false;
  while(getline(in, line)) {
    ++lc;
    if ((lc-1) % size != rank) continue;
    if (rank == 0 && lc % 10 == 0) { cerr << '.' << flush; flag = true; }
    if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; }
    size_t pos = line.rfind('\t');
    assert(pos != string::npos);
    corpus->push_back(TrainingInstance());
    ParseInstance(line, pos + 1, &corpus->back(), lc);
  }
  if (flag) cerr << "\nRead " << lc << " training instances\n";
}