diff options
Diffstat (limited to 'klm/lm/ngram_query.cc')
-rw-r--r-- | klm/lm/ngram_query.cc | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 3fa8cb03..d6da02e3 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -6,6 +6,8 @@ #include <iostream> #include <string> +#include <ctype.h> + #include <sys/resource.h> #include <sys/time.h> @@ -43,35 +45,38 @@ template <class Model> void Query(const Model &model) { state = model.BeginSentenceState(); float total = 0.0; bool got = false; + unsigned int oov = 0; while (std::cin >> word) { got = true; lm::WordIndex vocab = model.GetVocabulary().Index(word); + if (vocab == 0) ++oov; ret = model.FullScore(state, vocab, out); total += ret.prob; std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n'; state = out; - if (std::cin.get() == '\n') break; + char c; + while (true) { + c = std::cin.get(); + if (!std::cin) break; + if (c == '\n') break; + if (!isspace(c)) { + std::cin.unget(); + break; + } + } + if (c == '\n') break; } if (!got && !std::cin) break; ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); total += ret.prob; std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n'; - std::cout << "Total: " << total << '\n'; + std::cout << "Total: " << total << " OOV: " << oov << '\n'; } PrintUsage("After queries:\n"); } -class PrintVocab : public lm::ngram::EnumerateVocab { - public: - void Add(lm::WordIndex index, const StringPiece &str) { - std::cerr << "vocab " << index << ' ' << str << '\n'; - } -}; - template <class Model> void Query(const char *name) { lm::ngram::Config config; - PrintVocab printer; - config.enumerate_vocab = &printer; Model model(name, config); Query(model); } |