summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rst_parser/dep_training.cc70
-rw-r--r--rst_parser/dep_training.h4
-rw-r--r--rst_parser/mst_train.cc2
-rw-r--r--rst_parser/rst_train.cc2
4 files changed, 48 insertions, 30 deletions
diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc
index de431ebc..e26505ec 100644
--- a/rst_parser/dep_training.cc
+++ b/rst_parser/dep_training.cc
@@ -10,11 +10,51 @@
using namespace std;
-void TrainingInstance::ReadTraining(const string& fname, vector<TrainingInstance>* corpus, int rank, int size) {
+static void ParseInstance(const string& line, int start, TrainingInstance* out, int lc = 0) {
+ picojson::value obj;
+ string err;
+ picojson::parse(obj, line.begin() + start, line.end(), &err);
+ if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); }
+ TrainingInstance& cur = *out;
+ TaggedSentence& ts = cur.ts;
+ EdgeSubset& tree = cur.tree;
+ assert(obj.is<picojson::object>());
+ const picojson::object& d = obj.get<picojson::object>();
+ const picojson::array& ta = d.find("tokens")->second.get<picojson::array>();
+ for (unsigned i = 0; i < ta.size(); ++i) {
+ ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>()));
+ ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>()));
+ }
+ if (d.find("deps") != d.end()) {
+ const picojson::array& da = d.find("deps")->second.get<picojson::array>();
+ for (unsigned i = 0; i < da.size(); ++i) {
+ const picojson::array& thm = da[i].get<picojson::array>();
+ // get dep type here
+ short h = thm[2].get<double>();
+ short m = thm[1].get<double>();
+ if (h < 0)
+ tree.roots.push_back(m);
+ else
+ tree.h_m_pairs.push_back(make_pair(h,m));
+ }
+ }
+ //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl;
+}
+
+bool TrainingInstance::ReadInstance(std::istream* in, TrainingInstance* instance) {
+ string line;
+ if (!getline(*in, line)) return false;
+ size_t pos = line.rfind('\t');
+ assert(pos != string::npos);
+ static int lc = 0; ++lc;
+ ParseInstance(line, pos + 1, instance, lc);
+ return true;
+}
+
+void TrainingInstance::ReadTrainingCorpus(const string& fname, vector<TrainingInstance>* corpus, int rank, int size) {
ReadFile rf(fname);
istream& in = *rf.stream();
string line;
- string err;
int lc = 0;
bool flag = false;
while(getline(in, line)) {
@@ -24,32 +64,8 @@ void TrainingInstance::ReadTraining(const string& fname, vector<TrainingInstance
if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; }
size_t pos = line.rfind('\t');
assert(pos != string::npos);
- picojson::value obj;
- picojson::parse(obj, line.begin() + pos, line.end(), &err);
- if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); }
corpus->push_back(TrainingInstance());
- TrainingInstance& cur = corpus->back();
- TaggedSentence& ts = cur.ts;
- EdgeSubset& tree = cur.tree;
- assert(obj.is<picojson::object>());
- const picojson::object& d = obj.get<picojson::object>();
- const picojson::array& ta = d.find("tokens")->second.get<picojson::array>();
- for (unsigned i = 0; i < ta.size(); ++i) {
- ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>()));
- ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>()));
- }
- const picojson::array& da = d.find("deps")->second.get<picojson::array>();
- for (unsigned i = 0; i < da.size(); ++i) {
- const picojson::array& thm = da[i].get<picojson::array>();
- // get dep type here
- short h = thm[2].get<double>();
- short m = thm[1].get<double>();
- if (h < 0)
- tree.roots.push_back(m);
- else
- tree.h_m_pairs.push_back(make_pair(h,m));
- }
- //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl;
+ ParseInstance(line, pos + 1, &corpus->back(), lc);
}
if (flag) cerr << "\nRead " << lc << " training instances\n";
}
diff --git a/rst_parser/dep_training.h b/rst_parser/dep_training.h
index 73ffd298..3eeee22e 100644
--- a/rst_parser/dep_training.h
+++ b/rst_parser/dep_training.h
@@ -1,6 +1,7 @@
#ifndef _DEP_TRAINING_H_
#define _DEP_TRAINING_H_
+#include <iostream>
#include <string>
#include <vector>
#include "arc_factored.h"
@@ -11,7 +12,8 @@ struct TrainingInstance {
EdgeSubset tree;
SparseVector<weight_t> features;
// reads a "Jsent" formatted dependency file
- static void ReadTraining(const std::string& fname, std::vector<TrainingInstance>* corpus, int rank = 0, int size = 1);
+ static bool ReadInstance(std::istream* in, TrainingInstance* instance); // returns false at EOF
+ static void ReadTrainingCorpus(const std::string& fname, std::vector<TrainingInstance>* corpus, int rank = 0, int size = 1);
};
#endif
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc
index 0709e7c9..e414f450 100644
--- a/rst_parser/mst_train.cc
+++ b/rst_parser/mst_train.cc
@@ -74,7 +74,7 @@ int main(int argc, char** argv) {
InitCommandLine(argc, argv, &conf);
ArcFeatureFunctions ffs;
vector<TrainingInstance> corpus;
- TrainingInstance::ReadTraining(conf["training_data"].as<string>(), &corpus, rank, size);
+ TrainingInstance::ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus, rank, size);
vector<ArcFactoredForest> forests(corpus.size());
SparseVector<double> empirical;
bool flag = false;
diff --git a/rst_parser/rst_train.cc b/rst_parser/rst_train.cc
index 16673cdc..9b730f3d 100644
--- a/rst_parser/rst_train.cc
+++ b/rst_parser/rst_train.cc
@@ -52,7 +52,7 @@ int main(int argc, char** argv) {
vector<TrainingInstance> corpus;
ArcFeatureFunctions ffs;
GlobalFeatureFunctions gff;
- TrainingInstance::ReadTraining(conf["training_data"].as<string>(), &corpus);
+ TrainingInstance::ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus);
vector<ArcFactoredForest> forests(corpus.size());
vector<prob_t> zs(corpus.size());
SparseVector<double> empirical;