From bee6a3c3f6c54cf7449229488c6124dddc7e2f31 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 18 Jan 2011 15:55:40 -0500 Subject: new version of klm --- klm/lm/ngram_query.cc | 45 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) (limited to 'klm/lm/ngram_query.cc') diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 74457a74..3fa8cb03 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -1,3 +1,4 @@ +#include "lm/enumerate_vocab.hh" #include "lm/model.hh" #include @@ -44,29 +45,61 @@ template void Query(const Model &model) { bool got = false; while (std::cin >> word) { got = true; - ret = model.FullScore(state, model.GetVocabulary().Index(word), out); + lm::WordIndex vocab = model.GetVocabulary().Index(word); + ret = model.FullScore(state, vocab, out); total += ret.prob; - std::cout << word << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << ' '; + std::cout << word << '=' << vocab << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\n'; state = out; if (std::cin.get() == '\n') break; } if (!got && !std::cin) break; ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); total += ret.prob; - std::cout << " " << static_cast(ret.ngram_length) << ' ' << ret.prob << ' '; + std::cout << "=" << model.GetVocabulary().EndSentence() << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\n'; std::cout << "Total: " << total << '\n'; } PrintUsage("After queries:\n"); } +class PrintVocab : public lm::ngram::EnumerateVocab { + public: + void Add(lm::WordIndex index, const StringPiece &str) { + std::cerr << "vocab " << index << ' ' << str << '\n'; + } +}; + +template void Query(const char *name) { + lm::ngram::Config config; + PrintVocab printer; + config.enumerate_vocab = &printer; + Model model(name, config); + Query(model); +} + int main(int argc, char *argv[]) { if (argc < 2) { std::cerr << "Pass language model name." << std::endl; return 0; } - { - lm::ngram::Model ngram(argv[1]); - Query(ngram); + lm::ngram::ModelType model_type; + if (lm::ngram::RecognizeBinary(argv[1], model_type)) { + switch(model_type) { + case lm::ngram::HASH_PROBING: + Query(argv[1]); + break; + case lm::ngram::HASH_SORTED: + Query(argv[1]); + break; + case lm::ngram::TRIE_SORTED: + Query(argv[1]); + break; + default: + std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; + abort(); + } + } else { + Query(argv[1]); } + PrintUsage("Total time including destruction:\n"); } -- cgit v1.2.3