summaryrefslogtreecommitdiff
path: root/klm/lm/ngram_query.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/ngram_query.cc')
-rw-r--r--klm/lm/ngram_query.cc27
1 files changed, 16 insertions, 11 deletions
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc
index 3fa8cb03..d6da02e3 100644
--- a/klm/lm/ngram_query.cc
+++ b/klm/lm/ngram_query.cc
@@ -6,6 +6,8 @@
#include <iostream>
#include <string>
+#include <ctype.h>
+
#include <sys/resource.h>
#include <sys/time.h>
@@ -43,35 +45,38 @@ template <class Model> void Query(const Model &model) {
state = model.BeginSentenceState();
float total = 0.0;
bool got = false;
+ unsigned int oov = 0;
while (std::cin >> word) {
got = true;
lm::WordIndex vocab = model.GetVocabulary().Index(word);
+ if (vocab == 0) ++oov;
ret = model.FullScore(state, vocab, out);
total += ret.prob;
std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n';
state = out;
- if (std::cin.get() == '\n') break;
+ char c;
+ while (true) {
+ c = std::cin.get();
+ if (!std::cin) break;
+ if (c == '\n') break;
+ if (!isspace(c)) {
+ std::cin.unget();
+ break;
+ }
+ }
+ if (c == '\n') break;
}
if (!got && !std::cin) break;
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n';
- std::cout << "Total: " << total << '\n';
+ std::cout << "Total: " << total << " OOV: " << oov << '\n';
}
PrintUsage("After queries:\n");
}
-class PrintVocab : public lm::ngram::EnumerateVocab {
- public:
- void Add(lm::WordIndex index, const StringPiece &str) {
- std::cerr << "vocab " << index << ' ' << str << '\n';
- }
-};
-
template <class Model> void Query(const char *name) {
lm::ngram::Config config;
- PrintVocab printer;
- config.enumerate_vocab = &printer;
Model model(name, config);
Query(model);
}