summaryrefslogtreecommitdiff
path: root/training/test_ngram.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/test_ngram.cc')
-rw-r--r--training/test_ngram.cc130
1 files changed, 130 insertions, 0 deletions
diff --git a/training/test_ngram.cc b/training/test_ngram.cc
new file mode 100644
index 00000000..c481b564
--- /dev/null
+++ b/training/test_ngram.cc
@@ -0,0 +1,130 @@
+#include <fstream>
+#include <iostream>
+#include <vector>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "lm/model.hh"
+#include "lm/enumerate_vocab.hh"
+
+namespace po = boost::program_options;
+using namespace std;
+
+lm::ngram::ProbingModel* ngram;
+struct GetVocab : public lm::ngram::EnumerateVocab {
+ GetVocab(vector<lm::WordIndex>* out) : out_(out) { }
+ void Add(lm::WordIndex index, const StringPiece &str) {
+ out_->push_back(index);
+ }
+ vector<lm::WordIndex>* out_;
+};
+
+bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("model,m",po::value<string>(),"n-gram language model file (KLM)");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ po::positional_options_description p;
+ p.add("grammar", -1);
+
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(po::command_line_parser(argc, argv).options(dcmdline_options).positional(p).run(), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help")) {
+ cerr << "Usage " << argv[0] << " [OPTIONS]\n";
+ cerr << dcmdline_options << endl;
+ return false;
+ }
+ return true;
+}
+
+template <class Model> double BlanketProb(const vector<lm::WordIndex>& sentence, const lm::WordIndex word, const int subst_pos, const Model &model) {
+ typename Model::State state, out;
+ lm::FullScoreReturn ret;
+ double total = 0;
+ state = model.NullContextState();
+
+ const int begin = max(subst_pos - model.Order() + 1, 0);
+ const int end = min(subst_pos + model.Order(), (int)sentence.size());
+ int lookups = 0;
+ bool have_full_context = false;
+ for (int i = begin; i < end; ++i) {
+ if (i == 0) {
+ state = model.BeginSentenceState();
+ have_full_context = true;
+ } else {
+ lookups++;
+ if (lookups == model.Order()) { have_full_context = true; }
+ ret = model.FullScore(state, (subst_pos == i ? word : sentence[i]), out);
+ if (have_full_context) { total += ret.prob; }
+ state = out;
+ }
+ }
+ return total;
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ if (!InitCommandLine(argc, argv, &conf)) return 1;
+ lm::ngram::Config kconf;
+ vector<lm::WordIndex> vocab;
+ GetVocab gv(&vocab);
+ kconf.enumerate_vocab = &gv;
+ ngram = new lm::ngram::ProbingModel(conf["model"].as<string>().c_str(), kconf);
+ cerr << "Loaded " << (int)ngram->Order() << "-gram KenLM (vocab size=" << vocab.size() << ")\n";
+ vector<int> exclude(vocab.size(), 0);
+ exclude[0] = 1; // exclude OOVs
+
+ double prob_sum = 0;
+ int counter = 0;
+ int rank_error = 0;
+ string line;
+ while (getline(cin, line)) {
+ stringstream line_stream(line);
+ vector<string> tokens;
+ tokens.push_back("<s>");
+ string token;
+ while (line_stream >> token)
+ tokens.push_back(token);
+ tokens.push_back("</s>");
+
+ vector<lm::WordIndex> sentence(tokens.size());
+ for (int i = 0; i < tokens.size(); ++i)
+ sentence[i] = ngram->GetVocabulary().Index(tokens[i]);
+ exclude[sentence[0]] = 1;
+ exclude[sentence.back()] = 1;
+ for (int i = 1; i < tokens.size()-1; ++i) {
+ cerr << tokens[i] << endl;
+ ++counter;
+ lm::WordIndex gold = sentence[i];
+ double blanket_prob = BlanketProb<lm::ngram::ProbingModel>(sentence, gold, i, *ngram);
+ double z = 0;
+ for (int v = 0; v < vocab.size(); ++v) {
+ if (exclude[v]) continue;
+ double lp = BlanketProb<lm::ngram::ProbingModel>(sentence, v, i, *ngram);
+ if (lp > blanket_prob) ++rank_error;
+ z += pow(10.0, lp);
+ }
+ double post_prob = blanket_prob - log10(z);
+ cerr << " " << post_prob << endl;
+ prob_sum -= post_prob;
+ }
+ }
+ cerr << "perplexity=" << pow(10,prob_sum/(double)counter) << endl;
+ cerr << "Rank error=" << rank_error/(double)counter << endl;
+
+ return 0;
+}
+