summaryrefslogtreecommitdiff
path: root/klm/lm/ngram_query.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/ngram_query.hh')
-rw-r--r--klm/lm/ngram_query.hh17
1 files changed, 12 insertions, 5 deletions
diff --git a/klm/lm/ngram_query.hh b/klm/lm/ngram_query.hh
index dfcda170..ec2590f4 100644
--- a/klm/lm/ngram_query.hh
+++ b/klm/lm/ngram_query.hh
@@ -11,21 +11,25 @@
#include <istream>
#include <string>
+#include <math.h>
+
namespace lm {
namespace ngram {
template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
- std::cerr << "Loading statistics:\n";
- util::PrintUsage(std::cerr);
typename Model::State state, out;
lm::FullScoreReturn ret;
std::string word;
+ double corpus_total = 0.0;
+ uint64_t corpus_oov = 0;
+ uint64_t corpus_tokens = 0;
+
while (in_stream) {
state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
float total = 0.0;
bool got = false;
- unsigned int oov = 0;
+ uint64_t oov = 0;
while (in_stream >> word) {
got = true;
lm::WordIndex vocab = model.GetVocabulary().Index(word);
@@ -33,6 +37,7 @@ template <class Model> void Query(const Model &model, bool sentence_context, std
ret = model.FullScore(state, vocab, out);
total += ret.prob;
out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
+ ++corpus_tokens;
state = out;
char c;
while (true) {
@@ -50,12 +55,14 @@ template <class Model> void Query(const Model &model, bool sentence_context, std
if (sentence_context) {
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
+ ++corpus_tokens;
out_stream << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
}
out_stream << "Total: " << total << " OOV: " << oov << '\n';
+ corpus_total += total;
+ corpus_oov += oov;
}
- std::cerr << "After queries:\n";
- util::PrintUsage(std::cerr);
+ out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << std::endl;
}
template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {