diff options
Diffstat (limited to 'klm/lm/ngram_query.cc')
| -rw-r--r-- | klm/lm/ngram_query.cc | 45 | 
1 files changed, 39 insertions, 6 deletions
| diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 74457a74..3fa8cb03 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -1,3 +1,4 @@ +#include "lm/enumerate_vocab.hh"  #include "lm/model.hh"  #include <cstdlib> @@ -44,29 +45,61 @@ template <class Model> void Query(const Model &model) {      bool got = false;      while (std::cin >> word) {        got = true; -      ret = model.FullScore(state, model.GetVocabulary().Index(word), out); +      lm::WordIndex vocab = model.GetVocabulary().Index(word); +      ret = model.FullScore(state, vocab, out);        total += ret.prob; -      std::cout << word << ' ' << static_cast<unsigned int>(ret.ngram_length)  << ' ' << ret.prob << ' '; +      std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length)  << ' ' << ret.prob << '\n';        state = out;        if (std::cin.get() == '\n') break;      }      if (!got && !std::cin) break;      ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);      total += ret.prob; -    std::cout << "</s> " << static_cast<unsigned int>(ret.ngram_length)  << ' ' << ret.prob << ' '; +    std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length)  << ' ' << ret.prob << '\n';      std::cout << "Total: " << total << '\n';    }    PrintUsage("After queries:\n");  } +class PrintVocab : public lm::ngram::EnumerateVocab { +  public: +    void Add(lm::WordIndex index, const StringPiece &str) { +      std::cerr << "vocab " << index << ' ' << str << '\n'; +    } +}; + +template <class Model> void Query(const char *name) { +  lm::ngram::Config config; +  PrintVocab printer; +  config.enumerate_vocab = &printer; +  Model model(name, config); +  Query(model); +} +  int main(int argc, char *argv[]) {    if (argc < 2) {      std::cerr << "Pass language model name." << std::endl;      return 0;    } -  { -    lm::ngram::Model ngram(argv[1]); -    Query(ngram); +  lm::ngram::ModelType model_type; +  if (lm::ngram::RecognizeBinary(argv[1], model_type)) { +    switch(model_type) { +      case lm::ngram::HASH_PROBING: +        Query<lm::ngram::ProbingModel>(argv[1]); +        break; +      case lm::ngram::HASH_SORTED: +        Query<lm::ngram::SortedModel>(argv[1]); +        break; +      case lm::ngram::TRIE_SORTED: +        Query<lm::ngram::TrieModel>(argv[1]); +        break; +      default: +        std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; +        abort(); +    } +  } else { +    Query<lm::ngram::ProbingModel>(argv[1]);    } +    PrintUsage("Total time including destruction:\n");  } | 
