summaryrefslogtreecommitdiff
path: root/rst_parser/rst_parse.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2012-04-23 21:44:02 +0200
committerPatrick Simianer <p@simianer.de>2012-04-23 21:44:02 +0200
commit2f427278616cbe3fa6f56d6b97c40b3894dbd950 (patch)
tree6998435e4677437c474cf0f835ce9f72d70d3945 /rst_parser/rst_parse.cc
parent6d0d0eb6bbfaee6b6998659a55e2195977ccd217 (diff)
parent217c4aaeba1c9f19b3420b526235bffd86c7a92b (diff)
Merge remote-tracking branch 'upstream/master'
Conflicts: Makefile.am configure.ac
Diffstat (limited to 'rst_parser/rst_parse.cc')
-rw-r--r--rst_parser/rst_parse.cc111
1 files changed, 111 insertions, 0 deletions
diff --git a/rst_parser/rst_parse.cc b/rst_parser/rst_parse.cc
new file mode 100644
index 00000000..9c42a8f4
--- /dev/null
+++ b/rst_parser/rst_parse.cc
@@ -0,0 +1,111 @@
+#include "arc_factored.h"
+
+#include <vector>
+#include <iostream>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "timing_stats.h"
+#include "arc_ff.h"
+#include "dep_training.h"
+#include "stringlib.h"
+#include "filelib.h"
+#include "tdict.h"
+#include "weights.h"
+#include "rst.h"
+#include "global_ff.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ string cfg_file;
+ opts.add_options()
+ ("input,i",po::value<string>()->default_value("-"), "File containing test data (jsent format)")
+ ("q_weights,q",po::value<string>(), "Arc-factored weights for proposal distribution (mandatory)")
+ ("p_weights,p",po::value<string>(), "Weights for target distribution (optional)")
+ ("samples,n",po::value<unsigned>()->default_value(1000), "Number of samples");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config,c", po::value<string>(&cfg_file), "Configuration file")
+ ("help,?", "Print this help message and exit");
+
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(dconfig_options).add(clo);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (cfg_file.size() > 0) {
+ ReadFile rf(cfg_file);
+ po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf);
+ }
+ if (conf->count("help") || conf->count("q_weights") == 0) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ vector<weight_t> qweights, pweights;
+ Weights::InitFromFile(conf["q_weights"].as<string>(), &qweights);
+ if (conf.count("p_weights"))
+ Weights::InitFromFile(conf["p_weights"].as<string>(), &pweights);
+ const bool global = pweights.size() > 0;
+ ArcFeatureFunctions ffs;
+ GlobalFeatureFunctions gff;
+ ReadFile rf(conf["input"].as<string>());
+ istream* in = rf.stream();
+ TrainingInstance sent;
+ MT19937 rng;
+ int samples = conf["samples"].as<unsigned>();
+ int totroot = 0, root_right = 0, tot = 0, cor = 0;
+ while(TrainingInstance::ReadInstance(in, &sent)) {
+ ffs.PrepareForInput(sent.ts);
+ if (global) gff.PrepareForInput(sent.ts);
+ ArcFactoredForest forest(sent.ts.pos.size());
+ forest.ExtractFeatures(sent.ts, ffs);
+ forest.Reweight(qweights);
+ TreeSampler ts(forest);
+ double best_score = -numeric_limits<double>::infinity();
+ EdgeSubset best_tree;
+ for (int n = 0; n < samples; ++n) {
+ EdgeSubset tree;
+ ts.SampleRandomSpanningTree(&tree, &rng);
+ SparseVector<double> qfeats, gfeats;
+ tree.ExtractFeatures(sent.ts, ffs, &qfeats);
+ double score = 0;
+ if (global) {
+ gff.Features(sent.ts, tree, &gfeats);
+ score = (qfeats + gfeats).dot(pweights);
+ } else {
+ score = qfeats.dot(qweights);
+ }
+ if (score > best_score) {
+ best_tree = tree;
+ best_score = score;
+ }
+ }
+ cerr << "BEST SCORE: " << best_score << endl;
+ cout << best_tree << endl;
+ const bool sent_has_ref = sent.tree.h_m_pairs.size() > 0;
+ if (sent_has_ref) {
+ map<pair<short,short>, bool> ref;
+ for (int i = 0; i < sent.tree.h_m_pairs.size(); ++i)
+ ref[sent.tree.h_m_pairs[i]] = true;
+ int ref_root = sent.tree.roots.front();
+ if (ref_root == best_tree.roots.front()) { ++root_right; }
+ ++totroot;
+ for (int i = 0; i < best_tree.h_m_pairs.size(); ++i) {
+ if (ref[best_tree.h_m_pairs[i]]) {
+ ++cor;
+ }
+ ++tot;
+ }
+ }
+ }
+ cerr << "F = " << (double(cor + root_right) / (tot + totroot)) << endl;
+ return 0;
+}
+