diff options
Diffstat (limited to 'klm/lm/ngram_query.cc')
-rw-r--r-- | klm/lm/ngram_query.cc | 45 |
1 files changed, 39 insertions, 6 deletions
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 74457a74..3fa8cb03 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -1,3 +1,4 @@ +#include "lm/enumerate_vocab.hh" #include "lm/model.hh" #include <cstdlib> @@ -44,29 +45,61 @@ template <class Model> void Query(const Model &model) { bool got = false; while (std::cin >> word) { got = true; - ret = model.FullScore(state, model.GetVocabulary().Index(word), out); + lm::WordIndex vocab = model.GetVocabulary().Index(word); + ret = model.FullScore(state, vocab, out); total += ret.prob; - std::cout << word << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' '; + std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n'; state = out; if (std::cin.get() == '\n') break; } if (!got && !std::cin) break; ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); total += ret.prob; - std::cout << "</s> " << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' '; + std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n'; std::cout << "Total: " << total << '\n'; } PrintUsage("After queries:\n"); } +class PrintVocab : public lm::ngram::EnumerateVocab { + public: + void Add(lm::WordIndex index, const StringPiece &str) { + std::cerr << "vocab " << index << ' ' << str << '\n'; + } +}; + +template <class Model> void Query(const char *name) { + lm::ngram::Config config; + PrintVocab printer; + config.enumerate_vocab = &printer; + Model model(name, config); + Query(model); +} + int main(int argc, char *argv[]) { if (argc < 2) { std::cerr << "Pass language model name." << std::endl; return 0; } - { - lm::ngram::Model ngram(argv[1]); - Query(ngram); + lm::ngram::ModelType model_type; + if (lm::ngram::RecognizeBinary(argv[1], model_type)) { + switch(model_type) { + case lm::ngram::HASH_PROBING: + Query<lm::ngram::ProbingModel>(argv[1]); + break; + case lm::ngram::HASH_SORTED: + Query<lm::ngram::SortedModel>(argv[1]); + break; + case lm::ngram::TRIE_SORTED: + Query<lm::ngram::TrieModel>(argv[1]); + break; + default: + std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; + abort(); + } + } else { + Query<lm::ngram::ProbingModel>(argv[1]); } + PrintUsage("Total time including destruction:\n"); } |