summaryrefslogtreecommitdiff
path: root/training/test_ngram.cc
blob: 4597cc01a1586c81dabb3562be2e98aa1e3bfb4e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <fstream>
#include <iostream>
#include <vector>

#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>

#include "lm/model.hh"
#include "lm/enumerate_vocab.hh"

namespace po = boost::program_options;
using namespace std;

lm::ngram::ProbingModel* ngram;
struct GetVocab : public lm::EnumerateVocab {
  GetVocab(vector<lm::WordIndex>* out) : out_(out) { }
  void Add(lm::WordIndex index, const StringPiece &str) {
    out_->push_back(index);
  }
  vector<lm::WordIndex>* out_;
};

bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
  po::options_description opts("Configuration options");
  opts.add_options()
        ("model,m",po::value<string>(),"n-gram language model file (KLM)");
  po::options_description clo("Command line options");
  clo.add_options()
        ("config", po::value<string>(), "Configuration file")
        ("help,h", "Print this help message and exit");
  po::options_description dconfig_options, dcmdline_options;
  po::positional_options_description p;
  p.add("grammar", -1);
  
  dconfig_options.add(opts);
  dcmdline_options.add(opts).add(clo);

  po::store(po::command_line_parser(argc, argv).options(dcmdline_options).positional(p).run(), *conf);
  if (conf->count("config")) {
    ifstream config((*conf)["config"].as<string>().c_str());
    po::store(po::parse_config_file(config, dconfig_options), *conf);
  }
  po::notify(*conf);

  if (conf->count("help")) {
    cerr << "Usage " << argv[0] << " [OPTIONS]\n";
    cerr << dcmdline_options << endl;
    return false;
  }
  return true;
}

template <class Model> double BlanketProb(const vector<lm::WordIndex>& sentence, const lm::WordIndex word, const int subst_pos, const Model &model) {
  typename Model::State state, out;
  lm::FullScoreReturn ret;
  double total = 0;
  state = model.NullContextState();

  const int begin = max(subst_pos - model.Order() + 1, 0);
  const int end = min(subst_pos + model.Order(), (int)sentence.size());
  int lookups = 0;
  bool have_full_context = false;
  for (int i = begin; i < end; ++i) {
    if (i == 0) {
      state = model.BeginSentenceState();
      have_full_context = true;
    } else {
      lookups++;
      if (lookups == model.Order()) { have_full_context = true; }
      ret = model.FullScore(state, (subst_pos == i ? word : sentence[i]), out);
      if (have_full_context) { total += ret.prob; }
      state = out;
    }
  }
  return total;
}

int main(int argc, char** argv) {
  po::variables_map conf;
  if (!InitCommandLine(argc, argv, &conf)) return 1;
  lm::ngram::Config kconf;
  vector<lm::WordIndex> vocab;
  GetVocab gv(&vocab);
  kconf.enumerate_vocab = &gv; 
  ngram = new lm::ngram::ProbingModel(conf["model"].as<string>().c_str(), kconf);
  cerr << "Loaded " << (int)ngram->Order() << "-gram KenLM (vocab size=" << vocab.size() << ")\n";
  vector<int> exclude(vocab.size(), 0);
  exclude[0] = 1; // exclude OOVs

  double prob_sum = 0;
  int counter = 0;
  int rank_error = 0;
  string line;
  while (getline(cin, line)) {
    stringstream line_stream(line);
    vector<string> tokens;
    tokens.push_back("<s>");
    string token;
    while (line_stream >> token)
      tokens.push_back(token);
    tokens.push_back("</s>");

    vector<lm::WordIndex> sentence(tokens.size());
    for (int i = 0; i < tokens.size(); ++i)
      sentence[i] = ngram->GetVocabulary().Index(tokens[i]);
    exclude[sentence[0]] = 1;
    exclude[sentence.back()] = 1;
    for (int i = 1; i < tokens.size()-1; ++i) {
      cerr << tokens[i] << endl;
      ++counter;
      lm::WordIndex gold = sentence[i];
      double blanket_prob = BlanketProb<lm::ngram::ProbingModel>(sentence, gold, i, *ngram);
      double z = 0;
      for (int v = 0; v < vocab.size(); ++v) {
        if (exclude[v]) continue;
        double lp = BlanketProb<lm::ngram::ProbingModel>(sentence, v, i, *ngram);
        if (lp > blanket_prob) ++rank_error;
        z += pow(10.0, lp);
      }
      double post_prob = blanket_prob - log10(z);
      cerr << "  " << post_prob << endl;
      prob_sum -= post_prob;
    }
  }
  cerr << "perplexity=" << pow(10,prob_sum/(double)counter) << endl;
  cerr << "Rank error=" << rank_error/(double)counter << endl;

  return 0;
}