summaryrefslogtreecommitdiff
path: root/rst_parser/dep_training.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-04-16 18:20:33 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-04-16 18:20:33 -0400
commit4b38556c88c739de82b9c298261a262ec620280e (patch)
tree272124012fa38a359a7c0efb1da4499a1b4371b5 /rst_parser/dep_training.cc
parentf2fcf9e8aa0e5dee75fd08ee915488ec1a741975 (diff)
rst sampler
Diffstat (limited to 'rst_parser/dep_training.cc')
-rw-r--r--rst_parser/dep_training.cc56
1 files changed, 56 insertions, 0 deletions
diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc
new file mode 100644
index 00000000..de431ebc
--- /dev/null
+++ b/rst_parser/dep_training.cc
@@ -0,0 +1,56 @@
+#include "dep_training.h"
+
+#include <vector>
+#include <iostream>
+
+#include "stringlib.h"
+#include "filelib.h"
+#include "tdict.h"
+#include "picojson.h"
+
+using namespace std;
+
+void TrainingInstance::ReadTraining(const string& fname, vector<TrainingInstance>* corpus, int rank, int size) {
+ ReadFile rf(fname);
+ istream& in = *rf.stream();
+ string line;
+ string err;
+ int lc = 0;
+ bool flag = false;
+ while(getline(in, line)) {
+ ++lc;
+ if ((lc-1) % size != rank) continue;
+ if (rank == 0 && lc % 10 == 0) { cerr << '.' << flush; flag = true; }
+ if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; }
+ size_t pos = line.rfind('\t');
+ assert(pos != string::npos);
+ picojson::value obj;
+ picojson::parse(obj, line.begin() + pos, line.end(), &err);
+ if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); }
+ corpus->push_back(TrainingInstance());
+ TrainingInstance& cur = corpus->back();
+ TaggedSentence& ts = cur.ts;
+ EdgeSubset& tree = cur.tree;
+ assert(obj.is<picojson::object>());
+ const picojson::object& d = obj.get<picojson::object>();
+ const picojson::array& ta = d.find("tokens")->second.get<picojson::array>();
+ for (unsigned i = 0; i < ta.size(); ++i) {
+ ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>()));
+ ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>()));
+ }
+ const picojson::array& da = d.find("deps")->second.get<picojson::array>();
+ for (unsigned i = 0; i < da.size(); ++i) {
+ const picojson::array& thm = da[i].get<picojson::array>();
+ // get dep type here
+ short h = thm[2].get<double>();
+ short m = thm[1].get<double>();
+ if (h < 0)
+ tree.roots.push_back(m);
+ else
+ tree.h_m_pairs.push_back(make_pair(h,m));
+ }
+ //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl;
+ }
+ if (flag) cerr << "\nRead " << lc << " training instances\n";
+}
+