summaryrefslogtreecommitdiff
path: root/rst_parser/rst_train.cc
diff options
context:
space:
mode:
Diffstat (limited to 'rst_parser/rst_train.cc')
-rw-r--r--rst_parser/rst_train.cc144
1 files changed, 144 insertions, 0 deletions
diff --git a/rst_parser/rst_train.cc b/rst_parser/rst_train.cc
new file mode 100644
index 00000000..16673cdc
--- /dev/null
+++ b/rst_parser/rst_train.cc
@@ -0,0 +1,144 @@
+#include "arc_factored.h"
+
+#include <vector>
+#include <iostream>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "timing_stats.h"
+#include "arc_ff.h"
+#include "dep_training.h"
+#include "stringlib.h"
+#include "filelib.h"
+#include "tdict.h"
+#include "weights.h"
+#include "rst.h"
+#include "global_ff.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ string cfg_file;
+ opts.add_options()
+ ("training_data,t",po::value<string>()->default_value("-"), "File containing training data (jsent format)")
+ ("q_weights,q",po::value<string>(), "Arc-factored weights for proposal distribution")
+ ("samples,n",po::value<unsigned>()->default_value(1000), "Number of samples");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config,c", po::value<string>(&cfg_file), "Configuration file")
+ ("help,?", "Print this help message and exit");
+
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(dconfig_options).add(clo);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (cfg_file.size() > 0) {
+ ReadFile rf(cfg_file);
+ po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf);
+ }
+ if (conf->count("help")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ vector<weight_t> qweights(FD::NumFeats(), 0.0);
+ Weights::InitFromFile(conf["q_weights"].as<string>(), &qweights);
+ vector<TrainingInstance> corpus;
+ ArcFeatureFunctions ffs;
+ GlobalFeatureFunctions gff;
+ TrainingInstance::ReadTraining(conf["training_data"].as<string>(), &corpus);
+ vector<ArcFactoredForest> forests(corpus.size());
+ vector<prob_t> zs(corpus.size());
+ SparseVector<double> empirical;
+ bool flag = false;
+ for (int i = 0; i < corpus.size(); ++i) {
+ TrainingInstance& cur = corpus[i];
+ if ((i+1) % 10 == 0) { cerr << '.' << flush; flag = true; }
+ if ((i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; }
+ SparseVector<weight_t> efmap;
+ ffs.PrepareForInput(cur.ts);
+ gff.PrepareForInput(cur.ts);
+ for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) {
+ efmap.clear();
+ ffs.EdgeFeatures(cur.ts, cur.tree.h_m_pairs[j].first,
+ cur.tree.h_m_pairs[j].second,
+ &efmap);
+ cur.features += efmap;
+ }
+ for (int j = 0; j < cur.tree.roots.size(); ++j) {
+ efmap.clear();
+ ffs.EdgeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap);
+ cur.features += efmap;
+ }
+ efmap.clear();
+ gff.Features(cur.ts, cur.tree, &efmap);
+ cur.features += efmap;
+ empirical += cur.features;
+ forests[i].resize(cur.ts.words.size());
+ forests[i].ExtractFeatures(cur.ts, ffs);
+ forests[i].Reweight(qweights);
+ forests[i].EdgeMarginals(&zs[i]);
+ zs[i] = prob_t::One() / zs[i];
+ // cerr << zs[i] << endl;
+ forests[i].Reweight(qweights); // EdgeMarginals overwrites edge_prob
+ }
+ if (flag) cerr << endl;
+ MT19937 rng;
+ SparseVector<double> model_exp;
+ SparseVector<double> weights;
+ Weights::InitSparseVector(qweights, &weights);
+ int samples = conf["samples"].as<unsigned>();
+ for (int i = 0; i < corpus.size(); ++i) {
+#if 0
+ forests[i].EdgeMarginals();
+ model_exp.clear();
+ for (int h = -1; h < num_words; ++h) {
+ for (int m = 0; m < num_words; ++m) {
+ if (h == m) continue;
+ const ArcFactoredForest::Edge& edge = forests[i](h,m);
+ const SparseVector<weight_t>& fmap = edge.features;
+ double prob = edge.edge_prob.as_float();
+ model_exp += fmap * prob;
+ }
+ }
+ cerr << "TRUE EXP: " << model_exp << endl;
+ forests[i].Reweight(weights);
+#endif
+
+ TreeSampler ts(forests[i]);
+ prob_t zhat = prob_t::Zero();
+ SparseVector<prob_t> sampled_exp;
+ for (int n = 0; n < samples; ++n) {
+ EdgeSubset tree;
+ ts.SampleRandomSpanningTree(&tree, &rng);
+ SparseVector<double> qfeats, gfeats;
+ tree.ExtractFeatures(corpus[i].ts, ffs, &qfeats);
+ prob_t u; u.logeq(qfeats.dot(qweights));
+ const prob_t q = u / zs[i]; // proposal mass
+ gff.Features(corpus[i].ts, tree, &gfeats);
+ SparseVector<double> tot_feats = qfeats + gfeats;
+ u.logeq(tot_feats.dot(weights));
+ prob_t w = u / q;
+ zhat += w;
+ for (SparseVector<double>::const_iterator it = tot_feats.begin(); it != tot_feats.end(); ++it)
+ sampled_exp.add_value(it->first, w * prob_t(it->second));
+ }
+ sampled_exp /= zhat;
+ SparseVector<double> tot_m;
+ for (SparseVector<prob_t>::const_iterator it = sampled_exp.begin(); it != sampled_exp.end(); ++it)
+ tot_m.add_value(it->first, it->second.as_float());
+ //cerr << "DIFF: " << (tot_m - corpus[i].features) << endl;
+ const double eta = 0.03;
+ weights -= (tot_m - corpus[i].features) * eta;
+ }
+ cerr << "WEIGHTS.\n";
+ cerr << weights << endl;
+ return 0;
+}
+