diff options
Diffstat (limited to 'klm/lm/wrappers')
-rw-r--r-- | klm/lm/wrappers/README | 3 | ||||
-rw-r--r-- | klm/lm/wrappers/nplm.cc | 90 | ||||
-rw-r--r-- | klm/lm/wrappers/nplm.hh | 83 |
3 files changed, 176 insertions, 0 deletions
diff --git a/klm/lm/wrappers/README b/klm/lm/wrappers/README new file mode 100644 index 00000000..56c34c23 --- /dev/null +++ b/klm/lm/wrappers/README @@ -0,0 +1,3 @@ +This directory is for wrappers around other people's LMs, presenting an interface similar to KenLM's. You will need to have their LM installed. + +NPLM is a work in progress. diff --git a/klm/lm/wrappers/nplm.cc b/klm/lm/wrappers/nplm.cc new file mode 100644 index 00000000..70622bd2 --- /dev/null +++ b/klm/lm/wrappers/nplm.cc @@ -0,0 +1,90 @@ +#include "lm/wrappers/nplm.hh" +#include "util/exception.hh" +#include "util/file.hh" + +#include <algorithm> + +#include <string.h> + +#include "neuralLM.h" + +namespace lm { +namespace np { + +Vocabulary::Vocabulary(const nplm::vocabulary &vocab) + : base::Vocabulary(vocab.lookup_word("<s>"), vocab.lookup_word("</s>"), vocab.lookup_word("<unk>")), + vocab_(vocab), null_word_(vocab.lookup_word("<null>")) {} + +Vocabulary::~Vocabulary() {} + +WordIndex Vocabulary::Index(const std::string &str) const { + return vocab_.lookup_word(str); +} + +bool Model::Recognize(const std::string &name) { + try { + util::scoped_fd file(util::OpenReadOrThrow(name.c_str())); + char magic_check[16]; + util::ReadOrThrow(file.get(), magic_check, sizeof(magic_check)); + const char nnlm_magic[] = "\\config\nversion "; + return !memcmp(magic_check, nnlm_magic, 16); + } catch (const util::Exception &) { + return false; + } +} + +Model::Model(const std::string &file, std::size_t cache) + : base_instance_(new nplm::neuralLM(file)), vocab_(base_instance_->get_vocabulary()), cache_size_(cache) { + UTIL_THROW_IF(base_instance_->get_order() > NPLM_MAX_ORDER, util::Exception, "This NPLM has order " << (unsigned int)base_instance_->get_order() << " but the KenLM wrapper was compiled with " << NPLM_MAX_ORDER << ". Change the defintion of NPLM_MAX_ORDER and recompile."); + // log10 compatible with backoff models. + base_instance_->set_log_base(10.0); + State begin_sentence, null_context; + std::fill(begin_sentence.words, begin_sentence.words + NPLM_MAX_ORDER - 1, base_instance_->lookup_word("<s>")); + null_word_ = base_instance_->lookup_word("<null>"); + std::fill(null_context.words, null_context.words + NPLM_MAX_ORDER - 1, null_word_); + + Init(begin_sentence, null_context, vocab_, base_instance_->get_order()); +} + +Model::~Model() {} + +FullScoreReturn Model::FullScore(const State &from, const WordIndex new_word, State &out_state) const { + nplm::neuralLM *lm = backend_.get(); + if (!lm) { + lm = new nplm::neuralLM(*base_instance_); + backend_.reset(lm); + lm->set_cache(cache_size_); + } + // State is in natural word order. + FullScoreReturn ret; + for (int i = 0; i < lm->get_order() - 1; ++i) { + lm->staging_ngram()(i) = from.words[i]; + } + lm->staging_ngram()(lm->get_order() - 1) = new_word; + ret.prob = lm->lookup_from_staging(); + // Always say full order. + ret.ngram_length = lm->get_order(); + // Shift everything down by one. + memcpy(out_state.words, from.words + 1, sizeof(WordIndex) * (lm->get_order() - 2)); + out_state.words[lm->get_order() - 2] = new_word; + // Fill in trailing words with zeros so state comparison works. + memset(out_state.words + lm->get_order() - 1, 0, sizeof(WordIndex) * (NPLM_MAX_ORDER - lm->get_order())); + return ret; +} + +// TODO: optimize with direct call? +FullScoreReturn Model::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { + // State is in natural word order. The API here specifies reverse order. + std::size_t state_length = std::min<std::size_t>(Order() - 1, context_rend - context_rbegin); + State state; + // Pad with null words. + for (lm::WordIndex *i = state.words; i < state.words + Order() - 1 - state_length; ++i) { + *i = null_word_; + } + // Put new words at the end. + std::reverse_copy(context_rbegin, context_rbegin + state_length, state.words + Order() - 1 - state_length); + return FullScore(state, new_word, out_state); +} + +} // namespace np +} // namespace lm diff --git a/klm/lm/wrappers/nplm.hh b/klm/lm/wrappers/nplm.hh new file mode 100644 index 00000000..b7dd4a21 --- /dev/null +++ b/klm/lm/wrappers/nplm.hh @@ -0,0 +1,83 @@ +#ifndef LM_WRAPPERS_NPLM_H +#define LM_WRAPPERS_NPLM_H + +#include "lm/facade.hh" +#include "lm/max_order.hh" +#include "util/string_piece.hh" + +#include <boost/thread/tss.hpp> +#include <boost/scoped_ptr.hpp> + +/* Wrapper to NPLM "by Ashish Vaswani, with contributions from David Chiang + * and Victoria Fossum." + * http://nlg.isi.edu/software/nplm/ + */ + +namespace nplm { +class vocabulary; +class neuralLM; +} // namespace nplm + +namespace lm { +namespace np { + +class Vocabulary : public base::Vocabulary { + public: + Vocabulary(const nplm::vocabulary &vocab); + + ~Vocabulary(); + + WordIndex Index(const std::string &str) const; + + // TODO: lobby them to support StringPiece + WordIndex Index(const StringPiece &str) const { + return Index(std::string(str.data(), str.size())); + } + + lm::WordIndex NullWord() const { return null_word_; } + + private: + const nplm::vocabulary &vocab_; + + const lm::WordIndex null_word_; +}; + +// Sorry for imposing my limitations on your code. +#define NPLM_MAX_ORDER 7 + +struct State { + WordIndex words[NPLM_MAX_ORDER - 1]; +}; + +class Model : public lm::base::ModelFacade<Model, State, Vocabulary> { + private: + typedef lm::base::ModelFacade<Model, State, Vocabulary> P; + + public: + // Does this look like an NPLM? + static bool Recognize(const std::string &file); + + explicit Model(const std::string &file, std::size_t cache_size = 1 << 20); + + ~Model(); + + FullScoreReturn FullScore(const State &from, const WordIndex new_word, State &out_state) const; + + FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; + + private: + boost::scoped_ptr<nplm::neuralLM> base_instance_; + + mutable boost::thread_specific_ptr<nplm::neuralLM> backend_; + + Vocabulary vocab_; + + lm::WordIndex null_word_; + + const std::size_t cache_size_; +}; + +} // namespace np +} // namespace lm + +#endif // LM_WRAPPERS_NPLM_H |