summaryrefslogtreecommitdiff
path: root/klm/lm/ngram_query.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/ngram_query.cc')
-rw-r--r--klm/lm/ngram_query.cc45
1 files changed, 39 insertions, 6 deletions
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc
index 74457a74..3fa8cb03 100644
--- a/klm/lm/ngram_query.cc
+++ b/klm/lm/ngram_query.cc
@@ -1,3 +1,4 @@
+#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
#include <cstdlib>
@@ -44,29 +45,61 @@ template <class Model> void Query(const Model &model) {
bool got = false;
while (std::cin >> word) {
got = true;
- ret = model.FullScore(state, model.GetVocabulary().Index(word), out);
+ lm::WordIndex vocab = model.GetVocabulary().Index(word);
+ ret = model.FullScore(state, vocab, out);
total += ret.prob;
- std::cout << word << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' ';
+ std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n';
state = out;
if (std::cin.get() == '\n') break;
}
if (!got && !std::cin) break;
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
- std::cout << "</s> " << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' ';
+ std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n';
std::cout << "Total: " << total << '\n';
}
PrintUsage("After queries:\n");
}
+class PrintVocab : public lm::ngram::EnumerateVocab {
+ public:
+ void Add(lm::WordIndex index, const StringPiece &str) {
+ std::cerr << "vocab " << index << ' ' << str << '\n';
+ }
+};
+
+template <class Model> void Query(const char *name) {
+ lm::ngram::Config config;
+ PrintVocab printer;
+ config.enumerate_vocab = &printer;
+ Model model(name, config);
+ Query(model);
+}
+
int main(int argc, char *argv[]) {
if (argc < 2) {
std::cerr << "Pass language model name." << std::endl;
return 0;
}
- {
- lm::ngram::Model ngram(argv[1]);
- Query(ngram);
+ lm::ngram::ModelType model_type;
+ if (lm::ngram::RecognizeBinary(argv[1], model_type)) {
+ switch(model_type) {
+ case lm::ngram::HASH_PROBING:
+ Query<lm::ngram::ProbingModel>(argv[1]);
+ break;
+ case lm::ngram::HASH_SORTED:
+ Query<lm::ngram::SortedModel>(argv[1]);
+ break;
+ case lm::ngram::TRIE_SORTED:
+ Query<lm::ngram::TrieModel>(argv[1]);
+ break;
+ default:
+ std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
+ abort();
+ }
+ } else {
+ Query<lm::ngram::ProbingModel>(argv[1]);
}
+
PrintUsage("Total time including destruction:\n");
}