diff options
Diffstat (limited to 'rst_parser/mst_train.cc')
-rw-r--r-- | rst_parser/mst_train.cc | 58 |
1 files changed, 2 insertions, 56 deletions
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index c5cab6ec..f0403d7e 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -10,10 +10,9 @@ #include "stringlib.h" #include "filelib.h" #include "tdict.h" -#include "picojson.h" +#include "dep_training.h" #include "optimize.h" #include "weights.h" -#include "rst.h" using namespace std; namespace po = boost::program_options; @@ -47,56 +46,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } -struct TrainingInstance { - TaggedSentence ts; - EdgeSubset tree; - SparseVector<weight_t> features; -}; - -void ReadTraining(const string& fname, vector<TrainingInstance>* corpus, int rank = 0, int size = 1) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - string err; - int lc = 0; - bool flag = false; - while(getline(in, line)) { - ++lc; - if ((lc-1) % size != rank) continue; - if (rank == 0 && lc % 10 == 0) { cerr << '.' << flush; flag = true; } - if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; } - size_t pos = line.rfind('\t'); - assert(pos != string::npos); - picojson::value obj; - picojson::parse(obj, line.begin() + pos, line.end(), &err); - if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); } - corpus->push_back(TrainingInstance()); - TrainingInstance& cur = corpus->back(); - TaggedSentence& ts = cur.ts; - EdgeSubset& tree = cur.tree; - assert(obj.is<picojson::object>()); - const picojson::object& d = obj.get<picojson::object>(); - const picojson::array& ta = d.find("tokens")->second.get<picojson::array>(); - for (unsigned i = 0; i < ta.size(); ++i) { - ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>())); - ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>())); - } - const picojson::array& da = d.find("deps")->second.get<picojson::array>(); - for (unsigned i = 0; i < da.size(); ++i) { - const picojson::array& thm = da[i].get<picojson::array>(); - // get dep type here - short h = thm[2].get<double>(); - short m = thm[1].get<double>(); - if (h < 0) - tree.roots.push_back(m); - else - tree.h_m_pairs.push_back(make_pair(h,m)); - } - //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl; - } - if (flag) cerr << "\nRead " << lc << " training instances\n"; -} - void AddFeatures(double prob, const SparseVector<double>& fmap, vector<double>* g) { for (SparseVector<double>::const_iterator it = fmap.begin(); it != fmap.end(); ++it) (*g)[it->first] += it->second * prob; @@ -131,7 +80,7 @@ int main(int argc, char** argv) { vector<TrainingInstance> corpus; vector<boost::shared_ptr<ArcFeatureFunction> > ffs; ffs.push_back(boost::shared_ptr<ArcFeatureFunction>(new DistancePenalty(""))); - ReadTraining(conf["training_data"].as<string>(), &corpus, rank, size); + TrainingInstance::ReadTraining(conf["training_data"].as<string>(), &corpus, rank, size); vector<ArcFactoredForest> forests(corpus.size()); SparseVector<double> empirical; bool flag = false; @@ -224,9 +173,6 @@ int main(int argc, char** argv) { } if (converged) { cerr << "CONVERGED\n"; break; } } - forests[0].Reweight(weights); - TreeSampler ts(forests[0]); - EdgeSubset tt; ts.SampleRandomSpanningTree(&tt); return 0; } |