summaryrefslogtreecommitdiff
path: root/rst_parser/dep_training.cc
diff options
context:
space:
mode:
Diffstat (limited to 'rst_parser/dep_training.cc')
-rw-r--r--rst_parser/dep_training.cc70
1 files changed, 43 insertions, 27 deletions
diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc
index de431ebc..e26505ec 100644
--- a/rst_parser/dep_training.cc
+++ b/rst_parser/dep_training.cc
@@ -10,11 +10,51 @@
using namespace std;
-void TrainingInstance::ReadTraining(const string& fname, vector<TrainingInstance>* corpus, int rank, int size) {
+static void ParseInstance(const string& line, int start, TrainingInstance* out, int lc = 0) {
+ picojson::value obj;
+ string err;
+ picojson::parse(obj, line.begin() + start, line.end(), &err);
+ if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); }
+ TrainingInstance& cur = *out;
+ TaggedSentence& ts = cur.ts;
+ EdgeSubset& tree = cur.tree;
+ assert(obj.is<picojson::object>());
+ const picojson::object& d = obj.get<picojson::object>();
+ const picojson::array& ta = d.find("tokens")->second.get<picojson::array>();
+ for (unsigned i = 0; i < ta.size(); ++i) {
+ ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>()));
+ ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>()));
+ }
+ if (d.find("deps") != d.end()) {
+ const picojson::array& da = d.find("deps")->second.get<picojson::array>();
+ for (unsigned i = 0; i < da.size(); ++i) {
+ const picojson::array& thm = da[i].get<picojson::array>();
+ // get dep type here
+ short h = thm[2].get<double>();
+ short m = thm[1].get<double>();
+ if (h < 0)
+ tree.roots.push_back(m);
+ else
+ tree.h_m_pairs.push_back(make_pair(h,m));
+ }
+ }
+ //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl;
+}
+
+bool TrainingInstance::ReadInstance(std::istream* in, TrainingInstance* instance) {
+ string line;
+ if (!getline(*in, line)) return false;
+ size_t pos = line.rfind('\t');
+ assert(pos != string::npos);
+ static int lc = 0; ++lc;
+ ParseInstance(line, pos + 1, instance, lc);
+ return true;
+}
+
+void TrainingInstance::ReadTrainingCorpus(const string& fname, vector<TrainingInstance>* corpus, int rank, int size) {
ReadFile rf(fname);
istream& in = *rf.stream();
string line;
- string err;
int lc = 0;
bool flag = false;
while(getline(in, line)) {
@@ -24,32 +64,8 @@ void TrainingInstance::ReadTraining(const string& fname, vector<TrainingInstance
if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; }
size_t pos = line.rfind('\t');
assert(pos != string::npos);
- picojson::value obj;
- picojson::parse(obj, line.begin() + pos, line.end(), &err);
- if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); }
corpus->push_back(TrainingInstance());
- TrainingInstance& cur = corpus->back();
- TaggedSentence& ts = cur.ts;
- EdgeSubset& tree = cur.tree;
- assert(obj.is<picojson::object>());
- const picojson::object& d = obj.get<picojson::object>();
- const picojson::array& ta = d.find("tokens")->second.get<picojson::array>();
- for (unsigned i = 0; i < ta.size(); ++i) {
- ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>()));
- ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>()));
- }
- const picojson::array& da = d.find("deps")->second.get<picojson::array>();
- for (unsigned i = 0; i < da.size(); ++i) {
- const picojson::array& thm = da[i].get<picojson::array>();
- // get dep type here
- short h = thm[2].get<double>();
- short m = thm[1].get<double>();
- if (h < 0)
- tree.roots.push_back(m);
- else
- tree.h_m_pairs.push_back(make_pair(h,m));
- }
- //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl;
+ ParseInstance(line, pos + 1, &corpus->back(), lc);
}
if (flag) cerr << "\nRead " << lc << " training instances\n";
}