diff options
| -rw-r--r-- | rst_parser/Makefile.am | 5 | ||||
| -rw-r--r-- | rst_parser/dep_training.cc | 4 | ||||
| -rw-r--r-- | rst_parser/rst_parse.cc | 111 | 
3 files changed, 119 insertions, 1 deletions
| diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am index 876c2237..4977f584 100644 --- a/rst_parser/Makefile.am +++ b/rst_parser/Makefile.am @@ -1,5 +1,5 @@  bin_PROGRAMS = \ -  mst_train rst_train +  mst_train rst_train rst_parse  noinst_LIBRARIES = librst.a @@ -11,4 +11,7 @@ mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/  rst_train_SOURCES = rst_train.cc  rst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +rst_parse_SOURCES = rst_parse.cc +rst_parse_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +  AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/training -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc index e26505ec..ef97798b 100644 --- a/rst_parser/dep_training.cc +++ b/rst_parser/dep_training.cc @@ -18,6 +18,10 @@ static void ParseInstance(const string& line, int start, TrainingInstance* out,    TrainingInstance& cur = *out;    TaggedSentence& ts = cur.ts;    EdgeSubset& tree = cur.tree; +  ts.pos.clear(); +  ts.words.clear(); +  tree.roots.clear(); +  tree.h_m_pairs.clear();    assert(obj.is<picojson::object>());    const picojson::object& d = obj.get<picojson::object>();    const picojson::array& ta = d.find("tokens")->second.get<picojson::array>(); diff --git a/rst_parser/rst_parse.cc b/rst_parser/rst_parse.cc new file mode 100644 index 00000000..9c42a8f4 --- /dev/null +++ b/rst_parser/rst_parse.cc @@ -0,0 +1,111 @@ +#include "arc_factored.h" + +#include <vector> +#include <iostream> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "timing_stats.h" +#include "arc_ff.h" +#include "dep_training.h" +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "weights.h" +#include "rst.h" +#include "global_ff.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +  po::options_description opts("Configuration options"); +  string cfg_file; +  opts.add_options() +        ("input,i",po::value<string>()->default_value("-"), "File containing test data (jsent format)") +        ("q_weights,q",po::value<string>(), "Arc-factored weights for proposal distribution (mandatory)") +        ("p_weights,p",po::value<string>(), "Weights for target distribution (optional)") +        ("samples,n",po::value<unsigned>()->default_value(1000), "Number of samples"); +  po::options_description clo("Command line options"); +  clo.add_options() +        ("config,c", po::value<string>(&cfg_file), "Configuration file") +        ("help,?", "Print this help message and exit"); + +  po::options_description dconfig_options, dcmdline_options; +  dconfig_options.add(opts); +  dcmdline_options.add(dconfig_options).add(clo); +  po::store(parse_command_line(argc, argv, dcmdline_options), *conf); +  if (cfg_file.size() > 0) { +    ReadFile rf(cfg_file); +    po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); +  } +  if (conf->count("help") || conf->count("q_weights") == 0) { +    cerr << dcmdline_options << endl; +    exit(1); +  } +} + +int main(int argc, char** argv) { +  po::variables_map conf; +  InitCommandLine(argc, argv, &conf); +  vector<weight_t> qweights, pweights; +  Weights::InitFromFile(conf["q_weights"].as<string>(), &qweights); +  if (conf.count("p_weights")) +    Weights::InitFromFile(conf["p_weights"].as<string>(), &pweights); +  const bool global = pweights.size() > 0; +  ArcFeatureFunctions ffs; +  GlobalFeatureFunctions gff; +  ReadFile rf(conf["input"].as<string>()); +  istream* in = rf.stream(); +  TrainingInstance sent; +  MT19937 rng; +  int samples = conf["samples"].as<unsigned>(); +  int totroot = 0, root_right = 0, tot = 0, cor = 0; +  while(TrainingInstance::ReadInstance(in, &sent)) { +    ffs.PrepareForInput(sent.ts); +    if (global) gff.PrepareForInput(sent.ts); +    ArcFactoredForest forest(sent.ts.pos.size()); +    forest.ExtractFeatures(sent.ts, ffs); +    forest.Reweight(qweights); +    TreeSampler ts(forest); +    double best_score = -numeric_limits<double>::infinity(); +    EdgeSubset best_tree; +    for (int n = 0; n < samples; ++n) { +      EdgeSubset tree; +      ts.SampleRandomSpanningTree(&tree, &rng); +      SparseVector<double> qfeats, gfeats; +      tree.ExtractFeatures(sent.ts, ffs, &qfeats); +      double score = 0; +      if (global) { +        gff.Features(sent.ts, tree, &gfeats); +        score = (qfeats + gfeats).dot(pweights); +      } else { +        score = qfeats.dot(qweights); +      } +      if (score > best_score) { +        best_tree = tree; +        best_score = score; +      } +    } +    cerr << "BEST SCORE: " << best_score << endl; +    cout << best_tree << endl; +    const bool sent_has_ref = sent.tree.h_m_pairs.size() > 0; +    if (sent_has_ref) { +      map<pair<short,short>, bool> ref; +      for (int i = 0; i < sent.tree.h_m_pairs.size(); ++i) +        ref[sent.tree.h_m_pairs[i]] = true; +      int ref_root = sent.tree.roots.front(); +      if (ref_root == best_tree.roots.front()) { ++root_right; } +      ++totroot; +      for (int i = 0; i < best_tree.h_m_pairs.size(); ++i) { +        if (ref[best_tree.h_m_pairs[i]]) { +          ++cor; +        } +        ++tot; +      } +    } +  } +  cerr << "F = " << (double(cor + root_right) / (tot + totroot)) << endl; +  return 0; +} + | 
