diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-05-23 17:57:28 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-05-23 17:57:28 -0400 |
commit | 83ae8c0d5b445d049b9b3f353f408756ba23ce3b (patch) | |
tree | 4742da4057959f1669233eda8b188e2cffef2c94 /training | |
parent | bc95fedbaa083d557840db6ac2cbf14e2a3eccce (diff) |
calculate word perplexities
Diffstat (limited to 'training')
-rw-r--r-- | training/Makefile.am | 4 | ||||
-rw-r--r-- | training/test_ngram.cc | 130 |
2 files changed, 134 insertions, 0 deletions
diff --git a/training/Makefile.am b/training/Makefile.am index 5697043b..0d9085e4 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -1,5 +1,6 @@ bin_PROGRAMS = \ model1 \ + test_ngram \ mr_em_map_adapter \ mr_em_adapted_reduce \ mr_reduce_to_weights \ @@ -39,6 +40,9 @@ cllh_filter_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval augment_grammar_SOURCES = augment_grammar.cc augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +test_ngram_SOURCES = test_ngram.cc +test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + atools_SOURCES = atools.cc atools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz diff --git a/training/test_ngram.cc b/training/test_ngram.cc new file mode 100644 index 00000000..c481b564 --- /dev/null +++ b/training/test_ngram.cc @@ -0,0 +1,130 @@ +#include <fstream> +#include <iostream> +#include <vector> + +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "lm/model.hh" +#include "lm/enumerate_vocab.hh" + +namespace po = boost::program_options; +using namespace std; + +lm::ngram::ProbingModel* ngram; +struct GetVocab : public lm::ngram::EnumerateVocab { + GetVocab(vector<lm::WordIndex>* out) : out_(out) { } + void Add(lm::WordIndex index, const StringPiece &str) { + out_->push_back(index); + } + vector<lm::WordIndex>* out_; +}; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("model,m",po::value<string>(),"n-gram language model file (KLM)"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + po::positional_options_description p; + p.add("grammar", -1); + + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(po::command_line_parser(argc, argv).options(dcmdline_options).positional(p).run(), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help")) { + cerr << "Usage " << argv[0] << " [OPTIONS]\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +template <class Model> double BlanketProb(const vector<lm::WordIndex>& sentence, const lm::WordIndex word, const int subst_pos, const Model &model) { + typename Model::State state, out; + lm::FullScoreReturn ret; + double total = 0; + state = model.NullContextState(); + + const int begin = max(subst_pos - model.Order() + 1, 0); + const int end = min(subst_pos + model.Order(), (int)sentence.size()); + int lookups = 0; + bool have_full_context = false; + for (int i = begin; i < end; ++i) { + if (i == 0) { + state = model.BeginSentenceState(); + have_full_context = true; + } else { + lookups++; + if (lookups == model.Order()) { have_full_context = true; } + ret = model.FullScore(state, (subst_pos == i ? word : sentence[i]), out); + if (have_full_context) { total += ret.prob; } + state = out; + } + } + return total; +} + +int main(int argc, char** argv) { + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + lm::ngram::Config kconf; + vector<lm::WordIndex> vocab; + GetVocab gv(&vocab); + kconf.enumerate_vocab = &gv; + ngram = new lm::ngram::ProbingModel(conf["model"].as<string>().c_str(), kconf); + cerr << "Loaded " << (int)ngram->Order() << "-gram KenLM (vocab size=" << vocab.size() << ")\n"; + vector<int> exclude(vocab.size(), 0); + exclude[0] = 1; // exclude OOVs + + double prob_sum = 0; + int counter = 0; + int rank_error = 0; + string line; + while (getline(cin, line)) { + stringstream line_stream(line); + vector<string> tokens; + tokens.push_back("<s>"); + string token; + while (line_stream >> token) + tokens.push_back(token); + tokens.push_back("</s>"); + + vector<lm::WordIndex> sentence(tokens.size()); + for (int i = 0; i < tokens.size(); ++i) + sentence[i] = ngram->GetVocabulary().Index(tokens[i]); + exclude[sentence[0]] = 1; + exclude[sentence.back()] = 1; + for (int i = 1; i < tokens.size()-1; ++i) { + cerr << tokens[i] << endl; + ++counter; + lm::WordIndex gold = sentence[i]; + double blanket_prob = BlanketProb<lm::ngram::ProbingModel>(sentence, gold, i, *ngram); + double z = 0; + for (int v = 0; v < vocab.size(); ++v) { + if (exclude[v]) continue; + double lp = BlanketProb<lm::ngram::ProbingModel>(sentence, v, i, *ngram); + if (lp > blanket_prob) ++rank_error; + z += pow(10.0, lp); + } + double post_prob = blanket_prob - log10(z); + cerr << " " << post_prob << endl; + prob_sum -= post_prob; + } + } + cerr << "perplexity=" << pow(10,prob_sum/(double)counter) << endl; + cerr << "Rank error=" << rank_error/(double)counter << endl; + + return 0; +} + |