From 35f9ffa85fb68e0624111a76e27955524649950a Mon Sep 17 00:00:00 2001
From: Chris Dyer <cdyer@cs.cmu.edu>
Date: Thu, 19 Apr 2012 02:45:27 -0400
Subject: compute f

---
 rst_parser/Makefile.am     |   5 +-
 rst_parser/dep_training.cc |   4 ++
 rst_parser/rst_parse.cc    | 111 +++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 119 insertions(+), 1 deletion(-)
 create mode 100644 rst_parser/rst_parse.cc

(limited to 'rst_parser')

diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am
index 876c2237..4977f584 100644
--- a/rst_parser/Makefile.am
+++ b/rst_parser/Makefile.am
@@ -1,5 +1,5 @@
 bin_PROGRAMS = \
-  mst_train rst_train
+  mst_train rst_train rst_parse
 
 noinst_LIBRARIES = librst.a
 
@@ -11,4 +11,7 @@ mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/
 rst_train_SOURCES = rst_train.cc
 rst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
 
+rst_parse_SOURCES = rst_parse.cc
+rst_parse_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
+
 AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/training -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm
diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc
index e26505ec..ef97798b 100644
--- a/rst_parser/dep_training.cc
+++ b/rst_parser/dep_training.cc
@@ -18,6 +18,10 @@ static void ParseInstance(const string& line, int start, TrainingInstance* out,
   TrainingInstance& cur = *out;
   TaggedSentence& ts = cur.ts;
   EdgeSubset& tree = cur.tree;
+  ts.pos.clear();
+  ts.words.clear();
+  tree.roots.clear();
+  tree.h_m_pairs.clear();
   assert(obj.is<picojson::object>());
   const picojson::object& d = obj.get<picojson::object>();
   const picojson::array& ta = d.find("tokens")->second.get<picojson::array>();
diff --git a/rst_parser/rst_parse.cc b/rst_parser/rst_parse.cc
new file mode 100644
index 00000000..9c42a8f4
--- /dev/null
+++ b/rst_parser/rst_parse.cc
@@ -0,0 +1,111 @@
+#include "arc_factored.h"
+
+#include <vector>
+#include <iostream>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "timing_stats.h"
+#include "arc_ff.h"
+#include "dep_training.h"
+#include "stringlib.h"
+#include "filelib.h"
+#include "tdict.h"
+#include "weights.h"
+#include "rst.h"
+#include "global_ff.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+  po::options_description opts("Configuration options");
+  string cfg_file;
+  opts.add_options()
+        ("input,i",po::value<string>()->default_value("-"), "File containing test data (jsent format)")
+        ("q_weights,q",po::value<string>(), "Arc-factored weights for proposal distribution (mandatory)")
+        ("p_weights,p",po::value<string>(), "Weights for target distribution (optional)")
+        ("samples,n",po::value<unsigned>()->default_value(1000), "Number of samples");
+  po::options_description clo("Command line options");
+  clo.add_options()
+        ("config,c", po::value<string>(&cfg_file), "Configuration file")
+        ("help,?", "Print this help message and exit");
+
+  po::options_description dconfig_options, dcmdline_options;
+  dconfig_options.add(opts);
+  dcmdline_options.add(dconfig_options).add(clo);
+  po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+  if (cfg_file.size() > 0) {
+    ReadFile rf(cfg_file);
+    po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf);
+  }
+  if (conf->count("help") || conf->count("q_weights") == 0) {
+    cerr << dcmdline_options << endl;
+    exit(1);
+  }
+}
+
+int main(int argc, char** argv) {
+  po::variables_map conf;
+  InitCommandLine(argc, argv, &conf);
+  vector<weight_t> qweights, pweights;
+  Weights::InitFromFile(conf["q_weights"].as<string>(), &qweights);
+  if (conf.count("p_weights"))
+    Weights::InitFromFile(conf["p_weights"].as<string>(), &pweights);
+  const bool global = pweights.size() > 0;
+  ArcFeatureFunctions ffs;
+  GlobalFeatureFunctions gff;
+  ReadFile rf(conf["input"].as<string>());
+  istream* in = rf.stream();
+  TrainingInstance sent;
+  MT19937 rng;
+  int samples = conf["samples"].as<unsigned>();
+  int totroot = 0, root_right = 0, tot = 0, cor = 0;
+  while(TrainingInstance::ReadInstance(in, &sent)) {
+    ffs.PrepareForInput(sent.ts);
+    if (global) gff.PrepareForInput(sent.ts);
+    ArcFactoredForest forest(sent.ts.pos.size());
+    forest.ExtractFeatures(sent.ts, ffs);
+    forest.Reweight(qweights);
+    TreeSampler ts(forest);
+    double best_score = -numeric_limits<double>::infinity();
+    EdgeSubset best_tree;
+    for (int n = 0; n < samples; ++n) {
+      EdgeSubset tree;
+      ts.SampleRandomSpanningTree(&tree, &rng);
+      SparseVector<double> qfeats, gfeats;
+      tree.ExtractFeatures(sent.ts, ffs, &qfeats);
+      double score = 0;
+      if (global) {
+        gff.Features(sent.ts, tree, &gfeats);
+        score = (qfeats + gfeats).dot(pweights);
+      } else {
+        score = qfeats.dot(qweights);
+      }
+      if (score > best_score) {
+        best_tree = tree;
+        best_score = score;
+      }
+    }
+    cerr << "BEST SCORE: " << best_score << endl;
+    cout << best_tree << endl;
+    const bool sent_has_ref = sent.tree.h_m_pairs.size() > 0;
+    if (sent_has_ref) {
+      map<pair<short,short>, bool> ref;
+      for (int i = 0; i < sent.tree.h_m_pairs.size(); ++i)
+        ref[sent.tree.h_m_pairs[i]] = true;
+      int ref_root = sent.tree.roots.front();
+      if (ref_root == best_tree.roots.front()) { ++root_right; }
+      ++totroot;
+      for (int i = 0; i < best_tree.h_m_pairs.size(); ++i) {
+        if (ref[best_tree.h_m_pairs[i]]) {
+          ++cor;
+        }
+        ++tot;
+      }
+    }
+  }
+  cerr << "F = " << (double(cor + root_right) / (tot + totroot)) << endl;
+  return 0;
+}
+
-- 
cgit v1.2.3