summaryrefslogtreecommitdiff
path: root/rst_parser/mst_train.cc
diff options
context:
space:
mode:
Diffstat (limited to 'rst_parser/mst_train.cc')
-rw-r--r--rst_parser/mst_train.cc58
1 files changed, 2 insertions, 56 deletions
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc
index c5cab6ec..f0403d7e 100644
--- a/rst_parser/mst_train.cc
+++ b/rst_parser/mst_train.cc
@@ -10,10 +10,9 @@
#include "stringlib.h"
#include "filelib.h"
#include "tdict.h"
-#include "picojson.h"
+#include "dep_training.h"
#include "optimize.h"
#include "weights.h"
-#include "rst.h"
using namespace std;
namespace po = boost::program_options;
@@ -47,56 +46,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
}
}
-struct TrainingInstance {
- TaggedSentence ts;
- EdgeSubset tree;
- SparseVector<weight_t> features;
-};
-
-void ReadTraining(const string& fname, vector<TrainingInstance>* corpus, int rank = 0, int size = 1) {
- ReadFile rf(fname);
- istream& in = *rf.stream();
- string line;
- string err;
- int lc = 0;
- bool flag = false;
- while(getline(in, line)) {
- ++lc;
- if ((lc-1) % size != rank) continue;
- if (rank == 0 && lc % 10 == 0) { cerr << '.' << flush; flag = true; }
- if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; }
- size_t pos = line.rfind('\t');
- assert(pos != string::npos);
- picojson::value obj;
- picojson::parse(obj, line.begin() + pos, line.end(), &err);
- if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); }
- corpus->push_back(TrainingInstance());
- TrainingInstance& cur = corpus->back();
- TaggedSentence& ts = cur.ts;
- EdgeSubset& tree = cur.tree;
- assert(obj.is<picojson::object>());
- const picojson::object& d = obj.get<picojson::object>();
- const picojson::array& ta = d.find("tokens")->second.get<picojson::array>();
- for (unsigned i = 0; i < ta.size(); ++i) {
- ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>()));
- ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>()));
- }
- const picojson::array& da = d.find("deps")->second.get<picojson::array>();
- for (unsigned i = 0; i < da.size(); ++i) {
- const picojson::array& thm = da[i].get<picojson::array>();
- // get dep type here
- short h = thm[2].get<double>();
- short m = thm[1].get<double>();
- if (h < 0)
- tree.roots.push_back(m);
- else
- tree.h_m_pairs.push_back(make_pair(h,m));
- }
- //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl;
- }
- if (flag) cerr << "\nRead " << lc << " training instances\n";
-}
-
void AddFeatures(double prob, const SparseVector<double>& fmap, vector<double>* g) {
for (SparseVector<double>::const_iterator it = fmap.begin(); it != fmap.end(); ++it)
(*g)[it->first] += it->second * prob;
@@ -131,7 +80,7 @@ int main(int argc, char** argv) {
vector<TrainingInstance> corpus;
vector<boost::shared_ptr<ArcFeatureFunction> > ffs;
ffs.push_back(boost::shared_ptr<ArcFeatureFunction>(new DistancePenalty("")));
- ReadTraining(conf["training_data"].as<string>(), &corpus, rank, size);
+ TrainingInstance::ReadTraining(conf["training_data"].as<string>(), &corpus, rank, size);
vector<ArcFactoredForest> forests(corpus.size());
SparseVector<double> empirical;
bool flag = false;
@@ -224,9 +173,6 @@ int main(int argc, char** argv) {
}
if (converged) { cerr << "CONVERGED\n"; break; }
}
- forests[0].Reweight(weights);
- TreeSampler ts(forests[0]);
- EdgeSubset tt; ts.SampleRandomSpanningTree(&tt);
return 0;
}