diff options
-rw-r--r-- | rst_parser/dep_training.cc | 70 | ||||
-rw-r--r-- | rst_parser/dep_training.h | 4 | ||||
-rw-r--r-- | rst_parser/mst_train.cc | 2 | ||||
-rw-r--r-- | rst_parser/rst_train.cc | 2 |
4 files changed, 48 insertions, 30 deletions
diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc index de431ebc..e26505ec 100644 --- a/rst_parser/dep_training.cc +++ b/rst_parser/dep_training.cc @@ -10,11 +10,51 @@ using namespace std; -void TrainingInstance::ReadTraining(const string& fname, vector<TrainingInstance>* corpus, int rank, int size) { +static void ParseInstance(const string& line, int start, TrainingInstance* out, int lc = 0) { + picojson::value obj; + string err; + picojson::parse(obj, line.begin() + start, line.end(), &err); + if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); } + TrainingInstance& cur = *out; + TaggedSentence& ts = cur.ts; + EdgeSubset& tree = cur.tree; + assert(obj.is<picojson::object>()); + const picojson::object& d = obj.get<picojson::object>(); + const picojson::array& ta = d.find("tokens")->second.get<picojson::array>(); + for (unsigned i = 0; i < ta.size(); ++i) { + ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>())); + ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>())); + } + if (d.find("deps") != d.end()) { + const picojson::array& da = d.find("deps")->second.get<picojson::array>(); + for (unsigned i = 0; i < da.size(); ++i) { + const picojson::array& thm = da[i].get<picojson::array>(); + // get dep type here + short h = thm[2].get<double>(); + short m = thm[1].get<double>(); + if (h < 0) + tree.roots.push_back(m); + else + tree.h_m_pairs.push_back(make_pair(h,m)); + } + } + //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl; +} + +bool TrainingInstance::ReadInstance(std::istream* in, TrainingInstance* instance) { + string line; + if (!getline(*in, line)) return false; + size_t pos = line.rfind('\t'); + assert(pos != string::npos); + static int lc = 0; ++lc; + ParseInstance(line, pos + 1, instance, lc); + return true; +} + +void TrainingInstance::ReadTrainingCorpus(const string& fname, vector<TrainingInstance>* corpus, int rank, int size) { ReadFile rf(fname); istream& in = *rf.stream(); string line; - string err; int lc = 0; bool flag = false; while(getline(in, line)) { @@ -24,32 +64,8 @@ void TrainingInstance::ReadTraining(const string& fname, vector<TrainingInstance if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; } size_t pos = line.rfind('\t'); assert(pos != string::npos); - picojson::value obj; - picojson::parse(obj, line.begin() + pos, line.end(), &err); - if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); } corpus->push_back(TrainingInstance()); - TrainingInstance& cur = corpus->back(); - TaggedSentence& ts = cur.ts; - EdgeSubset& tree = cur.tree; - assert(obj.is<picojson::object>()); - const picojson::object& d = obj.get<picojson::object>(); - const picojson::array& ta = d.find("tokens")->second.get<picojson::array>(); - for (unsigned i = 0; i < ta.size(); ++i) { - ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>())); - ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>())); - } - const picojson::array& da = d.find("deps")->second.get<picojson::array>(); - for (unsigned i = 0; i < da.size(); ++i) { - const picojson::array& thm = da[i].get<picojson::array>(); - // get dep type here - short h = thm[2].get<double>(); - short m = thm[1].get<double>(); - if (h < 0) - tree.roots.push_back(m); - else - tree.h_m_pairs.push_back(make_pair(h,m)); - } - //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl; + ParseInstance(line, pos + 1, &corpus->back(), lc); } if (flag) cerr << "\nRead " << lc << " training instances\n"; } diff --git a/rst_parser/dep_training.h b/rst_parser/dep_training.h index 73ffd298..3eeee22e 100644 --- a/rst_parser/dep_training.h +++ b/rst_parser/dep_training.h @@ -1,6 +1,7 @@ #ifndef _DEP_TRAINING_H_ #define _DEP_TRAINING_H_ +#include <iostream> #include <string> #include <vector> #include "arc_factored.h" @@ -11,7 +12,8 @@ struct TrainingInstance { EdgeSubset tree; SparseVector<weight_t> features; // reads a "Jsent" formatted dependency file - static void ReadTraining(const std::string& fname, std::vector<TrainingInstance>* corpus, int rank = 0, int size = 1); + static bool ReadInstance(std::istream* in, TrainingInstance* instance); // returns false at EOF + static void ReadTrainingCorpus(const std::string& fname, std::vector<TrainingInstance>* corpus, int rank = 0, int size = 1); }; #endif diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index 0709e7c9..e414f450 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -74,7 +74,7 @@ int main(int argc, char** argv) { InitCommandLine(argc, argv, &conf); ArcFeatureFunctions ffs; vector<TrainingInstance> corpus; - TrainingInstance::ReadTraining(conf["training_data"].as<string>(), &corpus, rank, size); + TrainingInstance::ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus, rank, size); vector<ArcFactoredForest> forests(corpus.size()); SparseVector<double> empirical; bool flag = false; diff --git a/rst_parser/rst_train.cc b/rst_parser/rst_train.cc index 16673cdc..9b730f3d 100644 --- a/rst_parser/rst_train.cc +++ b/rst_parser/rst_train.cc @@ -52,7 +52,7 @@ int main(int argc, char** argv) { vector<TrainingInstance> corpus; ArcFeatureFunctions ffs; GlobalFeatureFunctions gff; - TrainingInstance::ReadTraining(conf["training_data"].as<string>(), &corpus); + TrainingInstance::ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus); vector<ArcFactoredForest> forests(corpus.size()); vector<prob_t> zs(corpus.size()); SparseVector<double> empirical; |