summaryrefslogtreecommitdiff
path: root/src/collapse_weights.cc
blob: 5e0f3f7240882c430223c5f8b364f5f3ecdce3e2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <iostream>
#include <fstream>
#include <tr1/unordered_map>

#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
#include <boost/functional/hash.hpp>

#include "prob.h"
#include "filelib.h"
#include "trule.h"
#include "weights.h"

namespace po = boost::program_options;
using namespace std;

typedef std::tr1::unordered_map<vector<WordID>, prob_t, boost::hash<vector<WordID> > > MarginalMap;

void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
  po::options_description opts("Configuration options");
  opts.add_options()
        ("grammar,g", po::value<string>(), "Grammar file")
        ("weights,w", po::value<string>(), "Weights file");
  po::options_description clo("Command line options");
  clo.add_options()
        ("config,c", po::value<string>(), "Configuration file")
        ("help,h", "Print this help message and exit");
  po::options_description dconfig_options, dcmdline_options;
  dconfig_options.add(opts);
  dcmdline_options.add(opts).add(clo);

  po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
  if (conf->count("config")) {
    const string cfg = (*conf)["config"].as<string>();
    cerr << "Configuration file: " << cfg << endl;
    ifstream config(cfg.c_str());
    po::store(po::parse_config_file(config, dconfig_options), *conf);
  }
  po::notify(*conf);

  if (conf->count("help") || !conf->count("grammar") || !conf->count("weights")) {
    cerr << dcmdline_options << endl;
    exit(1);
  }
}

int main(int argc, char** argv) {
  po::variables_map conf;
  InitCommandLine(argc, argv, &conf);
  const string wfile = conf["weights"].as<string>();
  const string gfile = conf["grammar"].as<string>();
  Weights wm;
  wm.InitFromFile(wfile);
  vector<double> w;
  wm.InitVector(&w);
  MarginalMap e_tots;
  MarginalMap f_tots;
  prob_t tot;
  {
    ReadFile rf(gfile);
    assert(*rf.stream());
    istream& in = *rf.stream();
    cerr << "Computing marginals...\n";
    int lc = 0;
    while(in) {
      string line;
      getline(in, line);
      ++lc;
      if (line.empty()) continue;
      TRule tr(line, true);
      if (tr.GetFeatureValues().empty())
        cerr << "Line " << lc << ": empty features - may introduce bias\n";
      prob_t prob;
      prob.logeq(tr.GetFeatureValues().dot(w));
      e_tots[tr.e_] += prob;
      f_tots[tr.f_] += prob;
      tot += prob;
    }
  }
  bool normalized = (fabs(log(tot)) < 0.001);
  cerr << "Total: " << tot << (normalized ? " [normalized]" : " [scaled]") << endl;
  ReadFile rf(gfile);
  istream&in = *rf.stream();
  while(in) {
    string line;
    getline(in, line);
    if (line.empty()) continue;
    TRule tr(line, true);
    const double lp = tr.GetFeatureValues().dot(w);
    if (isinf(lp)) { continue; }
    tr.scores_.clear();

    cout << tr.AsString() << " ||| F_and_E=" << lp - log(tot);
    if (!normalized) {
      cout << ";ZF_and_E=" << lp;
    }
    cout << ";F_given_E=" << lp - log(e_tots[tr.e_])
         << ";E_given_F=" << lp - log(f_tots[tr.f_]) << endl;
  }
  return 0;
}