summaryrefslogtreecommitdiff
path: root/rst_parser/rst_parse.cc
blob: 9c42a8f49717165efbafa7948e688f057347c5d8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include "arc_factored.h"

#include <vector>
#include <iostream>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>

#include "timing_stats.h"
#include "arc_ff.h"
#include "dep_training.h"
#include "stringlib.h"
#include "filelib.h"
#include "tdict.h"
#include "weights.h"
#include "rst.h"
#include "global_ff.h"

using namespace std;
namespace po = boost::program_options;

void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
  po::options_description opts("Configuration options");
  string cfg_file;
  opts.add_options()
        ("input,i",po::value<string>()->default_value("-"), "File containing test data (jsent format)")
        ("q_weights,q",po::value<string>(), "Arc-factored weights for proposal distribution (mandatory)")
        ("p_weights,p",po::value<string>(), "Weights for target distribution (optional)")
        ("samples,n",po::value<unsigned>()->default_value(1000), "Number of samples");
  po::options_description clo("Command line options");
  clo.add_options()
        ("config,c", po::value<string>(&cfg_file), "Configuration file")
        ("help,?", "Print this help message and exit");

  po::options_description dconfig_options, dcmdline_options;
  dconfig_options.add(opts);
  dcmdline_options.add(dconfig_options).add(clo);
  po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
  if (cfg_file.size() > 0) {
    ReadFile rf(cfg_file);
    po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf);
  }
  if (conf->count("help") || conf->count("q_weights") == 0) {
    cerr << dcmdline_options << endl;
    exit(1);
  }
}

int main(int argc, char** argv) {
  po::variables_map conf;
  InitCommandLine(argc, argv, &conf);
  vector<weight_t> qweights, pweights;
  Weights::InitFromFile(conf["q_weights"].as<string>(), &qweights);
  if (conf.count("p_weights"))
    Weights::InitFromFile(conf["p_weights"].as<string>(), &pweights);
  const bool global = pweights.size() > 0;
  ArcFeatureFunctions ffs;
  GlobalFeatureFunctions gff;
  ReadFile rf(conf["input"].as<string>());
  istream* in = rf.stream();
  TrainingInstance sent;
  MT19937 rng;
  int samples = conf["samples"].as<unsigned>();
  int totroot = 0, root_right = 0, tot = 0, cor = 0;
  while(TrainingInstance::ReadInstance(in, &sent)) {
    ffs.PrepareForInput(sent.ts);
    if (global) gff.PrepareForInput(sent.ts);
    ArcFactoredForest forest(sent.ts.pos.size());
    forest.ExtractFeatures(sent.ts, ffs);
    forest.Reweight(qweights);
    TreeSampler ts(forest);
    double best_score = -numeric_limits<double>::infinity();
    EdgeSubset best_tree;
    for (int n = 0; n < samples; ++n) {
      EdgeSubset tree;
      ts.SampleRandomSpanningTree(&tree, &rng);
      SparseVector<double> qfeats, gfeats;
      tree.ExtractFeatures(sent.ts, ffs, &qfeats);
      double score = 0;
      if (global) {
        gff.Features(sent.ts, tree, &gfeats);
        score = (qfeats + gfeats).dot(pweights);
      } else {
        score = qfeats.dot(qweights);
      }
      if (score > best_score) {
        best_tree = tree;
        best_score = score;
      }
    }
    cerr << "BEST SCORE: " << best_score << endl;
    cout << best_tree << endl;
    const bool sent_has_ref = sent.tree.h_m_pairs.size() > 0;
    if (sent_has_ref) {
      map<pair<short,short>, bool> ref;
      for (int i = 0; i < sent.tree.h_m_pairs.size(); ++i)
        ref[sent.tree.h_m_pairs[i]] = true;
      int ref_root = sent.tree.roots.front();
      if (ref_root == best_tree.roots.front()) { ++root_right; }
      ++totroot;
      for (int i = 0; i < best_tree.h_m_pairs.size(); ++i) {
        if (ref[best_tree.h_m_pairs[i]]) {
          ++cor;
        }
        ++tot;
      }
    }
  }
  cerr << "F = " << (double(cor + root_right) / (tot + totroot)) << endl;
  return 0;
}