diff options
author | Kenneth Heafield <kheafiel@cluster10.lti.ece.cmu.local> | 2011-03-09 13:40:23 -0500 |
---|---|---|
committer | Kenneth Heafield <kheafiel@cluster10.lti.ece.cmu.local> | 2011-03-09 13:40:23 -0500 |
commit | 6c923d45f2aaf960806429d36ca58a41b3a39740 (patch) | |
tree | 9d8c5bf26189e9e8e6c12c199a5925c5ca6046a9 /klm/lm | |
parent | 95ea293005f74a627fdd2aae318d5746fa8c4e6c (diff) |
kenlm sync
Diffstat (limited to 'klm/lm')
-rw-r--r-- | klm/lm/binary_format.cc | 35 | ||||
-rw-r--r-- | klm/lm/binary_format.hh | 6 | ||||
-rw-r--r-- | klm/lm/build_binary.cc | 26 | ||||
-rw-r--r-- | klm/lm/config.cc | 1 | ||||
-rw-r--r-- | klm/lm/config.hh | 7 | ||||
-rw-r--r-- | klm/lm/exception.cc | 21 | ||||
-rw-r--r-- | klm/lm/exception.hh | 40 | ||||
-rw-r--r-- | klm/lm/lm_exception.cc | 4 | ||||
-rw-r--r-- | klm/lm/lm_exception.hh | 2 | ||||
-rw-r--r-- | klm/lm/model.cc | 22 | ||||
-rw-r--r-- | klm/lm/ngram.hh | 226 | ||||
-rw-r--r-- | klm/lm/ngram_build_binary.cc | 13 | ||||
-rw-r--r-- | klm/lm/ngram_config.hh | 58 | ||||
-rw-r--r-- | klm/lm/ngram_query.cc | 33 | ||||
-rw-r--r-- | klm/lm/ngram_test.cc | 91 | ||||
-rw-r--r-- | klm/lm/read_arpa.cc | 83 | ||||
-rw-r--r-- | klm/lm/read_arpa.hh | 3 | ||||
-rw-r--r-- | klm/lm/search_hashed.cc | 5 | ||||
-rw-r--r-- | klm/lm/search_trie.cc | 50 | ||||
-rw-r--r-- | klm/lm/test.binary | bin | 1660 -> 0 bytes | |||
-rw-r--r-- | klm/lm/virtual_interface.cc | 2 | ||||
-rw-r--r-- | klm/lm/vocab.cc | 24 | ||||
-rw-r--r-- | klm/lm/vocab.hh | 10 |
23 files changed, 163 insertions, 599 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 2a6aff34..9be0bc8e 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -9,6 +9,7 @@ #include <fcntl.h> #include <errno.h> #include <stdlib.h> +#include <string.h> #include <sys/mman.h> #include <sys/types.h> #include <sys/stat.h> @@ -19,6 +20,8 @@ namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0"; +// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). +const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; const long int kMagicVersion = 4; // Test values. @@ -81,6 +84,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ if (config.write_mmap) { std::size_t total = TotalHeaderSize(order) + memory_size; backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED); + strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order)); return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order); } else { backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); @@ -88,17 +92,8 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ } } -uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) { +uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) { if (config.write_mmap) { - // header and vocab share the same mmap. The header is written here because we know the counts. - Parameters params; - params.counts = counts; - params.fixed.order = counts.size(); - params.fixed.probing_multiplier = config.probing_multiplier; - params.fixed.model_type = model_type; - params.fixed.has_vocabulary = config.include_vocab; - WriteHeader(backing.vocab.get(), params); - // Grow the file to accomodate the search, using zeros. if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size)) UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed"); @@ -115,6 +110,19 @@ uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::ve } } +void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing) { + if (config.write_mmap) { + // header and vocab share the same mmap. The header is written here because we know the counts. + Parameters params; + params.counts = counts; + params.fixed.order = counts.size(); + params.fixed.probing_multiplier = config.probing_multiplier; + params.fixed.model_type = model_type; + params.fixed.has_vocabulary = config.include_vocab; + WriteHeader(backing.vocab.get(), params); + } +} + namespace detail { bool IsBinaryFormat(int fd) { @@ -130,14 +138,17 @@ bool IsBinaryFormat(int fd) { Sanity reference_header = Sanity(); reference_header.SetToReference(); if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true; + if (!memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) { + UTIL_THROW(FormatLoadException, "This binary file did not finish building"); + } if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) { char *end_ptr; const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion); long int version = strtol(begin_version, &end_ptr, 10); if ((end_ptr != begin_version) && version != kMagicVersion) { - UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to rebuild your binary LM from the ARPA. Sorry."); + UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary"); } - UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture."); + UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture"); } return false; } diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 2d66f813..72d8c159 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -43,7 +43,9 @@ struct Backing { uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. -uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing); +uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing); + +void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing); namespace detail { @@ -81,7 +83,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to) to.InitializeFromARPA(file, config); } } catch (util::Exception &e) { - e << " in file " << file; + e << " File: " << file; throw; } diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 144c57e0..d6dd5994 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -1,6 +1,8 @@ #include "lm/model.hh" #include "util/file_piece.hh" +#include <cstdlib> +#include <exception> #include <iostream> #include <iomanip> @@ -13,8 +15,10 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u unknown_probability] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" -"Where type is one of probing, trie, or sorted:\n\n" + std::cerr << "Usage: " << name << " [-u unknown_probability] [-s] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" +"-u sets the default probability for <unk> if the ARPA file does not have one.\n" +"-s allows models to be built even if they do not have <s> and </s>.\n\n" +"type is one of probing, trie, or sorted:\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" "trie is a straightforward trie with bit-level packing. It uses the least\n" @@ -65,12 +69,25 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { } // namespace lm } // namespace +void terminate_handler() { + try { throw; } + catch(const std::exception& e) { + std::cerr << e.what() << std::endl; + } + catch(...) { + std::cerr << "A non-standard exception was thrown." << std::endl; + } + std::abort(); +} + int main(int argc, char *argv[]) { using namespace lm::ngram; + std::set_terminate(terminate_handler); + lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "u:p:t:m:")) != -1) { + while ((opt = getopt(argc, argv, "su:p:t:m:")) != -1) { switch(opt) { case 'u': config.unknown_missing_prob = ParseFloat(optarg); @@ -84,6 +101,9 @@ int main(int argc, char *argv[]) { case 'm': config.building_memory = ParseUInt(optarg) * 1048576; break; + case 's': + config.sentence_marker_missing = lm::ngram::Config::SILENT; + break; default: Usage(argv[0]); } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 2831d578..d8773fe5 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -9,6 +9,7 @@ Config::Config() : messages(&std::cerr), enumerate_vocab(NULL), unknown_missing(COMPLAIN), + sentence_marker_missing(THROW_UP), unknown_missing_prob(0.0), probing_multiplier(1.5), building_memory(1073741824ULL), // 1 GB diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 767fa5f9..17f67df3 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -27,9 +27,12 @@ struct Config { // ONLY EFFECTIVE WHEN READING ARPA + typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction; // What to do when <unk> isn't in the provided model. - typedef enum {THROW_UP, COMPLAIN, SILENT} UnknownMissing; - UnknownMissing unknown_missing; + WarningAction unknown_missing; + // What to do when <s> or </s> is missing from the model. + // If THROW_UP, the exception will be of type util::SpecialWordMissingException. + WarningAction sentence_marker_missing; // The probability to substitute for <unk> if it's missing from the model. // No effect if the model has <unk> or unknown_missing == THROW_UP. diff --git a/klm/lm/exception.cc b/klm/lm/exception.cc deleted file mode 100644 index 59a1650d..00000000 --- a/klm/lm/exception.cc +++ /dev/null @@ -1,21 +0,0 @@ -#include "lm/exception.hh" - -#include<errno.h> -#include<stdio.h> - -namespace lm { - -LoadException::LoadException() throw() {} -LoadException::~LoadException() throw() {} -VocabLoadException::VocabLoadException() throw() {} -VocabLoadException::~VocabLoadException() throw() {} - -FormatLoadException::FormatLoadException() throw() {} -FormatLoadException::~FormatLoadException() throw() {} - -SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() { - *this << "Missing special word " << which; -} -SpecialWordMissingException::~SpecialWordMissingException() throw() {} - -} // namespace lm diff --git a/klm/lm/exception.hh b/klm/lm/exception.hh deleted file mode 100644 index 95109012..00000000 --- a/klm/lm/exception.hh +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef LM_EXCEPTION__ -#define LM_EXCEPTION__ - -#include "util/exception.hh" -#include "util/string_piece.hh" - -#include <exception> -#include <string> - -namespace lm { - -class LoadException : public util::Exception { - public: - virtual ~LoadException() throw(); - - protected: - LoadException() throw(); -}; - -class VocabLoadException : public LoadException { - public: - virtual ~VocabLoadException() throw(); - VocabLoadException() throw(); -}; - -class FormatLoadException : public LoadException { - public: - FormatLoadException() throw(); - ~FormatLoadException() throw(); -}; - -class SpecialWordMissingException : public VocabLoadException { - public: - explicit SpecialWordMissingException(StringPiece which) throw(); - ~SpecialWordMissingException() throw(); -}; - -} // namespace lm - -#endif diff --git a/klm/lm/lm_exception.cc b/klm/lm/lm_exception.cc index 473849d1..0b572e98 100644 --- a/klm/lm/lm_exception.cc +++ b/klm/lm/lm_exception.cc @@ -17,9 +17,7 @@ FormatLoadException::~FormatLoadException() throw() {} VocabLoadException::VocabLoadException() throw() {} VocabLoadException::~VocabLoadException() throw() {} -SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() { - *this << "Missing special word " << which; -} +SpecialWordMissingException::SpecialWordMissingException() throw() {} SpecialWordMissingException::~SpecialWordMissingException() throw() {} } // namespace lm diff --git a/klm/lm/lm_exception.hh b/klm/lm/lm_exception.hh index 3773c572..aa3ca886 100644 --- a/klm/lm/lm_exception.hh +++ b/klm/lm/lm_exception.hh @@ -39,7 +39,7 @@ class VocabLoadException : public LoadException { class SpecialWordMissingException : public VocabLoadException { public: - explicit SpecialWordMissingException(StringPiece which) throw(); + explicit SpecialWordMissingException() throw(); ~SpecialWordMissingException() throw(); }; diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 146fe07b..14949e97 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -82,25 +82,13 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); } - // TODO: fail faster? if (!vocab_.SawUnk()) { - switch(config.unknown_missing) { - case Config::THROW_UP: - { - SpecialWordMissingException e("<unk>"); - e << " and configuration was set to throw if unknown is missing"; - throw e; - } - case Config::COMPLAIN: - if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl; - // There's no break;. This is by design. - case Config::SILENT: - // Default probabilities for unknown. - search_.unigram.Unknown().backoff = 0.0; - search_.unigram.Unknown().prob = config.unknown_missing_prob; - break; - } + assert(config.unknown_missing != Config::THROW_UP); + // Default probabilities for unknown. + search_.unigram.Unknown().backoff = 0.0; + search_.unigram.Unknown().prob = config.unknown_missing_prob; } + FinishFile(config, kModelType, counts, backing_); } template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { diff --git a/klm/lm/ngram.hh b/klm/lm/ngram.hh deleted file mode 100644 index 899a80e8..00000000 --- a/klm/lm/ngram.hh +++ /dev/null @@ -1,226 +0,0 @@ -#ifndef LM_NGRAM__ -#define LM_NGRAM__ - -#include "lm/facade.hh" -#include "lm/ngram_config.hh" -#include "util/key_value_packing.hh" -#include "util/mmap.hh" -#include "util/probing_hash_table.hh" -#include "util/scoped.hh" -#include "util/sorted_uniform.hh" -#include "util/string_piece.hh" - -#include <algorithm> -#include <memory> -#include <vector> - -namespace util { class FilePiece; } - -namespace lm { -namespace ngram { - -// If you need higher order, change this and recompile. -// Having this limit means that State can be -// (kMaxOrder - 1) * sizeof(float) bytes instead of -// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead -const std::size_t kMaxOrder = 6; - -// This is a POD. -class State { - public: - bool operator==(const State &other) const { - if (valid_length_ != other.valid_length_) return false; - const WordIndex *end = history_ + valid_length_; - for (const WordIndex *first = history_, *second = other.history_; - first != end; ++first, ++second) { - if (*first != *second) return false; - } - // If the histories are equal, so are the backoffs. - return true; - } - - // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. - // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. - WordIndex history_[kMaxOrder - 1]; - float backoff_[kMaxOrder - 1]; - unsigned char valid_length_; -}; - -size_t hash_value(const State &state); - -namespace detail { - -uint64_t HashForVocab(const char *str, std::size_t len); -inline uint64_t HashForVocab(const StringPiece &str) { - return HashForVocab(str.data(), str.length()); -} - -struct Prob { - float prob; - void SetBackoff(float to); - void ZeroBackoff() {} -}; -// No inheritance so this will be a POD. -struct ProbBackoff { - float prob; - float backoff; - void SetBackoff(float to) { backoff = to; } - void ZeroBackoff() { backoff = 0.0; } -}; - -} // namespace detail - -// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. -class SortedVocabulary : public base::Vocabulary { - private: - // Sorted uniform requires a GetKey function. - struct Entry { - uint64_t GetKey() const { return key; } - uint64_t key; - bool operator<(const Entry &other) const { - return key < other.key; - } - }; - - public: - SortedVocabulary(); - - WordIndex Index(const StringPiece &str) const { - const Entry *found; - if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) { - return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table. - } else { - return 0; - } - } - - // Ignores second argument for consistency with probing hash which has a float here. - static size_t Size(std::size_t entries, float ignored = 0.0); - - // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. - void Init(void *start, std::size_t allocated, std::size_t entries); - - WordIndex Insert(const StringPiece &str); - - // Returns true if unknown was seen. Reorders reorder_vocab so that the IDs are sorted. - bool FinishedLoading(detail::ProbBackoff *reorder_vocab); - - void LoadedBinary(); - - private: - Entry *begin_, *end_; - - bool saw_unk_; -}; - -namespace detail { - -// Vocabulary storing a map from uint64_t to WordIndex. -template <class Search> class MapVocabulary : public base::Vocabulary { - public: - MapVocabulary(); - - WordIndex Index(const StringPiece &str) const { - typename Lookup::ConstIterator i; - return lookup_.Find(HashForVocab(str), i) ? i->GetValue() : 0; - } - - static size_t Size(std::size_t entries, float probing_multiplier) { - return Lookup::Size(entries, probing_multiplier); - } - - // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. - void Init(void *start, std::size_t allocated, std::size_t entries); - - WordIndex Insert(const StringPiece &str); - - // Returns true if unknown was seen. Does nothing with reorder_vocab. - bool FinishedLoading(ProbBackoff *reorder_vocab); - - void LoadedBinary(); - - private: - typedef typename Search::template Table<WordIndex>::T Lookup; - Lookup lookup_; - - bool saw_unk_; -}; - -// std::identity is an SGI extension :-( -struct IdentityHash : public std::unary_function<uint64_t, size_t> { - size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); } -}; - -// Should return the same results as SRI. -// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary. -template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> { - private: - typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P; - public: - // Get the size of memory that will be mapped given ngram counts. This - // does not include small non-mapped control structures, such as this class - // itself. - static size_t Size(const std::vector<size_t> &counts, const Config &config = Config()); - - GenericModel(const char *file, Config config = Config()); - - FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; - - private: - // Appears after Size in the cc. - void SetupMemory(char *start, const std::vector<size_t> &counts, const Config &config); - - void LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config); - - util::scoped_fd mapped_file_; - - // memory_ is the raw block of memory backing vocab_, unigram_, [middle.begin(), middle.end()), and longest_. - util::scoped_mmap memory_; - - VocabularyT vocab_; - - ProbBackoff *unigram_; - - typedef typename Search::template Table<ProbBackoff>::T Middle; - std::vector<Middle> middle_; - - typedef typename Search::template Table<Prob>::T Longest; - Longest longest_; -}; - -struct ProbingSearch { - typedef float Init; - - static const unsigned char kBinaryTag = 1; - - template <class Value> struct Table { - typedef util::ByteAlignedPacking<uint64_t, Value> Packing; - typedef util::ProbingHashTable<Packing, IdentityHash> T; - }; -}; - -struct SortedUniformSearch { - // This is ignored. - typedef float Init; - - static const unsigned char kBinaryTag = 2; - - template <class Value> struct Table { - typedef util::ByteAlignedPacking<uint64_t, Value> Packing; - typedef util::SortedUniformMap<Packing> T; - }; -}; - -} // namespace detail - -// These must also be instantiated in the cc file. -typedef detail::MapVocabulary<detail::ProbingSearch> Vocabulary; -typedef detail::GenericModel<detail::ProbingSearch, Vocabulary> Model; - -// SortedVocabulary was defined above. -typedef detail::GenericModel<detail::SortedUniformSearch, SortedVocabulary> SortedModel; - -} // namespace ngram -} // namespace lm - -#endif // LM_NGRAM__ diff --git a/klm/lm/ngram_build_binary.cc b/klm/lm/ngram_build_binary.cc deleted file mode 100644 index 9dab30a1..00000000 --- a/klm/lm/ngram_build_binary.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "lm/ngram.hh" - -#include <iostream> - -int main(int argc, char *argv[]) { - if (argc != 3) { - std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl; - return 1; - } - lm::ngram::Config config; - config.write_mmap = argv[2]; - lm::ngram::Model(argv[1], config); -} diff --git a/klm/lm/ngram_config.hh b/klm/lm/ngram_config.hh deleted file mode 100644 index a7b3afae..00000000 --- a/klm/lm/ngram_config.hh +++ /dev/null @@ -1,58 +0,0 @@ -#ifndef LM_NGRAM_CONFIG__ -#define LM_NGRAM_CONFIG__ - -/* Configuration for ngram model. Separate header to reduce pollution. */ - -#include <iostream> - -namespace lm { namespace ngram { - -struct Config { - /* EFFECTIVE FOR BOTH ARPA AND BINARY READS */ - // Where to log messages including the progress bar. Set to NULL for - // silence. - std::ostream *messages; - - - - /* ONLY EFFECTIVE WHEN READING ARPA */ - - // What to do when <unk> isn't in the provided model. - typedef enum {THROW_UP, COMPLAIN, SILENT} UnknownMissing; - UnknownMissing unknown_missing; - - // The probability to substitute for <unk> if it's missing from the model. - // No effect if the model has <unk> or unknown_missing == THROW_UP. - float unknown_missing_prob; - - // Size multiplier for probing hash table. Must be > 1. Space is linear in - // this. Time is probing_multiplier / (probing_multiplier - 1). No effect - // for sorted variant. - // If you find yourself setting this to a low number, consider using the - // Sorted version instead which has lower memory consumption. - float probing_multiplier; - - // While loading an ARPA file, also write out this binary format file. Set - // to NULL to disable. - const char *write_mmap; - - - - /* ONLY EFFECTIVE WHEN READING BINARY */ - bool prefault; - - - - // Defaults. - Config() : - messages(&std::cerr), - unknown_missing(COMPLAIN), - unknown_missing_prob(0.0), - probing_multiplier(1.5), - write_mmap(NULL), - prefault(false) {} -}; - -} /* namespace ngram */ } /* namespace lm */ - -#endif // LM_NGRAM_CONFIG__ diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index d6da02e3..9454a6d1 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -35,14 +35,14 @@ void PrintUsage(const char *message) { } } -template <class Model> void Query(const Model &model) { +template <class Model> void Query(const Model &model, bool sentence_context) { PrintUsage("Loading statistics:\n"); typename Model::State state, out; lm::FullScoreReturn ret; std::string word; while (std::cin) { - state = model.BeginSentenceState(); + state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); float total = 0.0; bool got = false; unsigned int oov = 0; @@ -52,7 +52,7 @@ template <class Model> void Query(const Model &model) { if (vocab == 0) ++oov; ret = model.FullScore(state, vocab, out); total += ret.prob; - std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n'; + std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; state = out; char c; while (true) { @@ -67,9 +67,11 @@ template <class Model> void Query(const Model &model) { if (c == '\n') break; } if (!got && !std::cin) break; - ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); - total += ret.prob; - std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n'; + if (sentence_context) { + ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); + total += ret.prob; + std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; + } std::cout << "Total: " << total << " OOV: " << oov << '\n'; } PrintUsage("After queries:\n"); @@ -82,29 +84,30 @@ template <class Model> void Query(const char *name) { } int main(int argc, char *argv[]) { - if (argc < 2) { - std::cerr << "Pass language model name." << std::endl; - return 0; + if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) { + std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl; + std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl; + return 1; } + bool sentence_context = (argc == 2); lm::ngram::ModelType model_type; if (lm::ngram::RecognizeBinary(argv[1], model_type)) { switch(model_type) { case lm::ngram::HASH_PROBING: - Query<lm::ngram::ProbingModel>(argv[1]); - break; - case lm::ngram::HASH_SORTED: - Query<lm::ngram::SortedModel>(argv[1]); + Query<lm::ngram::ProbingModel>(argv[1], sentence_context); break; case lm::ngram::TRIE_SORTED: - Query<lm::ngram::TrieModel>(argv[1]); + Query<lm::ngram::TrieModel>(argv[1], sentence_context); break; + case lm::ngram::HASH_SORTED: default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; abort(); } } else { - Query<lm::ngram::ProbingModel>(argv[1]); + Query<lm::ngram::ProbingModel>(argv[1], sentence_context); } PrintUsage("Total time including destruction:\n"); + return 0; } diff --git a/klm/lm/ngram_test.cc b/klm/lm/ngram_test.cc deleted file mode 100644 index 031e0348..00000000 --- a/klm/lm/ngram_test.cc +++ /dev/null @@ -1,91 +0,0 @@ -#include "lm/ngram.hh" - -#include <stdlib.h> - -#define BOOST_TEST_MODULE NGramTest -#include <boost/test/unit_test.hpp> - -namespace lm { -namespace ngram { -namespace { - -#define StartTest(word, ngram, score) \ - ret = model.FullScore( \ - state, \ - model.GetVocabulary().Index(word), \ - out);\ - BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ - BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \ - BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); - -#define AppendTest(word, ngram, score) \ - StartTest(word, ngram, score) \ - state = out; - -template <class M> void Starters(M &model) { - FullScoreReturn ret; - Model::State state(model.BeginSentenceState()); - Model::State out; - - StartTest("looking", 2, -0.4846522); - - // , probability plus <s> backoff - StartTest(",", 1, -1.383514 + -0.4149733); - // <unk> probability plus <s> backoff - StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); -} - -template <class M> void Continuation(M &model) { - FullScoreReturn ret; - Model::State state(model.BeginSentenceState()); - Model::State out; - - AppendTest("looking", 2, -0.484652); - AppendTest("on", 3, -0.348837); - AppendTest("a", 4, -0.0155266); - AppendTest("little", 5, -0.00306122); - State preserve = state; - AppendTest("the", 1, -4.04005); - AppendTest("biarritz", 1, -1.9889); - AppendTest("not_found", 0, -2.29666); - AppendTest("more", 1, -1.20632); - AppendTest(".", 2, -0.51363); - AppendTest("</s>", 3, -0.0191651); - - state = preserve; - AppendTest("more", 5, -0.00181395); - AppendTest("loin", 5, -0.0432557); -} - -BOOST_AUTO_TEST_CASE(starters_probing) { Model m("test.arpa"); Starters(m); } -BOOST_AUTO_TEST_CASE(continuation_probing) { Model m("test.arpa"); Continuation(m); } -BOOST_AUTO_TEST_CASE(starters_sorted) { SortedModel m("test.arpa"); Starters(m); } -BOOST_AUTO_TEST_CASE(continuation_sorted) { SortedModel m("test.arpa"); Continuation(m); } - -BOOST_AUTO_TEST_CASE(write_and_read_probing) { - Config config; - config.write_mmap = "test.binary"; - { - Model copy_model("test.arpa", config); - } - Model binary("test.binary"); - Starters(binary); - Continuation(binary); -} - -BOOST_AUTO_TEST_CASE(write_and_read_sorted) { - Config config; - config.write_mmap = "test.binary"; - config.prefault = true; - { - SortedModel copy_model("test.arpa", config); - } - SortedModel binary("test.binary"); - Starters(binary); - Continuation(binary); -} - - -} // namespace -} // namespace ngram -} // namespace lm diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index d0fe67f0..0e90196d 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -6,6 +6,7 @@ #include <vector> #include <ctype.h> +#include <string.h> #include <inttypes.h> namespace lm { @@ -22,14 +23,20 @@ bool IsEntirelyWhiteSpace(const StringPiece &line) { return true; } -template <class F> void GenericReadARPACounts(F &in, std::vector<uint64_t> &number) { +const char kBinaryMagic[] = "mmap lm http://kheafield.com/code"; + +} // namespace + +void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { number.clear(); StringPiece line; if (!IsEntirelyWhiteSpace(line = in.ReadLine())) { if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) { - UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, run\nzcat " << in.FileName() << " |kenlm/build_binary /dev/stdin " << in.FileName() << ".binary\nIf this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); + UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); } - UTIL_THROW(FormatLoadException, "First line was \"" << static_cast<int>(line.data()[1]) << "\" not blank"); + if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic) + UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?"); + UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank"); } if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { @@ -49,66 +56,14 @@ template <class F> void GenericReadARPACounts(F &in, std::vector<uint64_t> &numb } } -template <class F> void GenericReadNGramHeader(F &in, unsigned int length) { - StringPiece line; +void ReadNGramHeader(util::FilePiece &in, unsigned int length) { + StringPiece line; while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} std::stringstream expected; expected << '\\' << length << "-grams:"; if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead"); } -template <class F> void GenericReadEnd(F &in) { - StringPiece line; - do { - line = in.ReadLine(); - } while (IsEntirelyWhiteSpace(line)); - if (line != "\\end\\") UTIL_THROW(FormatLoadException, "Expected \\end\\ but the ARPA file has " << line); -} - -class FakeFilePiece { - public: - explicit FakeFilePiece(std::istream &in) : in_(in) { - in_.exceptions(std::ios::failbit | std::ios::badbit | std::ios::eofbit); - } - - StringPiece ReadLine() throw(util::EndOfFileException) { - getline(in_, buffer_); - return StringPiece(buffer_); - } - - float ReadFloat() { - float ret; - in_ >> ret; - return ret; - } - - const char *FileName() const { - // This only used for error messages and we don't know the file name. . . - return "$file"; - } - - private: - std::istream &in_; - std::string buffer_; -}; - -} // namespace - -void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { - GenericReadARPACounts(in, number); -} -void ReadARPACounts(std::istream &in, std::vector<uint64_t> &number) { - FakeFilePiece fake(in); - GenericReadARPACounts(fake, number); -} -void ReadNGramHeader(util::FilePiece &in, unsigned int length) { - GenericReadNGramHeader(in, length); -} -void ReadNGramHeader(std::istream &in, unsigned int length) { - FakeFilePiece fake(in); - GenericReadNGramHeader(fake, length); -} - void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) { switch (in.get()) { case '\t': @@ -146,20 +101,18 @@ void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) { } void ReadEnd(util::FilePiece &in) { - GenericReadEnd(in); StringPiece line; + do { + line = in.ReadLine(); + } while (IsEntirelyWhiteSpace(line)); + if (line != "\\end\\") UTIL_THROW(FormatLoadException, "Expected \\end\\ but the ARPA file has " << line); + try { while (true) { line = in.ReadLine(); if (!IsEntirelyWhiteSpace(line)) UTIL_THROW(FormatLoadException, "Trailing line " << line); } - } catch (const util::EndOfFileException &e) { - return; - } -} -void ReadEnd(std::istream &in) { - FakeFilePiece fake(in); - GenericReadEnd(fake); + } catch (const util::EndOfFileException &e) {} } } // namespace lm diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh index 4efdd29d..4953d40e 100644 --- a/klm/lm/read_arpa.hh +++ b/klm/lm/read_arpa.hh @@ -13,15 +13,12 @@ namespace lm { void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number); -void ReadARPACounts(std::istream &in, std::vector<uint64_t> &number); void ReadNGramHeader(util::FilePiece &in, unsigned int length); -void ReadNGramHeader(std::istream &in, unsigned int length); void ReadBackoff(util::FilePiece &in, Prob &weights); void ReadBackoff(util::FilePiece &in, ProbBackoff &weights); void ReadEnd(util::FilePiece &in); -void ReadEnd(std::istream &in); extern const bool kARPASpaces[256]; diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index f97ec790..bb3b955a 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -83,9 +83,10 @@ namespace detail { template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. - SetupMemory(GrowForSearch(config, HASH_PROBING, counts, Size(counts, config), backing), counts, config); + SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config); - Read1Grams(f, counts[0], vocab, unigram.Raw()); + Read1Grams(f, counts[0], vocab, unigram.Raw()); + CheckSpecials(config, vocab); try { if (counts.size() > 2) { diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 1060ddef..63631223 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -11,6 +11,7 @@ #include "lm/word_index.hh" #include "util/ersatz_progress.hh" #include "util/file_piece.hh" +#include "util/have.hh" #include "util/proxy_iterator.hh" #include "util/scoped.hh" @@ -20,7 +21,6 @@ #include <cstdio> #include <deque> #include <limits> -//#include <parallel/algorithm> #include <vector> #include <sys/mman.h> @@ -170,7 +170,7 @@ template <class Proxy> class CompareRecords : public std::binary_function<const return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices()); } bool operator()(const std::string &first, const std::string &second) const { - return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(first.data())); + return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(second.data())); } private: @@ -384,7 +384,6 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); - // TODO: __gnu_parallel::sort here. std::sort(context_begin, context_end, CompareRecords<PartialViewProxy>(order - 1)); std::string name(ngram_file_name + kContextSuffix); @@ -406,16 +405,16 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil class ContextReader { public: - ContextReader() : length_(0) {} + ContextReader() : valid_(false) {} - ContextReader(const char *name, size_t length) : file_(OpenOrThrow(name, "r")), length_(length), words_(length), valid_(true) { - ++*this; + ContextReader(const char *name, unsigned char order) { + Reset(name, order); } - void Reset(const char *name, size_t length) { + void Reset(const char *name, unsigned char order) { file_.reset(OpenOrThrow(name, "r")); - length_ = length; - words_.resize(length); + length_ = sizeof(WordIndex) * static_cast<size_t>(order); + words_.resize(order); valid_ = true; ++*this; } @@ -449,14 +448,14 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_ const size_t context_size = sizeof(WordIndex) * (order - 1); std::string first_name(first_base + kContextSuffix); std::string second_name(second_base + kContextSuffix); - ContextReader first(first_name.c_str(), context_size), second(second_name.c_str(), context_size); + ContextReader first(first_name.c_str(), order - 1), second(second_name.c_str(), order - 1); RemoveOrThrow(first_name.c_str()); RemoveOrThrow(second_name.c_str()); std::string out_name(out_base + kContextSuffix); util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w")); while (first && second) { for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) { - if (f == *first + order) { + if (f == *first + order - 1) { // Equal. WriteOrThrow(out.get(), *first, context_size); ++first; @@ -475,7 +474,10 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_ } } } - CopyRestOrThrow((first ? first : second).GetFile(), out.get()); + ContextReader &remaining = first ? first : second; + if (!remaining) return; + WriteOrThrow(out.get(), *remaining, context_size); + CopyRestOrThrow(remaining.GetFile(), out.get()); } void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { @@ -502,7 +504,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st } // Sort full records by full n-gram. EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - // TODO: __gnu_parallel::sort here. + // parallel_sort uses too much RAM std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords<EntryProxy>(order)); files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); WriteContextFile(begin, out_end, files.back(), entry_size, order); @@ -533,21 +535,22 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st } } -void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, const std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { { std::string unigram_name = file_prefix + "unigrams"; util::scoped_fd unigram_file; util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get())); + CheckSpecials(config, vocab); } // Only use as much buffer as we need. size_t buffer_use = 0; for (unsigned int order = 2; order < counts.size(); ++order) { - buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); + buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); } - buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); - buffer = std::min(buffer, buffer_use); + buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); + buffer = std::min<size_t>(buffer, buffer_use); util::scoped_memory mem; mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); @@ -767,7 +770,7 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u } } -void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) { +void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) { SortedFileReader inputs[counts.size() - 1]; ContextReader contexts[counts.size() - 1]; @@ -777,7 +780,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun inputs[i-2].Init(assembled.str(), i); RemoveOrThrow(assembled.str().c_str()); assembled << kContextSuffix; - contexts[i-2].Reset(assembled.str().c_str(), (i-1) * sizeof(WordIndex)); + contexts[i-2].Reset(assembled.str().c_str(), i-1); RemoveOrThrow(assembled.str().c_str()); } @@ -787,8 +790,9 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } SanityCheckCounts(counts, fixed_counts); + counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::kModelType, fixed_counts, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -811,7 +815,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun ++contexts[0]; } } - unlink(name.c_str()); + RemoveOrThrow(name.c_str()); } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. @@ -823,7 +827,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun for (const WordIndex *i = *context; i != *context + order - 1; ++i) { e << ' ' << *i; } - e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not."; + e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not"; throw e; } } @@ -868,7 +872,7 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v // Add directory delimiter. Assumes a real operating system. temporary_directory += '/'; // At least 1MB sorting memory. - ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); + ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); BuildTrie(temporary_directory, counts, config, *this, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { diff --git a/klm/lm/test.binary b/klm/lm/test.binary Binary files differdeleted file mode 100644 index 90bd2b76..00000000 --- a/klm/lm/test.binary +++ /dev/null diff --git a/klm/lm/virtual_interface.cc b/klm/lm/virtual_interface.cc index c5a64972..17a74c3c 100644 --- a/klm/lm/virtual_interface.cc +++ b/klm/lm/virtual_interface.cc @@ -11,8 +11,6 @@ void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, Wo begin_sentence_ = begin_sentence; end_sentence_ = end_sentence; not_found_ = not_found; - if (begin_sentence_ == not_found_) throw SpecialWordMissingException("<s>"); - if (end_sentence_ == not_found_) throw SpecialWordMissingException("</s>"); } Model::~Model() {} diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index ae79c727..415f8331 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -187,5 +187,29 @@ void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { SetSpecial(Index("<s>"), Index("</s>"), 0); } +void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { + switch(config.unknown_missing) { + case Config::SILENT: + return; + case Config::COMPLAIN: + if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl; + break; + case Config::THROW_UP: + UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing <unk> and the model is configured to throw an exception."); + } +} + +void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) { + switch (config.sentence_marker_missing) { + case Config::SILENT: + return; + case Config::COMPLAIN: + if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as <unk>."; + break; + case Config::THROW_UP: + UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing " << str << " and the model is configured to reject these models. Run build_binary -s to disable this check."); + } +} + } // namespace ngram } // namespace lm diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index b584c82f..546c1649 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -2,6 +2,7 @@ #define LM_VOCAB__ #include "lm/enumerate_vocab.hh" +#include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" #include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" @@ -134,6 +135,15 @@ class ProbingVocabulary : public base::Vocabulary { EnumerateVocab *enumerate_; }; +void MissingUnknown(const Config &config) throw(SpecialWordMissingException); +void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException); + +template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) { + if (!vocab.SawUnk()) MissingUnknown(config); + if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>"); + if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>"); +} + } // namespace ngram } // namespace lm |